In [None]:
import os
import csv
import random
import numpy as np
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate

import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image
from scipy.io import loadmat
import torch.nn.functional as F


# ========== A) MarsTrainLoader & MarsLoader ==========

class MarsTrainLoader:
    def __init__(self, root, home_dir='bbox_train', min_seq_len=0):
        self.root = root
        self.home_dir = home_dir
        self.min_seq_len = min_seq_len
        self.train_dir = os.path.join(self.root, self.home_dir)
        self._check_before_run()

    def _check_before_run(self):
        if not os.path.exists(self.train_dir):
            raise RuntimeError(f"Train directory '{self.train_dir}' does not exist.")

    def load_data(self):
        tracklet_dict = {}
        for root_dir, _, files in os.walk(self.train_dir):
            for fname in files:
                if fname.lower().endswith('.jpg'):
                    full_path = os.path.join(root_dir, fname)
                    pid, camid, track_id, frame_idx = self._parse_filename(fname)
                    key = (pid, camid, track_id)
                    if key not in tracklet_dict:
                        tracklet_dict[key] = []
                    tracklet_dict[key].append((frame_idx, full_path))

        tracklets = []
        tracklet_count = 1
        for (pid, camid, tid), frames in tracklet_dict.items():
            frames.sort(key=lambda x: x[0])
            img_paths = [x[1] for x in frames]
            if len(img_paths) >= self.min_seq_len:
                tracklets.append((img_paths, pid, camid, tracklet_count))
                tracklet_count += 1
        return tracklets

    def _parse_filename(self, filename):
        """
        例: '0000C1T0001F009.jpg'
          pid=0000, camid=1, trackid=0001, frame_idx=009
        """
        name, _ = os.path.splitext(filename)
        pid_str     = name[0:4]
        camid_str   = name[5]
        trackid_str = name[7:11]
        frameid_str = name[12:]
        return int(pid_str), int(camid_str), int(trackid_str), int(frameid_str)


class MarsLoader:
    def __init__(self, root, min_seq_len=0):
        self.root = root
        self.min_seq_len = min_seq_len
        self.test_name_path = os.path.join(root, 'test_name.txt')
        self.track_test_info_path = os.path.join(root, 'tracks_test_info.mat')
        self.query_IDX_path = os.path.join(root, 'query_IDX.mat')
        self._check_before_run()

        self.test_names = self._get_names(self.test_name_path)
        self.track_test = loadmat(self.track_test_info_path)['track_test_info']
        self.query_IDX  = loadmat(self.query_IDX_path)['query_IDX'].squeeze() - 1

        self.track_query = self.track_test[self.query_IDX, :]
        all_idx = set(range(self.track_test.shape[0]))
        query_idx_set = set(self.query_IDX.tolist())
        self.gallery_IDX = list(all_idx - query_idx_set)
        self.track_gallery = self.track_test[self.gallery_IDX, :]

    def _check_before_run(self):
        required_files = [
            self.test_name_path,
            self.track_test_info_path,
            self.query_IDX_path
        ]
        for f in required_files:
            if not os.path.exists(f):
                raise RuntimeError(f"'{f}' does not exist.")

    def _get_names(self, fpath):
        with open(fpath, 'r') as f:
            names = [line.strip() for line in f.readlines()]
        return names

    def load_data(self, home_dir='bbox_test'):
        query_data   = self._process_data(self.track_query, home_dir)
        gallery_data = self._process_data(self.track_gallery, home_dir)
        return query_data, gallery_data

    def _process_data(self, track_info, home_dir):
        tracklets = []
        for row in track_info:
            start_idx, end_idx, pid, camid = row
            if pid == -1 or pid == 0:
                continue
            camid = camid - 1
            img_names = self.test_names[start_idx - 1 : end_idx]
            img_paths = [
                os.path.join(self.root, home_dir, name[:4], name)
                for name in img_names
            ]
            if len(img_paths) >= self.min_seq_len:
                tracklets.append((img_paths, int(pid), int(camid)))
        return tracklets


# ========== B) カスタムDataset ==========

class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None, is_aug=0, aug_id=0):
        self.image_paths = image_paths
        self.transform   = transform
        self.is_aug      = is_aug
        self.aug_id      = aug_id
        self.skipped_files = []

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        try:
            img = Image.open(path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            return img, path
        except Exception as e:
            print(f"Warning: {e}")
            self.skipped_files.append(path)
            return None


def skip_none_collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return default_collate(batch)


# ========== C) バッチ単位でpklへ保存 ==========

@torch.no_grad()
def extract_embeddings_and_save_batchwise_pkl(
    dataloader, model, device,
    pkl_path,
    pids=None,
    camids=None,
    tids=None,
    is_aug=0,
    aug_id=0,
    append_mode=False
):
    """
    バッチ毎に埋め込みを抽出し、pickleファイルに書き出す。
    保存する内容は、各サンプルごとに以下の辞書情報:
      - file_name: 画像パス
      - embedding: 埋め込みベクトル (list形式)
      - pid, camid, tracklet_id, is_aug, aug_id
    append_mode=True の場合、既存のpklファイルがあれば内容に追加します。
    """
    all_data = []
    total_samples = len(dataloader.dataset)

    loader_iter = iter(dataloader)
    first_batch = next(loader_iter, None)
    if first_batch is None:
        print("No data in dataloader.")
        return
    images, file_paths = first_batch
    images = images.to(device)
    feats = model(images).cpu().numpy()  # shape=(B, feat_dim)

    global_idx = 0
    B = feats.shape[0]
    for i in range(B):
        row = {
            "file_name": file_paths[i],
            "embedding": feats[i].tolist(),
            "pid": pids[global_idx],
            "camid": camids[global_idx],
            "tracklet_id": tids[global_idx],
            "is_aug": is_aug,
            "aug_id": aug_id
        }
        all_data.append(row)
        global_idx += 1

    batch_idx = 1
    for batch in loader_iter:
        if batch is None:
            continue
        images, file_paths = batch
        images = images.to(device)
        feats = model(images).cpu().numpy()
        B = feats.shape[0]
        for i in range(B):
            row = {
                "file_name": file_paths[i],
                "embedding": feats[i].tolist(),
                "pid": pids[global_idx],
                "camid": camids[global_idx],
                "tracklet_id": tids[global_idx],
                "is_aug": is_aug,
                "aug_id": aug_id
            }
            all_data.append(row)
            global_idx += 1

        batch_idx += 1
        print(f"Processed batch={batch_idx}, total={global_idx}/{total_samples}")

    # 既存ファイルがあれば内容を追加
    if append_mode and os.path.exists(pkl_path):
        with open(pkl_path, "rb") as f_in:
            existing_data = pickle.load(f_in)
        all_data = existing_data + all_data

    with open(pkl_path, "wb") as f_out:
        pickle.dump(all_data, f_out)
    print(f"Saved PKL: {pkl_path}, total written={total_samples}")


# ========== D) メイン: ===================
if __name__ == "__main__":
    root_path = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS"

    #--------------------------------
    # (1) Trainデータ読み込み & フラット化
    #--------------------------------
    train_loader_obj = MarsTrainLoader(root=root_path, home_dir='bbox_train/bbox_train')
    train_data = train_loader_obj.load_data()
    print(f"[Train] Loaded tracklets: {len(train_data)}")

    # pid <= 300 のみに絞る
    filtered_train_data = []
    for (img_paths, pid, camid, tid) in train_data:
        if pid <= 3000:
            filtered_train_data.append((img_paths, pid, camid, tid))
    train_data = filtered_train_data
    print(f"[Train] tracklets after filtering: {len(train_data)}")

    # フラット化
    final_train_samples = []
    for (img_paths, pid, camid, tid) in train_data:
        for p in img_paths:
            final_train_samples.append((p, pid, camid, tid))
    print(f"Final train images: {len(final_train_samples)}")

    train_image_paths = [x[0] for x in final_train_samples]
    train_pids        = [x[1] for x in final_train_samples]
    train_camids      = [x[2] for x in final_train_samples]
    train_tids        = [x[3] for x in final_train_samples]

    #--------------------------------
    # (2) Query & Gallery
    #--------------------------------
    test_loader_obj = MarsLoader(root=root_path)
    query_data, gallery_data = test_loader_obj.load_data(home_dir='bbox_test/bbox_test')

    # Query
    filtered_query_data = []
    for (img_paths, pid, camid) in query_data:
        if pid <= 3000:
            filtered_query_data.append((img_paths, pid, camid))
    query_data = filtered_query_data

    final_query_samples = []
    query_tid = 1
    for (img_paths, pid, camid) in query_data:
        for p in img_paths:
            final_query_samples.append((p, pid, camid, query_tid))
        query_tid += 1

    query_image_paths = [x[0] for x in final_query_samples]
    query_pids        = [x[1] for x in final_query_samples]
    query_camids      = [x[2] for x in final_query_samples]
    query_tids        = [x[3] for x in final_query_samples]

    # Gallery
    filtered_gallery_data = []
    for (img_paths, pid, camid) in gallery_data:
        if pid <= 3000:
            filtered_gallery_data.append((img_paths, pid, camid))
    gallery_data = filtered_gallery_data

    final_gallery_samples = []
    gallery_tid = 1
    for (img_paths, pid, camid) in gallery_data:
        for p in img_paths:
            final_gallery_samples.append((p, pid, camid, gallery_tid))
        gallery_tid += 1

    gallery_image_paths = [x[0] for x in final_gallery_samples]
    gallery_pids        = [x[1] for x in final_gallery_samples]
    gallery_camids      = [x[2] for x in final_gallery_samples]
    gallery_tids        = [x[3] for x in final_gallery_samples]

    #--------------------------------
    # (3) モデルの準備 (ResNet50)
    #--------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = models.resnet50(pretrained=False)
    model.fc = nn.Identity()
    checkpoint = torch.load("C:/Users/sugie/PycharmProjects/pythonProject10/MARS/resnet50_msmt17_combineall_256x128_amsgrad_ep150_stp60_lr0.0015_b64_fb10_softmax_labelsmooth_flip_jitter.pth", map_location=device)
    state_dict = checkpoint.get("state_dict", checkpoint)
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict, strict=False)
    model.to(device).eval()

    #--------------------------------
    # (4) 変換定義
    #--------------------------------
    transform_no_aug = transforms.Compose([
        transforms.Resize((256, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    def get_aug_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(
                size=(256, 128),
                scale=(0.8, 1.0),
                ratio=(0.75, 1.3333)
            ),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
            ),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.25))
        ])

    #--------------------------------
    # (5) Dataset & Loader: Query は従来通り
    #--------------------------------
    query_dataset_noaug = ImageDataset(query_image_paths, transform_no_aug, is_aug=0, aug_id=0)
    query_loader_noaug  = DataLoader(query_dataset_noaug, batch_size=64, shuffle=False, collate_fn=skip_none_collate)

    #--------------------------------
    # (★A) Train: PIDごとに50刻みで分割してpkl出力
    #--------------------------------
    # Train全体のサンプル（final_train_samples）は (img_path, pid, camid, tracklet_id)
    # PIDを 0～49, 50～99, ... と分割します。
    unique_train_pids = sorted(list(set(train_pids)))
    chunk_size = 50
    chunk_starts = list(range(0, 3001, chunk_size))
    # Train用Augmentation回数（必要なら変更）
    N_AUG_TRAIN = 0

    for start_pid in chunk_starts:
        end_pid = start_pid + chunk_size
        train_pkl_path_chunk = f"train_embeddings_chunk_ReSNet_all_{start_pid}_{end_pid-1}.pkl"
        # PIDが [start_pid, end_pid-1] のサンプルを抽出
        chunk_indices = [i for i, pid in enumerate(train_pids) if start_pid <= pid < end_pid]
        if len(chunk_indices) == 0:
            continue
        chunk_image_paths = [train_image_paths[i] for i in chunk_indices]
        chunk_pids        = [train_pids[i]        for i in chunk_indices]
        chunk_camids      = [train_camids[i]      for i in chunk_indices]
        chunk_tids        = [train_tids[i]        for i in chunk_indices]

        # ---- (A) オリジナル ----
        train_dataset_noaug_chunk = ImageDataset(
            chunk_image_paths,
            transform_no_aug,
            is_aug=0, aug_id=0
        )
        train_loader_noaug_chunk  = DataLoader(
            train_dataset_noaug_chunk,
            batch_size=64,
            shuffle=False,
            collate_fn=skip_none_collate
        )

        extract_embeddings_and_save_batchwise_pkl(
            dataloader=train_loader_noaug_chunk,
            model=model,
            device=device,
            pkl_path=train_pkl_path_chunk,
            pids=chunk_pids,
            camids=chunk_camids,
            tids=chunk_tids,
            is_aug=0,
            aug_id=0,
            append_mode=False
        )

        # ---- (B) オーグメンテーション（N_AUG_TRAIN回）----
        for aug_i in range(1, N_AUG_TRAIN+1):
            transform_aug = get_aug_transform()
            train_dataset_aug_chunk = ImageDataset(
                chunk_image_paths,
                transform_aug,
                is_aug=1, aug_id=aug_i
            )
            train_loader_aug_chunk  = DataLoader(
                train_dataset_aug_chunk,
                batch_size=64,
                shuffle=False,
                collate_fn=skip_none_collate
            )

            extract_embeddings_and_save_batchwise_pkl(
                dataloader=train_loader_aug_chunk,
                model=model,
                device=device,
                pkl_path=train_pkl_path_chunk,
                pids=chunk_pids,
                camids=chunk_camids,
                tids=chunk_tids,
                is_aug=1,
                aug_id=aug_i,
                append_mode=True
            )
        print(f"Done train chunk [PID {start_pid}..{end_pid-1}]. PKL => {train_pkl_path_chunk}")

    --------------------------------
    (★B) Query: PIDごとに50刻みで分割してpkl出力
    --------------------------------
    unique_query_pids = sorted(list(set(query_pids)))
    chunk_size = 50
    chunk_starts = list(range(0, 3001, chunk_size))
    N_AUG_QUERY = 3

    for start_pid in chunk_starts:
        end_pid = start_pid + chunk_size
        query_pkl_path_chunk = f"query_embeddings_chunk_ReSNet_all_{start_pid}_{end_pid - 1}.pkl"

        chunk_indices = [i for i, pid in enumerate(query_pids) if start_pid <= pid < end_pid]
        if len(chunk_indices) == 0:
            continue

        chunk_image_paths = [query_image_paths[i] for i in chunk_indices]
        chunk_pids = [query_pids[i] for i in chunk_indices]
        chunk_camids = [query_camids[i] for i in chunk_indices]
        chunk_tids = [query_tids[i] for i in chunk_indices]

        # ---- (A) オリジナル ----
        query_dataset_noaug_chunk = ImageDataset(
            chunk_image_paths,
            transform_no_aug,
            is_aug=0, aug_id=0
        )
        query_loader_noaug_chunk = DataLoader(
            query_dataset_noaug_chunk,
            batch_size=64,
            shuffle=False,
            collate_fn=skip_none_collate
        )

        extract_embeddings_and_save_batchwise_pkl(
            dataloader=query_loader_noaug_chunk,
            model=model,
            device=device,
            pkl_path=query_pkl_path_chunk,
            pids=chunk_pids,
            camids=chunk_camids,
            tids=chunk_tids,
            is_aug=0,
            aug_id=0,
            append_mode=False
        )

        # ---- (B) オーグメンテーション ----
        for aug_i in range(1, N_AUG_QUERY + 1):
            transform_aug = get_aug_transform()
            query_dataset_aug_chunk = ImageDataset(
                chunk_image_paths,
                transform_aug,
                is_aug=1, aug_id=aug_i
            )
            query_loader_aug_chunk = DataLoader(
                query_dataset_aug_chunk,
                batch_size=64,
                shuffle=False,
                collate_fn=skip_none_collate
            )

            extract_embeddings_and_save_batchwise_pkl(
                dataloader=query_loader_aug_chunk,
                model=model,
                device=device,
                pkl_path=query_pkl_path_chunk,
                pids=chunk_pids,
                camids=chunk_camids,
                tids=chunk_tids,
                is_aug=1,
                aug_id=aug_i,
                append_mode=True
            )
        print(f"Done query chunk [PID {start_pid}..{end_pid - 1}]. PKL => {query_pkl_path_chunk}")

    #--------------------------------
    # (★C) Gallery: PIDごとに50刻みで分割してpkl出力
    #--------------------------------
    unique_gallery_pids = sorted(list(set(gallery_pids)))
    chunk_size = 25
    chunk_starts = list(range(0, 3001, chunk_size))
    N_AUG_GAL = 3

    for start_pid in chunk_starts:
        end_pid = start_pid + chunk_size
        gallery_pkl_path_chunk = f"gallery_embeddings_chunk_ReSNet_all_{start_pid}_{end_pid-1}.pkl"

        chunk_indices = [i for i, pid in enumerate(gallery_pids)
                         if start_pid <= pid < end_pid]
        if len(chunk_indices) == 0:
            continue

        chunk_image_paths = [gallery_image_paths[i] for i in chunk_indices]
        chunk_pids        = [gallery_pids[i]        for i in chunk_indices]
        chunk_camids      = [gallery_camids[i]      for i in chunk_indices]
        chunk_tids        = [gallery_tids[i]        for i in chunk_indices]

        # ---- (A) オリジナル ----
        gallery_dataset_noaug_chunk = ImageDataset(
            chunk_image_paths,
            transform_no_aug,
            is_aug=0, aug_id=0
        )
        gallery_loader_noaug_chunk  = DataLoader(
            gallery_dataset_noaug_chunk,
            batch_size=64,
            shuffle=False,
            collate_fn=skip_none_collate
        )

        extract_embeddings_and_save_batchwise_pkl(
            dataloader=gallery_loader_noaug_chunk,
            model=model,
            device=device,
            pkl_path=gallery_pkl_path_chunk,
            pids=chunk_pids,
            camids=chunk_camids,
            tids=chunk_tids,
            is_aug=0,
            aug_id=0,
            append_mode=False
        )

        # ---- (B) オーグメンテーション ----
        for aug_i in range(1, N_AUG_GAL+1):
            transform_aug = get_aug_transform()
            gallery_dataset_aug_chunk = ImageDataset(
                chunk_image_paths,
                transform_aug,
                is_aug=1, aug_id=aug_i
            )
            gallery_loader_aug_chunk  = DataLoader(
                gallery_dataset_aug_chunk,
                batch_size=64,
                shuffle=False,
                collate_fn=skip_none_collate
            )

            extract_embeddings_and_save_batchwise_pkl(
                dataloader=gallery_loader_aug_chunk,
                model=model,
                device=device,
                pkl_path=gallery_pkl_path_chunk,
                pids=chunk_pids,
                camids=chunk_camids,
                tids=chunk_tids,
                is_aug=1,
                aug_id=aug_i,
                append_mode=True
            )
        print(f"Done gallery chunk [PID {start_pid}..{end_pid-1}]. PKL => {gallery_pkl_path_chunk}")

    print("\nAll done.")


# 必要ライブラリ

In [1]:
import os
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate

import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image
from scipy.io import loadmat

ModuleNotFoundError: No module named 'scipy'

# データ読込（関数）

In [None]:
# ========== A) MarsTrainLoader & MarsLoader ==========

class MarsTrainLoader:
    def __init__(self, root, home_dir='bbox_train', min_seq_len=0):
        self.root = root
        self.home_dir = home_dir
        self.min_seq_len = min_seq_len
        self.train_dir = os.path.join(self.root, self.home_dir)
        self._check_before_run()

    def _check_before_run(self):
        if not os.path.exists(self.train_dir):
            raise RuntimeError(f"Train directory '{self.train_dir}' does not exist.")

    def load_data(self):
        tracklet_dict = {}
        for root_dir, _, files in os.walk(self.train_dir):
            for fname in files:
                if fname.lower().endswith('.jpg'):
                    full_path = os.path.join(root_dir, fname)
                    pid, camid, track_id, frame_idx = self._parse_filename(fname)
                    key = (pid, camid, track_id)
                    if key not in tracklet_dict:
                        tracklet_dict[key] = []
                    tracklet_dict[key].append((frame_idx, full_path))

        tracklets = []
        tracklet_count = 1
        for (pid, camid, tid), frames in tracklet_dict.items():
            frames.sort(key=lambda x: x[0])
            img_paths = [x[1] for x in frames]
            if len(img_paths) >= self.min_seq_len:
                tracklets.append((img_paths, pid, camid, tracklet_count))
                tracklet_count += 1
        return tracklets

    def _parse_filename(self, filename):
        """
        例: '0000C1T0001F009.jpg'
          pid=0000, camid=1, trackid=0001, frame_idx=009
        """
        name, _ = os.path.splitext(filename)
        pid_str     = name[0:4]
        camid_str   = name[5]
        trackid_str = name[7:11]
        frameid_str = name[12:]
        return int(pid_str), int(camid_str), int(trackid_str), int(frameid_str)


class MarsLoader:
    def __init__(self, root, min_seq_len=0):
        self.root = root
        self.min_seq_len = min_seq_len
        self.test_name_path = os.path.join(root, 'test_name.txt')
        self.track_test_info_path = os.path.join(root, 'tracks_test_info.mat')
        self.query_IDX_path = os.path.join(root, 'query_IDX.mat')
        self._check_before_run()

        self.test_names = self._get_names(self.test_name_path)
        self.track_test = loadmat(self.track_test_info_path)['track_test_info']
        self.query_IDX  = loadmat(self.query_IDX_path)['query_IDX'].squeeze() - 1

        self.track_query = self.track_test[self.query_IDX, :]
        all_idx = set(range(self.track_test.shape[0]))
        query_idx_set = set(self.query_IDX.tolist())
        self.gallery_IDX = list(all_idx - query_idx_set)
        self.track_gallery = self.track_test[self.gallery_IDX, :]

    def _check_before_run(self):
        required_files = [
            self.test_name_path,
            self.track_test_info_path,
            self.query_IDX_path
        ]
        for f in required_files:
            if not os.path.exists(f):
                raise RuntimeError(f"'{f}' does not exist.")

    def _get_names(self, fpath):
        with open(fpath, 'r') as f:
            names = [line.strip() for line in f.readlines()]
        return names

    def load_data(self, home_dir='bbox_test'):
        query_data   = self._process_data(self.track_query, home_dir)
        gallery_data = self._process_data(self.track_gallery, home_dir)
        return query_data, gallery_data

    def _process_data(self, track_info, home_dir):
        tracklets = []
        for row in track_info:
            start_idx, end_idx, pid, camid = row
            if pid == -1 or pid == 0:
                continue
            camid = camid - 1
            img_names = self.test_names[start_idx - 1 : end_idx]
            img_paths = [
                os.path.join(self.root, home_dir, name[:4], name)
                for name in img_names
            ]
            if len(img_paths) >= self.min_seq_len:
                tracklets.append((img_paths, int(pid), int(camid)))
        return tracklets


# ========== B) カスタムDataset ==========

class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None, is_aug=0, aug_id=0):
        self.image_paths = image_paths
        self.transform   = transform
        self.is_aug      = is_aug
        self.aug_id      = aug_id
        self.skipped_files = []

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        try:
            img = Image.open(path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            return img, path
        except Exception as e:
            print(f"Warning: {e}")
            self.skipped_files.append(path)
            return None


def skip_none_collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return default_collate(batch)


# 特徴量抽出　（関数）

In [2]:
# ========== C) バッチ単位でpklへ保存 ==========

@torch.no_grad()
def extract_embeddings_and_save_batchwise_pkl(
    dataloader, model, device,
    pkl_path,
    pids=None,
    camids=None,
    tids=None,
    is_aug=0,
    aug_id=0,
    append_mode=False
):
    """
    バッチ毎に埋め込みを抽出し、pickleファイルに書き出す。
    保存する内容は、各サンプルごとに以下の辞書情報:
      - file_name: 画像パス
      - embedding: 埋め込みベクトル (list形式)
      - pid, camid, tracklet_id, is_aug, aug_id
    append_mode=True の場合、既存のpklファイルがあれば内容に追加します。
    """
    all_data = []
    total_samples = len(dataloader.dataset)

    loader_iter = iter(dataloader)
    first_batch = next(loader_iter, None)
    if first_batch is None:
        print("No data in dataloader.")
        return
    images, file_paths = first_batch
    images = images.to(device)
    feats = model(images).cpu().numpy()  # shape=(B, feat_dim)

    global_idx = 0
    B = feats.shape[0]
    for i in range(B):
        row = {
            "file_name": file_paths[i],
            "embedding": feats[i].tolist(),
            "pid": pids[global_idx],
            "camid": camids[global_idx],
            "tracklet_id": tids[global_idx],
            "is_aug": is_aug,
            "aug_id": aug_id
        }
        all_data.append(row)
        global_idx += 1

    batch_idx = 1
    for batch in loader_iter:
        if batch is None:
            continue
        images, file_paths = batch
        images = images.to(device)
        feats = model(images).cpu().numpy()
        B = feats.shape[0]
        for i in range(B):
            row = {
                "file_name": file_paths[i],
                "embedding": feats[i].tolist(),
                "pid": pids[global_idx],
                "camid": camids[global_idx],
                "tracklet_id": tids[global_idx],
                "is_aug": is_aug,
                "aug_id": aug_id
            }
            all_data.append(row)
            global_idx += 1

        batch_idx += 1
        print(f"Processed batch={batch_idx}, total={global_idx}/{total_samples}")

    # 既存ファイルがあれば内容を追加
    if append_mode and os.path.exists(pkl_path):
        with open(pkl_path, "rb") as f_in:
            existing_data = pickle.load(f_in)
        all_data = existing_data + all_data

    with open(pkl_path, "wb") as f_out:
        pickle.dump(all_data, f_out)
    print(f"Saved PKL: {pkl_path}, total written={total_samples}")

# コード実行

In [3]:
# ========== D) メイン: ===================
if __name__ == "__main__":
    root_path = "C:/Users/sugie/PycharmProjects/pythonProject10/MARS"

    #--------------------------------
    # (1) Trainデータ読み込み & フラット化
    #--------------------------------
    train_loader_obj = MarsTrainLoader(root=root_path, home_dir='bbox_train/bbox_train')
    train_data = train_loader_obj.load_data()
    print(f"[Train] Loaded tracklets: {len(train_data)}")

    # pid <= 300 のみに絞る
    filtered_train_data = []
    for (img_paths, pid, camid, tid) in train_data:
        if pid <= 3000:
            filtered_train_data.append((img_paths, pid, camid, tid))
    train_data = filtered_train_data
    print(f"[Train] tracklets after filtering: {len(train_data)}")

    # フラット化
    final_train_samples = []
    for (img_paths, pid, camid, tid) in train_data:
        for p in img_paths:
            final_train_samples.append((p, pid, camid, tid))
    print(f"Final train images: {len(final_train_samples)}")

    train_image_paths = [x[0] for x in final_train_samples]
    train_pids        = [x[1] for x in final_train_samples]
    train_camids      = [x[2] for x in final_train_samples]
    train_tids        = [x[3] for x in final_train_samples]

    #--------------------------------
    # (2) Query & Gallery
    #--------------------------------
    test_loader_obj = MarsLoader(root=root_path)
    query_data, gallery_data = test_loader_obj.load_data(home_dir='bbox_test/bbox_test')

    # Query
    filtered_query_data = []
    for (img_paths, pid, camid) in query_data:
        if pid <= 3000:
            filtered_query_data.append((img_paths, pid, camid))
    query_data = filtered_query_data

    final_query_samples = []
    query_tid = 1
    for (img_paths, pid, camid) in query_data:
        for p in img_paths:
            final_query_samples.append((p, pid, camid, query_tid))
        query_tid += 1

    query_image_paths = [x[0] for x in final_query_samples]
    query_pids        = [x[1] for x in final_query_samples]
    query_camids      = [x[2] for x in final_query_samples]
    query_tids        = [x[3] for x in final_query_samples]

    # Gallery
    filtered_gallery_data = []
    for (img_paths, pid, camid) in gallery_data:
        if pid <= 3000:
            filtered_gallery_data.append((img_paths, pid, camid))
    gallery_data = filtered_gallery_data

    final_gallery_samples = []
    gallery_tid = 1
    for (img_paths, pid, camid) in gallery_data:
        for p in img_paths:
            final_gallery_samples.append((p, pid, camid, gallery_tid))
        gallery_tid += 1

    gallery_image_paths = [x[0] for x in final_gallery_samples]
    gallery_pids        = [x[1] for x in final_gallery_samples]
    gallery_camids      = [x[2] for x in final_gallery_samples]
    gallery_tids        = [x[3] for x in final_gallery_samples]

    #--------------------------------
    # (3) モデルの準備 (ResNet50)
    #--------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = models.resnet50(pretrained=False)
    model.fc = nn.Identity()
    checkpoint = torch.load("C:/Users/sugie/PycharmProjects/pythonProject10/MARS/resnet50_msmt17_combineall_256x128_amsgrad_ep150_stp60_lr0.0015_b64_fb10_softmax_labelsmooth_flip_jitter.pth", map_location=device)
    state_dict = checkpoint.get("state_dict", checkpoint)
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict, strict=False)
    model.to(device).eval()

    #--------------------------------
    # (4) 変換定義
    #--------------------------------
    transform_no_aug = transforms.Compose([
        transforms.Resize((256, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    def get_aug_transform():
        return transforms.Compose([
            transforms.RandomResizedCrop(
                size=(256, 128),
                scale=(0.8, 1.0),
                ratio=(0.75, 1.3333)
            ),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
            ),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.25))
        ])

    #--------------------------------
    # (5) Dataset & Loader: Query は従来通り
    #--------------------------------
    query_dataset_noaug = ImageDataset(query_image_paths, transform_no_aug, is_aug=0, aug_id=0)
    query_loader_noaug  = DataLoader(query_dataset_noaug, batch_size=64, shuffle=False, collate_fn=skip_none_collate)

    #--------------------------------
    # (★A) Train: PIDごとに50刻みで分割してpkl出力
    #--------------------------------
    # Train全体のサンプル（final_train_samples）は (img_path, pid, camid, tracklet_id)
    # PIDを 0～49, 50～99, ... と分割します。
    unique_train_pids = sorted(list(set(train_pids)))
    chunk_size = 50
    chunk_starts = list(range(0, 3001, chunk_size))
    # Train用Augmentation回数（必要なら変更）
    N_AUG_TRAIN = 0

    for start_pid in chunk_starts:
        end_pid = start_pid + chunk_size
        train_pkl_path_chunk = f"train_embeddings_chunk_ReSNet_all_{start_pid}_{end_pid-1}.pkl"
        # PIDが [start_pid, end_pid-1] のサンプルを抽出
        chunk_indices = [i for i, pid in enumerate(train_pids) if start_pid <= pid < end_pid]
        if len(chunk_indices) == 0:
            continue
        chunk_image_paths = [train_image_paths[i] for i in chunk_indices]
        chunk_pids        = [train_pids[i]        for i in chunk_indices]
        chunk_camids      = [train_camids[i]      for i in chunk_indices]
        chunk_tids        = [train_tids[i]        for i in chunk_indices]

        # ---- (A) オリジナル ----
        train_dataset_noaug_chunk = ImageDataset(
            chunk_image_paths,
            transform_no_aug,
            is_aug=0, aug_id=0
        )
        train_loader_noaug_chunk  = DataLoader(
            train_dataset_noaug_chunk,
            batch_size=64,
            shuffle=False,
            collate_fn=skip_none_collate
        )

        extract_embeddings_and_save_batchwise_pkl(
            dataloader=train_loader_noaug_chunk,
            model=model,
            device=device,
            pkl_path=train_pkl_path_chunk,
            pids=chunk_pids,
            camids=chunk_camids,
            tids=chunk_tids,
            is_aug=0,
            aug_id=0,
            append_mode=False
        )

        # ---- (B) オーグメンテーション（N_AUG_TRAIN回）----
        for aug_i in range(1, N_AUG_TRAIN+1):
            transform_aug = get_aug_transform()
            train_dataset_aug_chunk = ImageDataset(
                chunk_image_paths,
                transform_aug,
                is_aug=1, aug_id=aug_i
            )
            train_loader_aug_chunk  = DataLoader(
                train_dataset_aug_chunk,
                batch_size=64,
                shuffle=False,
                collate_fn=skip_none_collate
            )

            extract_embeddings_and_save_batchwise_pkl(
                dataloader=train_loader_aug_chunk,
                model=model,
                device=device,
                pkl_path=train_pkl_path_chunk,
                pids=chunk_pids,
                camids=chunk_camids,
                tids=chunk_tids,
                is_aug=1,
                aug_id=aug_i,
                append_mode=True
            )
        print(f"Done train chunk [PID {start_pid}..{end_pid-1}]. PKL => {train_pkl_path_chunk}")

    # --------------------------------
    # (★B) Query: PIDごとに50刻みで分割してpkl出力
    # --------------------------------
    unique_query_pids = sorted(list(set(query_pids)))
    chunk_size = 50
    chunk_starts = list(range(0, 3001, chunk_size))
    N_AUG_QUERY = 3

    for start_pid in chunk_starts:
        end_pid = start_pid + chunk_size
        query_pkl_path_chunk = f"query_embeddings_chunk_ReSNet_all_{start_pid}_{end_pid - 1}.pkl"

        chunk_indices = [i for i, pid in enumerate(query_pids) if start_pid <= pid < end_pid]
        if len(chunk_indices) == 0:
            continue

        chunk_image_paths = [query_image_paths[i] for i in chunk_indices]
        chunk_pids = [query_pids[i] for i in chunk_indices]
        chunk_camids = [query_camids[i] for i in chunk_indices]
        chunk_tids = [query_tids[i] for i in chunk_indices]

        # ---- (A) オリジナル ----
        query_dataset_noaug_chunk = ImageDataset(
            chunk_image_paths,
            transform_no_aug,
            is_aug=0, aug_id=0
        )
        query_loader_noaug_chunk = DataLoader(
            query_dataset_noaug_chunk,
            batch_size=64,
            shuffle=False,
            collate_fn=skip_none_collate
        )

        extract_embeddings_and_save_batchwise_pkl(
            dataloader=query_loader_noaug_chunk,
            model=model,
            device=device,
            pkl_path=query_pkl_path_chunk,
            pids=chunk_pids,
            camids=chunk_camids,
            tids=chunk_tids,
            is_aug=0,
            aug_id=0,
            append_mode=False
        )

        # ---- (B) オーグメンテーション ----
        for aug_i in range(1, N_AUG_QUERY + 1):
            transform_aug = get_aug_transform()
            query_dataset_aug_chunk = ImageDataset(
                chunk_image_paths,
                transform_aug,
                is_aug=1, aug_id=aug_i
            )
            query_loader_aug_chunk = DataLoader(
                query_dataset_aug_chunk,
                batch_size=64,
                shuffle=False,
                collate_fn=skip_none_collate
            )

            extract_embeddings_and_save_batchwise_pkl(
                dataloader=query_loader_aug_chunk,
                model=model,
                device=device,
                pkl_path=query_pkl_path_chunk,
                pids=chunk_pids,
                camids=chunk_camids,
                tids=chunk_tids,
                is_aug=1,
                aug_id=aug_i,
                append_mode=True
            )
        print(f"Done query chunk [PID {start_pid}..{end_pid - 1}]. PKL => {query_pkl_path_chunk}")

    #--------------------------------
    # (★C) Gallery: PIDごとに50刻みで分割してpkl出力
    #--------------------------------
    unique_gallery_pids = sorted(list(set(gallery_pids)))
    chunk_size = 25
    chunk_starts = list(range(0, 3001, chunk_size))
    N_AUG_GAL = 3

    for start_pid in chunk_starts:
        end_pid = start_pid + chunk_size
        gallery_pkl_path_chunk = f"gallery_embeddings_chunk_ReSNet_all_{start_pid}_{end_pid-1}.pkl"

        chunk_indices = [i for i, pid in enumerate(gallery_pids)
                         if start_pid <= pid < end_pid]
        if len(chunk_indices) == 0:
            continue

        chunk_image_paths = [gallery_image_paths[i] for i in chunk_indices]
        chunk_pids        = [gallery_pids[i]        for i in chunk_indices]
        chunk_camids      = [gallery_camids[i]      for i in chunk_indices]
        chunk_tids        = [gallery_tids[i]        for i in chunk_indices]

        # ---- (A) オリジナル ----
        gallery_dataset_noaug_chunk = ImageDataset(
            chunk_image_paths,
            transform_no_aug,
            is_aug=0, aug_id=0
        )
        gallery_loader_noaug_chunk  = DataLoader(
            gallery_dataset_noaug_chunk,
            batch_size=64,
            shuffle=False,
            collate_fn=skip_none_collate
        )

        extract_embeddings_and_save_batchwise_pkl(
            dataloader=gallery_loader_noaug_chunk,
            model=model,
            device=device,
            pkl_path=gallery_pkl_path_chunk,
            pids=chunk_pids,
            camids=chunk_camids,
            tids=chunk_tids,
            is_aug=0,
            aug_id=0,
            append_mode=False
        )

        # ---- (B) オーグメンテーション ----
        for aug_i in range(1, N_AUG_GAL+1):
            transform_aug = get_aug_transform()
            gallery_dataset_aug_chunk = ImageDataset(
                chunk_image_paths,
                transform_aug,
                is_aug=1, aug_id=aug_i
            )
            gallery_loader_aug_chunk  = DataLoader(
                gallery_dataset_aug_chunk,
                batch_size=64,
                shuffle=False,
                collate_fn=skip_none_collate
            )

            extract_embeddings_and_save_batchwise_pkl(
                dataloader=gallery_loader_aug_chunk,
                model=model,
                device=device,
                pkl_path=gallery_pkl_path_chunk,
                pids=chunk_pids,
                camids=chunk_camids,
                tids=chunk_tids,
                is_aug=1,
                aug_id=aug_i,
                append_mode=True
            )
        print(f"Done gallery chunk [PID {start_pid}..{end_pid-1}]. PKL => {gallery_pkl_path_chunk}")

    print("\nAll done.")


NameError: name 'MarsTrainLoader' is not defined