In [1]:
import os
import random
import time
import gc
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.metrics import classification_report
from pytorch_metric_learning import distances
from pytorch_metric_learning.losses import ArcFaceLoss

from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict


# --------------------------------------------------
# ◆ WeightedArcFaceLoss (ArcFaceLoss の継承)
# --------------------------------------------------
class WeightedArcFaceLoss(ArcFaceLoss):
    def __init__(
            self,
            num_classes,
            embedding_size,
            margin=28.6,
            scale=64,
            class_weights=None,
            **kwargs
    ):
        super().__init__(
            num_classes=num_classes,
            embedding_size=embedding_size,
            margin=margin,
            scale=scale,
            **kwargs
        )
        if class_weights is not None:
            if not isinstance(class_weights, torch.Tensor):
                class_weights = torch.tensor(class_weights, dtype=torch.float32)
        self.cross_entropy = nn.CrossEntropyLoss(
            weight=class_weights,
            reduction="none"
        )


def map_pids_to_0_known(df, known_pids, label_col="pid", new_label_col="pid_trans"):
    """
    既知のpidがあれば label = i+1 (1から開始),
    それ以外はラベル0とする
    """
    sorted_kps = sorted(known_pids)
    pid_map = {p: i + 1 for i, p in enumerate(sorted_kps)}
    df[new_label_col] = df[label_col].apply(lambda x: pid_map.get(x, 0))
    return df


def downsample_label0(df, label_col="pid"):
    """
    label=0(未知)のサンプルが多すぎるのを抑制する
    """
    label0_df = df[df[label_col] == 0]
    others_df = df[df[label_col] != 0]
    counts = others_df[label_col].value_counts()
    max_count = counts.max() if not counts.empty else 0
    max_count = max(max_count * 30, 128)
    print("各ラベルのサンプル数:")
    print(df[label_col].value_counts())
    print(f"→ ラベル0を {max_count} 枚に合わせます")
    if len(label0_df) > max_count:
        label0_df = label0_df.sample(n=max_count, random_state=42)
    new_df = pd.concat([label0_df, others_df], ignore_index=True)
    return new_df


def set_random_seed(seed_value=42):
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def extract_features_and_labels(df):
    """
    df内に 'embedding' 列があればそこから (N,dim) をstack,
    なければ dim0..dimN を探す
    ラベルは pid_trans があればそれ、なければ pid
    """
    if "embedding" in df.columns:
        X = np.stack(df["embedding"].values)
    else:
        feature_cols = [c for c in df.columns if c.startswith("dim")]
        X = df[feature_cols].values

    if "pid_trans" in df.columns:
        y = df["pid_trans"].values
    else:
        y = df["pid"].values

    return X, y


class Net128(nn.Module):
    def __init__(self):
        super(Net128, self).__init__()
        self.fc1 = nn.Linear(2048, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 512)
        self.dropout_1 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(512, 384)
        self.dropout_2 = nn.Dropout(0.2)
        self.fc4 = nn.Linear(384, 128)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout_1(x)
        x = F.relu(self.fc3(x))
        x = self.dropout_2(x)
        x = F.relu(self.fc4(x))
        return x


def train_one_epoch(model, loss_func, device, loader, optimizer):
    model.train()
    total_loss = 0.0
    for features, labels in loader:
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(features)
        loss_output = loss_func(embeddings, labels)
        if isinstance(loss_output, dict):
            loss_val = loss_output["loss"]["losses"].mean()
        else:
            loss_val = loss_output
        loss_val.backward()
        optimizer.step()
        total_loss += loss_val.item()
    return total_loss / len(loader)


def train_net128_weighted_arcface(X_train, y_train, device, epochs=100, lr=1e-6, batch_size=128):
    """
    ArcFaceを重み付きCrossEntropyで学習する (デフォルトでは epochs=100, lr=1e-6)
    """
    from torch.utils.data import DataLoader, TensorDataset
    unique_labels = np.unique(y_train)
    num_classes = unique_labels.max() + 1
    dataset = TensorDataset(
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.long)
    )
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    model = Net128().to(device)

    class_weights = np.zeros(num_classes, dtype=np.float32)
    for lab in unique_labels:
        c_count = (y_train == lab).sum()
        class_weights[lab] = 1.0 / (c_count + 1e-5)

    distance = distances.CosineSimilarity()
    loss_func = WeightedArcFaceLoss(
        num_classes=num_classes,
        embedding_size=128,
        margin=28.6,
        scale=64,
        class_weights=class_weights,
        distance=distance
    ).to(device)

    optimizer = optim.Adam(
        list(model.parameters()) + list(loss_func.parameters()), lr=lr
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=5, verbose=True
    )

    for epoch in range(1, epochs + 1):
        ep_start = time.time()
        avg_loss = train_one_epoch(model, loss_func, device, loader, optimizer)
        scheduler.step(avg_loss)
        print(f"  [Epoch {epoch}] train_loss={avg_loss:.4f}, time={(time.time() - ep_start):.2f} sec")

    return model


def get_embeddings_128(model, X_input, device, batch_size=128):
    """
    学習済みモデルで128次元埋め込みを得る
    """
    from torch.utils.data import DataLoader, TensorDataset
    model.eval()
    dataset = TensorDataset(torch.tensor(X_input, dtype=torch.float32))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    embs = []
    with torch.no_grad():
        for (batch,) in loader:
            batch = batch.to(device)
            emb = model(batch)
            embs.append(emb.cpu().numpy())
    return np.concatenate(embs, axis=0)


def compute_average_tracklet_similarity(gall_emb, label_embs):
    """
    トラックレット vs ラベル埋め込み
    shape (N_label, N_gall) のcos類似度行列を作り、全要素を平均
    """
    sim_matrix = cosine_similarity(label_embs, gall_emb)
    avg_sim = np.mean(sim_matrix)
    return avg_sim


def predict_label_for_gall_tracklet(gall_emb, label_embs_dict):
    """
    (best_label, best_score) を返す
    """
    best_label = 0
    best_score = -1e9
    for lab, lab_embs in label_embs_dict.items():
        avg_sim = compute_average_tracklet_similarity(gall_emb, lab_embs)
        if avg_sim > best_score:
            best_score = avg_sim
            best_label = lab
    return best_label, best_score


def build_label_embs_dict(X_train, y_train, model, device, batch_size=128):
    """
    学習済みモデルに通して得た埋め込みをラベルごとにまとめる
    """
    emb_train = get_embeddings_128(model, X_train, device=device, batch_size=batch_size)
    label_map = defaultdict(list)
    for emb_vec, lab in zip(emb_train, y_train):
        label_map[lab].append(emb_vec)
    label_embs_dict = {}
    for lab, vectors in label_map.items():
        label_embs_dict[lab] = np.stack(vectors, axis=0)
    return label_embs_dict


def main():
    set_random_seed(999)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ===============================
    # 1) トレーニングデータ読み込み
    # ===============================
    print("\n=== Loading train data (original only) from multiple pickle files ===")
    train_chunks_folder = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS/all_mini_direct/train_split_ReSNet_pkl/filtered"
    train_chunk_files = sorted([os.path.join(train_chunks_folder, f)
                                for f in os.listdir(train_chunks_folder)
                                if f.lower().endswith(".pkl")])
    print(f"Found {len(train_chunk_files)} train chunk PKL files in '{train_chunks_folder}'")

    total_train_samples = 70000
    num_files = len(train_chunk_files)
    samples_per_file = total_train_samples // num_files
    remainder = total_train_samples % num_files

    train_dfs = []
    for i, file in enumerate(train_chunk_files):
        df_part = pd.read_pickle(file)
        if isinstance(df_part, list):
            df_part = pd.DataFrame(df_part)
        n_samples = samples_per_file + (1 if i < remainder else 0)
        df_sampled = df_part.sample(n=n_samples, random_state=42)
        train_dfs.append(df_sampled)
        del df_part, df_sampled
        gc.collect()
    df_train_orig = pd.concat(train_dfs, ignore_index=True)
    print(f"Combined train data shape={df_train_orig.shape}")
    del train_dfs
    gc.collect()

    # ===============================
    # 2) クエリ / ギャラリーの初期設定
    # ===============================
    print("\n=== Building pid-to-file mapping for query data ===")
    query_chunks_folder = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS/all_mini_direct/query_split_ReSNet_pkl_mini/fillter"
    query_chunk_files = sorted([os.path.join(query_chunks_folder, f)
                                for f in os.listdir(query_chunks_folder)
                                if f.lower().endswith(".pkl")])
    print(f"Found {len(query_chunk_files)} query chunk PKL files in '{query_chunks_folder}'")

    pid_to_file = {}
    for qf in query_chunk_files:
        df_qpart = pd.read_pickle(qf)
        if isinstance(df_qpart, list):
            df_qpart = pd.DataFrame(df_qpart)
        for pid_ in df_qpart["pid"].unique():
            pid_to_file[pid_] = qf
        del df_qpart
        gc.collect()

    unique_query_pids = sorted(pid_to_file.keys())
    chunk_size_q = 7
    num_chunks = (len(unique_query_pids) + chunk_size_q - 1) // chunk_size_q
    print(f"Total query pids = {len(unique_query_pids)} => {num_chunks} query chunks (size={chunk_size_q})")

    # ギャラリー
    gallery_chunks_folder = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS/all_mini_direct/gallery_split_ReSNet_pkl_mini/fillter"
    gallery_chunk_files = sorted([os.path.join(gallery_chunks_folder, f)
                                  for f in os.listdir(gallery_chunks_folder)
                                  if f.lower().endswith(".pkl")])
    print(f"Found {len(gallery_chunk_files)} gallery chunk PKL files in '{gallery_chunks_folder}'")

    print("Resetting `annotated=False` in all gallery chunk PKL files...")
    for gal_fpath in gallery_chunk_files:
        df_gal_chunk = pd.read_pickle(gal_fpath)
        if isinstance(df_gal_chunk, list):
            df_gal_chunk = pd.DataFrame(df_gal_chunk)
        df_gal_chunk["annotated"] = False
        df_gal_chunk.to_pickle(gal_fpath)
        del df_gal_chunk
        gc.collect()
    print("All gallery chunk PKLs have annotated=False initialized.")

    # もし古いサマリーがあれば削除
    summary_csv = "summary_results_aug_matrix_next.csv"
    if os.path.exists(summary_csv):
        os.remove(summary_csv)

    # ===============================
    # 3) Queryを7個ずつ追加しながら学習＆アノテーションフロー
    # ===============================
    for chunk_idx in range(num_chunks):
        chunk_start_t = time.time()
        start_idx = chunk_idx * chunk_size_q
        end_idx = min((chunk_idx+1)*chunk_size_q, len(unique_query_pids))
        pids_to_add = unique_query_pids[start_idx:end_idx]
        print(f"\n=== Query Chunk {chunk_idx+1}/{num_chunks} ===")
        print(f"  -> pids_to_add = {pids_to_add}")

        print("Resetting annotated=False for all gallery chunk PKLs (start of THIS Query Chunk)...")
        for gal_fpath in gallery_chunk_files:
            df_gal_chunk = pd.read_pickle(gal_fpath)
            # "annotated"列が無ければ追加
            if "annotated" not in df_gal_chunk.columns:
                df_gal_chunk["annotated"] = False
            else:
                df_gal_chunk["annotated"] = False
            df_gal_chunk.to_pickle(gal_fpath)
            del df_gal_chunk
            gc.collect()

        # (A) クエリ取り出し
        needed_filepaths = set(pid_to_file[pid_] for pid_ in pids_to_add)
        query_dfs_for_chunk = []
        for fpath in needed_filepaths:
            df_qpart = pd.read_pickle(fpath)
            if isinstance(df_qpart, list):
                df_qpart = pd.DataFrame(df_qpart)
            df_qpart_sub = df_qpart[df_qpart["pid"].isin(pids_to_add)]
            if len(df_qpart_sub) > 0:
                query_dfs_for_chunk.append(df_qpart_sub)
            del df_qpart, df_qpart_sub
            gc.collect()

        if len(query_dfs_for_chunk) == 0:
            print("  -> No tracklets found for these PIDs in splitted query files.")
            continue

        df_query_chunk = pd.concat(query_dfs_for_chunk, ignore_index=True)
        del query_dfs_for_chunk
        gc.collect()

        # 「各PIDについて1つ目のトラックレット」を抜き出す
        selected_tracklets = []
        for pid_ in pids_to_add:
            df_pid = df_query_chunk[df_query_chunk["pid"] == pid_]
            if df_pid.empty:
                continue
            tracklet_id = df_pid["tracklet_id"].unique()[0]
            df_tracklet = df_pid[df_pid["tracklet_id"] == tracklet_id]
            selected_tracklets.append(df_tracklet)
            del df_pid, df_tracklet
            gc.collect()

        if len(selected_tracklets) == 0:
            print("  -> No valid tracklets found for these PIDs.")
            continue

        df_initial_add = pd.concat(selected_tracklets, ignore_index=True)
        del selected_tracklets
        gc.collect()

        count_orig = (df_initial_add["is_aug"] == 0).sum()
        count_aug = (df_initial_add["is_aug"] == 1).sum()
        print(f"  -> Add {len(df_initial_add)} query samples to training (original={count_orig}, augmentation={count_aug})")

        # new_train
        new_train = pd.concat([df_train_orig.copy(), df_initial_add], ignore_index=True)
        del df_initial_add
        gc.collect()

        # ラウンドを複数回まわす
        for round_idx in range(1, 4):
            round_start_t = time.time()
            print(f"\n--- Query Chunk {chunk_idx+1}, Round {round_idx} ---")

            # (1) map_pids_to_0_known + downsample
            df_train_mapped = new_train.copy()
            df_train_mapped = map_pids_to_0_known(df_train_mapped, pids_to_add,
                                                  label_col="pid", new_label_col="pid_trans")
            df_train_mapped = downsample_label0(df_train_mapped, label_col="pid_trans")

            # (2) 学習
            X_train_m, y_train_m = extract_features_and_labels(df_train_mapped)
            model = train_net128_weighted_arcface(
                X_train_m, y_train_m, device=device,
                epochs=100, lr=1e-6, batch_size=128
            )
            del X_train_m, y_train_m
            gc.collect()

            # (3) ラベルごとの埋め込み辞書
            X_label, y_label = extract_features_and_labels(df_train_mapped)
            label_embs_dict = build_label_embs_dict(X_label, y_label, model, device=device)
            del df_train_mapped, X_label, y_label
            gc.collect()

            # (4) ギャラリー推論
            all_results = []
            for gal_fpath in gallery_chunk_files:
                df_gal_chunk = pd.read_pickle(gal_fpath)

                df_gal_chunk = map_pids_to_0_known(df_gal_chunk, pids_to_add,
                                                   label_col="pid", new_label_col="pid_trans")

                df_gal_infer = df_gal_chunk[df_gal_chunk["is_aug"]==0]
                if len(df_gal_infer)==0:
                    del df_gal_chunk, df_gal_infer
                    gc.collect()
                    continue

                track_ids = df_gal_infer["tracklet_id"].unique()
                pred_list = []
                for tid in track_ids:
                    sub_t = df_gal_infer[df_gal_infer["tracklet_id"]==tid]
                    X_gall_t, _ = extract_features_and_labels(sub_t)
                    emb_gall_t = get_embeddings_128(model, X_gall_t, device=device)
                    pid_true = sub_t["pid_trans"].iloc[0]

                    pred_label, best_score = predict_label_for_gall_tracklet(emb_gall_t, label_embs_dict)
                    pred_list.append((tid, pid_true, pred_label, best_score, gal_fpath))

                df_res_chunk = pd.DataFrame(pred_list, columns=["track_id","pid_true","pid_pred","score","gal_fpath"])

                # annotated情報をマージ
                df_anno_chunk = df_gal_chunk[["tracklet_id","annotated"]].drop_duplicates()
                df_anno_chunk["gal_fpath"] = gal_fpath

                df_res_chunk = pd.merge(
                    df_res_chunk,
                    df_anno_chunk,
                    how="left",
                    left_on=["track_id","gal_fpath"],
                    right_on=["tracklet_id","gal_fpath"]
                )
                df_res_chunk.drop(columns=["tracklet_id"], inplace=True)

                all_results.append(df_res_chunk)
                del df_gal_chunk, df_gal_infer, df_res_chunk, df_anno_chunk
                gc.collect()

            if not all_results:
                print("No tracklets were inferred in this round.")
                continue

            df_res_all = pd.concat(all_results, ignore_index=True)
            df_res_all["correct"] = (df_res_all["pid_true"] == df_res_all["pid_pred"]).astype(int)

            # (5) レポート
            y_true_mv = df_res_all["pid_true"].values
            y_pred_mv = df_res_all["pid_pred"].values
            rep_dict = classification_report(y_true_mv, y_pred_mv, output_dict=True, zero_division=0)
            rep_df = pd.DataFrame(rep_dict).transpose().round(4)
            print("\n=== classification_report (tracklet-level) ===")
            print(rep_df)

            grp = df_res_all.groupby("pid_true")["correct"].agg(total="count", correct="sum")
            grp["detect_rate(%)"] = 100.0 * grp["correct"]/grp["total"]
            print("\n=== Per-PID detection in this round (ALL) ===")
            print(grp.sort_values("pid_true").head(30))

            # (7) 「annotated=False かつ pid_pred!=0」からアノテーション
            df_valid = df_res_all[
                (df_res_all["pid_pred"] != 0) &
                (df_res_all["annotated"] == False)
            ].copy()
            # スコア降順に並べる
            df_valid.sort_values(by="score", ascending=False, inplace=True)

            # まずは「pred_label(=pid_pred)ごとに上位5件」をピックアップ
            label_top5_ids = []
            for pred_lab, sub_df in df_valid.groupby("pid_pred"):
                top_5_tracks = sub_df.head(5)["track_id"].tolist()
                label_top5_ids.extend(top_5_tracks)

            # 重複除去(一応)
            label_top5_ids = list(dict.fromkeys(label_top5_ids))

            # トータルでアノテーションしたい最大数(例: 50)
            annotation_capacity = 50

            # 既に「ラベルごと上位5件」でこれだけ
            n_label_based = len(label_top5_ids)
            if n_label_based >= annotation_capacity:
                # もし label_top5_ids が 50を超える場合は、その中でscoreが高い順に 50個 だけ採用
                label_top5_df = df_valid[df_valid["track_id"].isin(label_top5_ids)].copy()
                label_top5_df.sort_values(by="score", ascending=False, inplace=True)
                final_ids_to_annotate = label_top5_df["track_id"].head(annotation_capacity).tolist()
            else:
                # 残り枠
                remain_needed = annotation_capacity - n_label_based
                # 「ラベル別上位5」に含まれなかった残り
                df_leftover = df_valid[~df_valid["track_id"].isin(label_top5_ids)].copy()
                # そこからスコア上位 'remain_needed' 件をピックアップ
                leftover_ids = df_leftover["track_id"].head(remain_needed).tolist()
                final_ids_to_annotate = label_top5_ids + leftover_ids

            print(f"  -> We'll annotate up to {annotation_capacity} tracklets.")
            print(f"  -> label-based picks = {n_label_based}, final to annotate = {len(final_ids_to_annotate)}")

            # (8) アノテーション処理
            newly_annotated_count = 0
            cond_correct_list = []

            df_topN = df_valid[df_valid["track_id"].isin(final_ids_to_annotate)].copy()
            group_fpath = df_topN.groupby("gal_fpath")["track_id"].unique()

            for gfpath, tids_array in group_fpath.items():
                df_gal_chunk = pd.read_pickle(gfpath)
                df_gal_chunk = map_pids_to_0_known(df_gal_chunk, pids_to_add,
                                                   label_col="pid", new_label_col="pid_trans")

                df_to_add = df_gal_chunk[
                    (df_gal_chunk["tracklet_id"].isin(tids_array)) &
                    (df_gal_chunk["annotated"] == False)
                ].copy()
                if not df_to_add.empty:
                    newly_ids = df_to_add["tracklet_id"].unique()
                    newly_annotated_count += len(newly_ids)

                    # 正解行を抽出
                    cond_correct_part = df_res_all[
                        (df_res_all["track_id"].isin(newly_ids)) &
                        (df_res_all["correct"] == 1)
                    ]
                    if not cond_correct_part.empty:
                        cond_correct_list.append(cond_correct_part)

                    # tracklet_id の振り直し
                    max_tid = new_train["tracklet_id"].max() if "tracklet_id" in new_train.columns else -1
                    df_to_add["tracklet_id"] = df_to_add["tracklet_id"] + max_tid + 1

                    # new_trainに加える
                    new_train = pd.concat([new_train, df_to_add], ignore_index=True)

                    # annotated=True に更新
                    original_idx = df_gal_chunk[
                        (df_gal_chunk["tracklet_id"].isin(newly_ids)) &
                        (df_gal_chunk["annotated"] == False)
                    ].index
                    df_gal_chunk.loc[original_idx, "annotated"] = True

                    # 書き戻し
                    df_gal_chunk.to_pickle(gfpath)

                del df_gal_chunk, df_to_add
                gc.collect()

            # まとめて正解数をカウント
            if cond_correct_list:
                cond_correct_all = pd.concat(cond_correct_list, ignore_index=True)
                correct_newly_annotated = cond_correct_all.shape[0]
            else:
                cond_correct_all = pd.DataFrame()
                correct_newly_annotated = 0

            print(f"  -> newly_annotated={newly_annotated_count}, correct_newly_annotated={correct_newly_annotated}")

            if not cond_correct_all.empty:
                grp_label = cond_correct_all.groupby("pid_true").size().reset_index(name="correct_count")
                print("\n=== correct_newly_annotated breakdown by pid_true ===")
                print(grp_label.sort_values("pid_true"))
            else:
                print("\nNo newly annotated tracklets in this round -> No breakdown")

            # ★ 追加: 内訳を文字列化
            if not cond_correct_all.empty:
                grp_label = cond_correct_all.groupby("pid_true").size().reset_index(name="correct_count")
                print("\n=== correct_newly_annotated breakdown by pid_true ===")
                print(grp_label.sort_values("pid_true"))

                breakdown_str = "; ".join(f"{int(r.pid_true)}:{int(r.correct_count)}"
                                              for r in grp_label.itertuples())
            else:
                print("\nNo newly annotated tracklets in this round -> No breakdown")
                breakdown_str = ""

            # (9) ログ書き出し
            row_dict = {
                "chunk": chunk_idx+1,
                "round": round_idx,
                "pids_to_add": str(pids_to_add),
                "train_samples": len(new_train),
                "newly_annotated_count": newly_annotated_count,
                "correct_newly_annotated": correct_newly_annotated,
                "correct_newly_annotated_breakdown": breakdown_str
            }
            df_summary = pd.DataFrame([row_dict])
            if not os.path.exists(summary_csv):
                df_summary.to_csv(summary_csv, index=False)
            else:
                df_summary.to_csv(summary_csv, index=False, mode="a", header=False)
            del df_summary, row_dict
            gc.collect()

        print(f"\n[Chunk {chunk_idx+1} done] pids_to_add={pids_to_add}, chunk_time={(time.time()-chunk_start_t):.2f}s")

    print("\n=== All Done ===")
    print("See 'summary_results_aug_matrix.csv' for chunk-by-chunk progress.")
    print("Also check the classification reports printed each round for details.")


if __name__=="__main__":
    main()


ModuleNotFoundError: No module named 'pandas'

# 必要ライブラリ

In [2]:
import os
import random
import time
import gc
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.metrics import classification_report
from pytorch_metric_learning import distances
from pytorch_metric_learning.losses import ArcFaceLoss

from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict


ModuleNotFoundError: No module named 'pandas'

# 学習やデータ前処理に関するユーティリティ関数群

In [3]:
def map_pids_to_0_known(df, known_pids, label_col="pid", new_label_col="pid_trans"):
    """
    既知の pid (known_pids) に含まれるIDは 1,2,3,…
    それ以外は 0 としてマッピングする。
    
    例:
    known_pids = [10, 11, 15] => これらを [1, 2, 3] に割り当て
    その他のpid => 0
    """
    sorted_kps = sorted(known_pids)
    pid_map = {p: i + 1 for i, p in enumerate(sorted_kps)}
    df[new_label_col] = df[label_col].apply(lambda x: pid_map.get(x, 0))
    return df


def downsample_label0(df, label_col="pid"):
    """
    未知(=0)ラベルのサンプルが多すぎる場合にダウンサンプリングを行う。
    データ不均衡を緩和して学習を安定させるのが目的。
    """
    label0_df = df[df[label_col] == 0]
    others_df = df[df[label_col] != 0]
    counts = others_df[label_col].value_counts()
    max_count = counts.max() if not counts.empty else 0
    # 他クラスの最大数 * 30 か最低128を基準に制限
    max_count = max(max_count * 30, 128)
    print("各ラベルのサンプル数:")
    print(df[label_col].value_counts())
    print(f"→ ラベル0を {max_count} 枚に合わせます")
    if len(label0_df) > max_count:
        label0_df = label0_df.sample(n=max_count, random_state=42)
    new_df = pd.concat([label0_df, others_df], ignore_index=True)
    return new_df


def set_random_seed(seed_value=42):
    """
    乱数シードを固定して学習の再現性をある程度確保する。
    """
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def extract_features_and_labels(df):
    """
    DataFrameから特徴量Xとラベルyを抽出して返す。
    - embedding列がある場合: (N, dim) にスタック
    - embedding列がない場合: dim0, dim1, ... という列をまとめて (N, dim) に
    - ラベルは pid_trans が優先、それが無ければ pid 列
    """
    if "embedding" in df.columns:
        X = np.stack(df["embedding"].values)
    else:
        feature_cols = [c for c in df.columns if c.startswith("dim")]
        X = df[feature_cols].values

    if "pid_trans" in df.columns:
        y = df["pid_trans"].values
    else:
        y = df["pid"].values

    return X, y

# WeightedArcFaceLoss クラス定義

In [5]:
class WeightedArcFaceLoss(ArcFaceLoss):
    """
    ArcFaceLossを継承し、クラスごとに異なる重み付けを行うためのCrossEntropyLossを組み合わせたクラス。
    """
    def __init__(
            self,
            num_classes,
            embedding_size,
            margin=28.6,
            scale=64,
            class_weights=None,
            **kwargs
    ):
        super().__init__(
            num_classes=num_classes,
            embedding_size=embedding_size,
            margin=margin,
            scale=scale,
            **kwargs
        )
        if class_weights is not None:
            if not isinstance(class_weights, torch.Tensor):
                class_weights = torch.tensor(class_weights, dtype=torch.float32)
        # reduction="none" にすることで、後段でlossの平均を手動でとる操作が可能
        self.cross_entropy = nn.CrossEntropyLoss(
            weight=class_weights,
            reduction="none"
        )

NameError: name 'ArcFaceLoss' is not defined

# ネットワークモデル Net128 の定義

In [6]:
class Net128(nn.Module):
    """
    2048次元の入力（ResNetなどで得られた特徴量を想定）を
    128次元まで圧縮するネットワーク。
    DropoutやBatchNormにより過学習を抑制しつつ学習を行う。
    """
    def __init__(self):
        super(Net128, self).__init__()
        self.fc1 = nn.Linear(2048, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 512)
        self.dropout_1 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(512, 384)
        self.dropout_2 = nn.Dropout(0.2)
        self.fc4 = nn.Linear(384, 128)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout_1(x)
        x = F.relu(self.fc3(x))
        x = self.dropout_2(x)
        x = F.relu(self.fc4(x))
        return x


NameError: name 'nn' is not defined

# 学習（train）に関する関数

In [8]:
def train_one_epoch(model, loss_func, device, loader, optimizer):
    """
    1エポック分の学習を行う。
    Dataloaderからミニバッチを取り出し、順伝搬 -> 損失計算 -> 逆伝搬 -> パラメータ更新。
    """
    model.train()
    total_loss = 0.0
    for features, labels in loader:
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(features)
        loss_output = loss_func(embeddings, labels)
        # WeightedArcFaceLossの場合、戻り値が辞書かどうか確認する
        if isinstance(loss_output, dict):
            loss_val = loss_output["loss"]["losses"].mean()
        else:
            loss_val = loss_output
        loss_val.backward()
        optimizer.step()
        total_loss += loss_val.item()
    return total_loss / len(loader)


def train_net128_weighted_arcface(X_train, y_train, device, epochs=100, lr=1e-6, batch_size=128):
    """
    ArcFaceを用いた学習を実施する。
    - WeightedArcFaceLoss (クラス別重み付き)
    - 学習率はデフォルトで 1e-6
    - エポック数はデフォルトで 100
    """
    from torch.utils.data import DataLoader, TensorDataset
    unique_labels = np.unique(y_train)
    num_classes = unique_labels.max() + 1

    dataset = TensorDataset(
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.long)
    )
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    model = Net128().to(device)

    # クラス重みの計算（単純に 1/class_count）
    class_weights = np.zeros(num_classes, dtype=np.float32)
    for lab in unique_labels:
        c_count = (y_train == lab).sum()
        class_weights[lab] = 1.0 / (c_count + 1e-5)

    distance = distances.CosineSimilarity()
    loss_func = WeightedArcFaceLoss(
        num_classes=num_classes,
        embedding_size=128,
        margin=28.6,
        scale=64,
        class_weights=class_weights,
        distance=distance
    ).to(device)

    optimizer = optim.Adam(
        list(model.parameters()) + list(loss_func.parameters()), lr=lr
    )
    # 損失が改善しない場合に学習率を下げるスケジューラ
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=5, verbose=True
    )

    for epoch in range(1, epochs + 1):
        ep_start = time.time()
        avg_loss = train_one_epoch(model, loss_func, device, loader, optimizer)
        scheduler.step(avg_loss)
        print(f"  [Epoch {epoch}] train_loss={avg_loss:.4f}, time={(time.time() - ep_start):.2f} sec")

    return model


# 推論に関する関数 (埋め込み取得、類似度計算など)

In [9]:
def get_embeddings_128(model, X_input, device, batch_size=128):
    """
    学習済みモデルを使って 128次元埋め込みを取得する。
    入力 X_input: shape = (N, 2048) を想定
    """
    from torch.utils.data import DataLoader, TensorDataset
    model.eval()
    dataset = TensorDataset(torch.tensor(X_input, dtype=torch.float32))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    embs = []
    with torch.no_grad():
        for (batch,) in loader:
            batch = batch.to(device)
            emb = model(batch)
            embs.append(emb.cpu().numpy())
    return np.concatenate(embs, axis=0)


def compute_average_tracklet_similarity(gall_emb, label_embs):
    """
    与えられたギャラリー埋め込み（gall_emb）と、あるラベルの埋め込み集合（label_embs）の
    コサイン類似度を計算し、その平均を返す。
    """
    sim_matrix = cosine_similarity(label_embs, gall_emb)
    avg_sim = np.mean(sim_matrix)
    return avg_sim


def predict_label_for_gall_tracklet(gall_emb, label_embs_dict):
    """
    ギャラリーの一つのトラックレット埋め込み gall_emb が
    どのラベルに最も近いかを類似度ベースで判定する。
    戻り値: (best_label, best_score)
    """
    best_label = 0
    best_score = -1e9
    for lab, lab_embs in label_embs_dict.items():
        avg_sim = compute_average_tracklet_similarity(gall_emb, lab_embs)
        if avg_sim > best_score:
            best_score = avg_sim
            best_label = lab
    return best_label, best_score


def build_label_embs_dict(X_train, y_train, model, device, batch_size=128):
    """
    学習済みモデルを用いて X_train -> 128次元埋め込みを取得し、
    ラベルごとに埋め込みをまとめた辞書を作成する。
    label_embs_dict[label] = (N_label_samples, 128)
    """
    emb_train = get_embeddings_128(model, X_train, device=device, batch_size=batch_size)
    label_map = defaultdict(list)
    for emb_vec, lab in zip(emb_train, y_train):
        label_map[lab].append(emb_vec)
    label_embs_dict = {}
    for lab, vectors in label_map.items():
        label_embs_dict[lab] = np.stack(vectors, axis=0)
    return label_embs_dict


# 関数実行

In [10]:
##############################################
# Cell 7: メインフロー main() 関数の定義
##############################################

def main():
    """
    データの読み込み、学習、推論、アノテーションフローのサイクルをまとめたメイン関数。
    ピックアップされたクエリのトラックレットを既存学習データに追加し、
    ArcFaceベースの学習・推論を繰り返しながらギャラリーのアノテーションを進める。
    """
    set_random_seed(999)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ===============================
    # 1) トレーニングデータ読み込み
    # ===============================
    print("\n=== Loading train data (original only) from multiple pickle files ===")
    train_chunks_folder = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS/all_mini_direct/train_split_ReSNet_pkl/filtered"
    train_chunk_files = sorted([os.path.join(train_chunks_folder, f)
                                for f in os.listdir(train_chunks_folder)
                                if f.lower().endswith(".pkl")])
    print(f"Found {len(train_chunk_files)} train chunk PKL files in '{train_chunks_folder}'")

    total_train_samples = 70000
    num_files = len(train_chunk_files)
    samples_per_file = total_train_samples // num_files
    remainder = total_train_samples % num_files

    train_dfs = []
    for i, file in enumerate(train_chunk_files):
        df_part = pd.read_pickle(file)
        if isinstance(df_part, list):
            df_part = pd.DataFrame(df_part)
        n_samples = samples_per_file + (1 if i < remainder else 0)
        df_sampled = df_part.sample(n=n_samples, random_state=42)
        train_dfs.append(df_sampled)
        del df_part, df_sampled
        gc.collect()
    df_train_orig = pd.concat(train_dfs, ignore_index=True)
    print(f"Combined train data shape={df_train_orig.shape}")
    del train_dfs
    gc.collect()

    # ===============================
    # 2) クエリ / ギャラリーの初期設定
    # ===============================
    print("\n=== Building pid-to-file mapping for query data ===")
    query_chunks_folder = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS/all_mini_direct/query_split_ReSNet_pkl_mini/fillter"
    query_chunk_files = sorted([os.path.join(query_chunks_folder, f)
                                for f in os.listdir(query_chunks_folder)
                                if f.lower().endswith(".pkl")])
    print(f"Found {len(query_chunk_files)} query chunk PKL files in '{query_chunks_folder}'")

    pid_to_file = {}
    for qf in query_chunk_files:
        df_qpart = pd.read_pickle(qf)
        if isinstance(df_qpart, list):
            df_qpart = pd.DataFrame(df_qpart)
        for pid_ in df_qpart["pid"].unique():
            pid_to_file[pid_] = qf
        del df_qpart
        gc.collect()

    unique_query_pids = sorted(pid_to_file.keys())
    chunk_size_q = 7
    num_chunks = (len(unique_query_pids) + chunk_size_q - 1) // chunk_size_q
    print(f"Total query pids = {len(unique_query_pids)} => {num_chunks} query chunks (size={chunk_size_q})")

    # ギャラリー
    gallery_chunks_folder = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS/all_mini_direct/gallery_split_ReSNet_pkl_mini/fillter"
    gallery_chunk_files = sorted([os.path.join(gallery_chunks_folder, f)
                                  for f in os.listdir(gallery_chunks_folder)
                                  if f.lower().endswith(".pkl")])
    print(f"Found {len(gallery_chunk_files)} gallery chunk PKL files in '{gallery_chunks_folder}'")

    print("Resetting `annotated=False` in all gallery chunk PKL files...")
    for gal_fpath in gallery_chunk_files:
        df_gal_chunk = pd.read_pickle(gal_fpath)
        if isinstance(df_gal_chunk, list):
            df_gal_chunk = pd.DataFrame(df_gal_chunk)
        df_gal_chunk["annotated"] = False
        df_gal_chunk.to_pickle(gal_fpath)
        del df_gal_chunk
        gc.collect()
    print("All gallery chunk PKLs have annotated=False initialized.")

    summary_csv = "summary_results_aug_matrix_next.csv"
    if os.path.exists(summary_csv):
        os.remove(summary_csv)

    # ===============================
    # 3) Queryを7個ずつ追加しながら学習＆アノテーションフロー
    # ===============================
    for chunk_idx in range(num_chunks):
        chunk_start_t = time.time()
        start_idx = chunk_idx * chunk_size_q
        end_idx = min((chunk_idx+1)*chunk_size_q, len(unique_query_pids))
        pids_to_add = unique_query_pids[start_idx:end_idx]
        print(f"\n=== Query Chunk {chunk_idx+1}/{num_chunks} ===")
        print(f"  -> pids_to_add = {pids_to_add}")

        print("Resetting annotated=False for all gallery chunk PKLs (start of THIS Query Chunk)...")
        for gal_fpath in gallery_chunk_files:
            df_gal_chunk = pd.read_pickle(gal_fpath)
            if "annotated" not in df_gal_chunk.columns:
                df_gal_chunk["annotated"] = False
            else:
                df_gal_chunk["annotated"] = False
            df_gal_chunk.to_pickle(gal_fpath)
            del df_gal_chunk
            gc.collect()

        # クエリを取り出して結合
        needed_filepaths = set(pid_to_file[pid_] for pid_ in pids_to_add)
        query_dfs_for_chunk = []
        for fpath in needed_filepaths:
            df_qpart = pd.read_pickle(fpath)
            if isinstance(df_qpart, list):
                df_qpart = pd.DataFrame(df_qpart)
            df_qpart_sub = df_qpart[df_qpart["pid"].isin(pids_to_add)]
            if len(df_qpart_sub) > 0:
                query_dfs_for_chunk.append(df_qpart_sub)
            del df_qpart, df_qpart_sub
            gc.collect()

        if len(query_dfs_for_chunk) == 0:
            print("  -> No tracklets found for these PIDs in splitted query files.")
            continue

        df_query_chunk = pd.concat(query_dfs_for_chunk, ignore_index=True)
        del query_dfs_for_chunk
        gc.collect()

        # 最初のトラックレットだけを選択
        selected_tracklets = []
        for pid_ in pids_to_add:
            df_pid = df_query_chunk[df_query_chunk["pid"] == pid_]
            if df_pid.empty:
                continue
            tracklet_id = df_pid["tracklet_id"].unique()[0]
            df_tracklet = df_pid[df_pid["tracklet_id"] == tracklet_id]
            selected_tracklets.append(df_tracklet)
            del df_pid, df_tracklet
            gc.collect()

        if len(selected_tracklets) == 0:
            print("  -> No valid tracklets found for these PIDs.")
            continue

        df_initial_add = pd.concat(selected_tracklets, ignore_index=True)
        del selected_tracklets
        gc.collect()

        count_orig = (df_initial_add["is_aug"] == 0).sum()
        count_aug = (df_initial_add["is_aug"] == 1).sum()
        print(f"  -> Add {len(df_initial_add)} query samples to training (original={count_orig}, augmentation={count_aug})")

        new_train = pd.concat([df_train_orig.copy(), df_initial_add], ignore_index=True)
        del df_initial_add
        gc.collect()

        # ラウンドを複数回まわす
        for round_idx in range(1, 4):
            round_start_t = time.time()
            print(f"\n--- Query Chunk {chunk_idx+1}, Round {round_idx} ---")

            # (1) map_pids_to_0_known + downsample
            df_train_mapped = new_train.copy()
            df_train_mapped = map_pids_to_0_known(df_train_mapped, pids_to_add,
                                                  label_col="pid", new_label_col="pid_trans")
            df_train_mapped = downsample_label0(df_train_mapped, label_col="pid_trans")

            # (2) 学習
            X_train_m, y_train_m = extract_features_and_labels(df_train_mapped)
            model = train_net128_weighted_arcface(
                X_train_m, y_train_m, device=device,
                epochs=100, lr=1e-6, batch_size=128
            )
            del X_train_m, y_train_m
            gc.collect()

            # (3) ラベルごとの埋め込み辞書
            X_label, y_label = extract_features_and_labels(df_train_mapped)
            label_embs_dict = build_label_embs_dict(X_label, y_label, model, device=device)
            del df_train_mapped, X_label, y_label
            gc.collect()

            # (4) ギャラリー推論
            all_results = []
            for gal_fpath in gallery_chunk_files:
                df_gal_chunk = pd.read_pickle(gal_fpath)
                df_gal_chunk = map_pids_to_0_known(df_gal_chunk, pids_to_add,
                                                   label_col="pid", new_label_col="pid_trans")

                df_gal_infer = df_gal_chunk[df_gal_chunk["is_aug"]==0]
                if len(df_gal_infer)==0:
                    del df_gal_chunk, df_gal_infer
                    gc.collect()
                    continue

                track_ids = df_gal_infer["tracklet_id"].unique()
                pred_list = []
                for tid in track_ids:
                    sub_t = df_gal_infer[df_gal_infer["tracklet_id"]==tid]
                    X_gall_t, _ = extract_features_and_labels(sub_t)
                    emb_gall_t = get_embeddings_128(model, X_gall_t, device=device)
                    pid_true = sub_t["pid_trans"].iloc[0]

                    pred_label, best_score = predict_label_for_gall_tracklet(emb_gall_t, label_embs_dict)
                    pred_list.append((tid, pid_true, pred_label, best_score, gal_fpath))

                df_res_chunk = pd.DataFrame(pred_list, columns=["track_id","pid_true","pid_pred","score","gal_fpath"])

                # annotated情報をマージ
                df_anno_chunk = df_gal_chunk[["tracklet_id","annotated"]].drop_duplicates()
                df_anno_chunk["gal_fpath"] = gal_fpath

                df_res_chunk = pd.merge(
                    df_res_chunk,
                    df_anno_chunk,
                    how="left",
                    left_on=["track_id","gal_fpath"],
                    right_on=["tracklet_id","gal_fpath"]
                )
                df_res_chunk.drop(columns=["tracklet_id"], inplace=True)

                all_results.append(df_res_chunk)
                del df_gal_chunk, df_gal_infer, df_res_chunk, df_anno_chunk
                gc.collect()

            if not all_results:
                print("No tracklets were inferred in this round.")
                continue

            df_res_all = pd.concat(all_results, ignore_index=True)
            df_res_all["correct"] = (df_res_all["pid_true"] == df_res_all["pid_pred"]).astype(int)

            # (5) レポート
            y_true_mv = df_res_all["pid_true"].values
            y_pred_mv = df_res_all["pid_pred"].values
            rep_dict = classification_report(y_true_mv, y_pred_mv, output_dict=True, zero_division=0)
            rep_df = pd.DataFrame(rep_dict).transpose().round(4)
            print("\n=== classification_report (tracklet-level) ===")
            print(rep_df)

            grp = df_res_all.groupby("pid_true")["correct"].agg(total="count", correct="sum")
            grp["detect_rate(%)"] = 100.0 * grp["correct"]/grp["total"]
            print("\n=== Per-PID detection in this round (ALL) ===")
            print(grp.sort_values("pid_true").head(30))

            # (7) 「annotated=False かつ pid_pred!=0」からアノテーション
            df_valid = df_res_all[
                (df_res_all["pid_pred"] != 0) &
                (df_res_all["annotated"] == False)
            ].copy()
            df_valid.sort_values(by="score", ascending=False, inplace=True)

            # ラベルごと上位5件
            label_top5_ids = []
            for pred_lab, sub_df in df_valid.groupby("pid_pred"):
                top_5_tracks = sub_df.head(5)["track_id"].tolist()
                label_top5_ids.extend(top_5_tracks)

            label_top5_ids = list(dict.fromkeys(label_top5_ids))  # 重複除去

            annotation_capacity = 50
            n_label_based = len(label_top5_ids)
            if n_label_based >= annotation_capacity:
                label_top5_df = df_valid[df_valid["track_id"].isin(label_top5_ids)].copy()
                label_top5_df.sort_values(by="score", ascending=False, inplace=True)
                final_ids_to_annotate = label_top5_df["track_id"].head(annotation_capacity).tolist()
            else:
                remain_needed = annotation_capacity - n_label_based
                df_leftover = df_valid[~df_valid["track_id"].isin(label_top5_ids)].copy()
                leftover_ids = df_leftover["track_id"].head(remain_needed).tolist()
                final_ids_to_annotate = label_top5_ids + leftover_ids

            print(f"  -> We'll annotate up to {annotation_capacity} tracklets.")
            print(f"  -> label-based picks = {n_label_based}, final to annotate = {len(final_ids_to_annotate)}")

            # (8) アノテーション処理
            newly_annotated_count = 0
            cond_correct_list = []

            df_topN = df_valid[df_valid["track_id"].isin(final_ids_to_annotate)].copy()
            group_fpath = df_topN.groupby("gal_fpath")["track_id"].unique()

            for gfpath, tids_array in group_fpath.items():
                df_gal_chunk = pd.read_pickle(gfpath)
                df_gal_chunk = map_pids_to_0_known(df_gal_chunk, pids_to_add,
                                                   label_col="pid", new_label_col="pid_trans")

                df_to_add = df_gal_chunk[
                    (df_gal_chunk["tracklet_id"].isin(tids_array)) &
                    (df_gal_chunk["annotated"] == False)
                ].copy()
                if not df_to_add.empty:
                    newly_ids = df_to_add["tracklet_id"].unique()
                    newly_annotated_count += len(newly_ids)

                    cond_correct_part = df_res_all[
                        (df_res_all["track_id"].isin(newly_ids)) &
                        (df_res_all["correct"] == 1)
                    ]
                    if not cond_correct_part.empty:
                        cond_correct_list.append(cond_correct_part)

                    max_tid = new_train["tracklet_id"].max() if "tracklet_id" in new_train.columns else -1
                    df_to_add["tracklet_id"] = df_to_add["tracklet_id"] + max_tid + 1

                    new_train = pd.concat([new_train, df_to_add], ignore_index=True)

                    original_idx = df_gal_chunk[
                        (df_gal_chunk["tracklet_id"].isin(newly_ids)) &
                        (df_gal_chunk["annotated"] == False)
                    ].index
                    df_gal_chunk.loc[original_idx, "annotated"] = True

                    df_gal_chunk.to_pickle(gfpath)

                del df_gal_chunk, df_to_add
                gc.collect()

            if cond_correct_list:
                cond_correct_all = pd.concat(cond_correct_list, ignore_index=True)
                correct_newly_annotated = cond_correct_all.shape[0]
            else:
                cond_correct_all = pd.DataFrame()
                correct_newly_annotated = 0

            print(f"  -> newly_annotated={newly_annotated_count}, correct_newly_annotated={correct_newly_annotated}")

            if not cond_correct_all.empty:
                grp_label = cond_correct_all.groupby("pid_true").size().reset_index(name="correct_count")
                print("\n=== correct_newly_annotated breakdown by pid_true ===")
                print(grp_label.sort_values("pid_true"))
                breakdown_str = "; ".join(f"{int(r.pid_true)}:{int(r.correct_count)}"
                                          for r in grp_label.itertuples())
            else:
                print("\nNo newly annotated tracklets in this round -> No breakdown")
                breakdown_str = ""

            # (9) ログ書き出し
            row_dict = {
                "chunk": chunk_idx+1,
                "round": round_idx,
                "pids_to_add": str(pids_to_add),
                "train_samples": len(new_train),
                "newly_annotated_count": newly_annotated_count,
                "correct_newly_annotated": correct_newly_annotated,
                "correct_newly_annotated_breakdown": breakdown_str
            }
            df_summary = pd.DataFrame([row_dict])
            if not os.path.exists(summary_csv):
                df_summary.to_csv(summary_csv, index=False)
            else:
                df_summary.to_csv(summary_csv, index=False, mode="a", header=False)
            del df_summary, row_dict
            gc.collect()

        print(f"\n[Chunk {chunk_idx+1} done] pids_to_add={pids_to_add}, chunk_time={(time.time()-chunk_start_t):.2f}s")

    print("\n=== All Done ===")
    print("See 'summary_results_aug_matrix.csv' for chunk-by-chunk progress.")
    print("Also check the classification reports printed each round for details.")


In [11]:
#############################################
# Cell 8: Notebook上で main() を呼び出す例
#############################################

# 実際の実行は任意です。
# もしこのNotebookでスクリプトを一通り動かしたい場合はこちらのセルを実行してください。

if __name__ == "__main__":
    main()


NameError: name 'torch' is not defined