In [1]:
# 1. 환경 설정 및 경로, Import

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision.models as models
import pandas as pd
import numpy as np
import random
import os
import sys
from tqdm import tqdm
from PIL import Image
from typing import Tuple, Dict, List
from google.colab import drive
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data.dataloader import default_collate


In [2]:
# 마운트 (공유 드라이브 경로가 MyDrive에 바로 연결되어 있다고 가정)
drive.mount('/content/drive')

# --- 경로 설정 ---
# 이 경로가 dataset.py, transforms.py, checkpoints 폴더가 있는 위치여야 합니다.
MODULE_PATH = "/content/drive/MyDrive/2025CV"
sys.path.append(MODULE_PATH)

Mounted at /content/drive


In [3]:
# Import Custom Modules (dataset.py, transforms.py)
# BBox 크롭 로직과 Transforms 정의를 가져옵니다.
from dataset_jw import DeepFashionC2S
from transforms import train_transform, val_transform

In [4]:
# --- Hyperparameters (팀원 간 통일 필수) ---
EXPERIMENT_SEED = 42
EMBEDDING_DIM = 128  # 128로 고정
LEARNING_RATE = 1e-4
TRIPLET_MARGIN = 0.5 # Online Semi-Hard Triplet Loss 마진 값
BATCH_SIZE = 32
PATIENCE = 5         # Early Stopping Patience (5 Epoch 동안 개선 없으면 중단)
MAX_EPOCHS = 40      # 최대 학습 Epoch 수
CHECKPOINT_DIR = os.path.join(MODULE_PATH, "checkpoints_C")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 재현성 확보를 위한 시드 고정
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    return torch.Generator().manual_seed(seed)

generator = set_seed(EXPERIMENT_SEED)

Using device: cuda


In [5]:
# --- CSV 파일 로드 (샘플링된 CSV 사용) ---
CSV_PATH_LIGHT = os.path.join(MODULE_PATH, "meta_c2s_10_2_2_sampling_ID.csv")

In [6]:
import os
import shutil
import pandas as pd

# --- 경로 설정 확인 ---

DRIVE_IMG_ROOT = os.path.join(MODULE_PATH, "Images") # 원본 이미지 루트 경로 (Drive)
LOCAL_IMG_ROOT = "/content/Images"                   # 타겟 이미지 루트 경로 (Local)

In [7]:
# =========================================================
# 이게 40분 걸려요
# CSV 기반 선택적 이미지 로컬 복사 및 I/O 최적화
# =========================================================
print("CSV 기반 선택적 이미지 로컬 런타임 복사 시작...")

# 1. CSV 파일 로드
try:
    df_light = pd.read_csv(CSV_PATH_LIGHT)
except FileNotFoundError:
    print(f"오류: CSV 파일을 찾을 수 없습니다. 경로를 확인하세요: {CSV_PATH_LIGHT}")
    exit()

# 2. 필요한 모든 이미지 경로 추출
# consumer_path와 shop_path 열에서 유니크한 경로만 추출합니다.
required_paths = pd.concat([df_light['consumer_path'], df_light['shop_path']]).unique()
print(f"총 {len(required_paths)}개의 유니크한 이미지 파일을 복사합니다.")

# 3. 로컬 타겟 폴더 생성
os.makedirs(LOCAL_IMG_ROOT, exist_ok=True)

# 4. 파일 복사 및 폴더 구조 유지
copied_count = 0
for relative_path in required_paths:
    # 원본 파일 경로 (Drive)
    source_file_path = os.path.join(DRIVE_IMG_ROOT, relative_path)

    # 타겟 파일 경로 (Local)
    target_file_path = os.path.join(LOCAL_IMG_ROOT, relative_path)

    # 타겟 디렉토리 생성 (예: /content/Images/img/TOPS/Summer_ 에 필요한 폴더 생성)
    target_dir = os.path.dirname(target_file_path)
    os.makedirs(target_dir, exist_ok=True)

    # 파일 복사
    try:
        if not os.path.exists(target_file_path):
             shutil.copy2(source_file_path, target_file_path)
             copied_count += 1
    except FileNotFoundError:
        # CSV에 경로가 있지만 실제 Drive에 파일이 없는 경우 건너뜁니다.
        print(f"[경고] 원본 파일이 Drive에 없습니다. 건너뜀: {source_file_path}")
    except Exception as e:
        print(f"[오류] 복사 중 예외 발생 ({relative_path}): {e}")

print(f"✅ 로컬 복사 완료. 총 {copied_count}개의 파일 복사됨.")

CSV 기반 선택적 이미지 로컬 런타임 복사 시작...
총 10546개의 유니크한 이미지 파일을 복사합니다.
✅ 로컬 복사 완료. 총 10546개의 파일 복사됨.


In [8]:
# =========================================================
# 3. IMG_ROOT_DIR 변수를 로컬 경로로 변경
# =========================================================
# 기존 경로 변수를 새로운 로컬 경로로 덮어씁니다.
IMG_ROOT_DIR = LOCAL_IMG_ROOT

print(f"새로운 이미지 루트 경로: {IMG_ROOT_DIR}")

새로운 이미지 루트 경로: /content/Images


In [9]:
# ============================================
# item_id 문자열 → 숫자 라벨 변환 매핑 생성
# ============================================

df_full = pd.read_csv(CSV_PATH_LIGHT)
unique_ids = df_full["item_id"].unique()

id2label = {id_str: idx for idx, id_str in enumerate(unique_ids)}
print("총 unique item_id 개수:", len(id2label))


총 unique item_id 개수: 1467


In [10]:
# -------------------
# CLAHE & Gray World
# -------------------
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms

def apply_clahe_pil(img: Image.Image) -> Image.Image:
    arr = np.array(img)
    yuv = cv2.cvtColor(arr, cv2.COLOR_RGB2YUV)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    yuv[:,:,0] = clahe.apply(yuv[:,:,0])
    out = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB)
    return Image.fromarray(out)

def apply_grayworld_pil(img: Image.Image) -> Image.Image:
    img_arr = np.array(img).astype(np.float32)
    mean_channels = img_arr.mean(axis=(0, 1), keepdims=True)
    gray_mean = mean_channels.mean()
    img_arr = img_arr * (gray_mean / (mean_channels + 1e-6))
    img_arr = np.clip(img_arr, 0, 255).astype(np.uint8)
    return Image.fromarray(img_arr)

# -------------------
# Transform Builder
# -------------------
def build_transforms(domain_mode="baseline", train=True, size=224):
    ops = []

    # resizing + augmentation
    ops.append(transforms.Resize((size, size)))
    if train:
        ops.append(transforms.RandomHorizontalFlip(0.5))
        ops.append(transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2))

    # domain normalization
    if domain_mode == "grayworld":
        ops.append(transforms.Lambda(apply_grayworld_pil))
    elif domain_mode == "clahe":
        ops.append(transforms.Lambda(apply_clahe_pil))
    elif domain_mode == "both":
        ops.append(transforms.Lambda(apply_clahe_pil))
        ops.append(transforms.Lambda(apply_grayworld_pil))
    # baseline: 아무것도 적용 안 함

    # tensor + normalize
    ops.append(transforms.ToTensor())
    ops.append(transforms.Normalize(mean=[0.485,0.456,0.406],
                                    std=[0.229,0.224,0.225]))

    return transforms.Compose(ops)


In [11]:
# -------------------
# Dataset
# -------------------
class DeepFashionC2S(torch.utils.data.Dataset):
    def __init__(self, csv_path, img_root, transform=None, split='train'):
        self.df = pd.read_csv(csv_path)
        self.df = self.df[self.df['split']==split].reset_index(drop=True)
        self.img_root = img_root
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def load_crop(self, img_path, x1, y1, x2, y2):
        full_path = os.path.join(self.img_root, img_path)
        if not os.path.exists(full_path):
            raise FileNotFoundError(f"File not found: {full_path}")
        img = Image.open(full_path).convert("RGB")
        return img.crop((x1, y1, x2, y2))

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        try:
            cons_img = self.load_crop(row['consumer_path'], row['cons_x1'], row['cons_y1'], row['cons_x2'], row['cons_y2'])
            shop_img = self.load_crop(row['shop_path'], row['shop_x1'], row['shop_y1'], row['shop_x2'], row['shop_y2'])
            if self.transform:
                cons_img = self.transform(cons_img)
                shop_img = self.transform(shop_img)
            return {"consumer": cons_img, "shop": shop_img, "item_id": row["item_id"]}
        except FileNotFoundError:
            print(f"[WARNING] File missing at idx {idx}. Skipping.")
            return None

In [12]:
# -------------------
# Feature Embedding (EfficientNet-B3)
# -------------------
class FeatureEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        base_model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
        num_ftrs = base_model.classifier[1].in_features
        self.feature_extractor = base_model.features
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.final_fc = nn.Linear(num_ftrs, embedding_dim)
        self.bn = nn.BatchNorm1d(embedding_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.final_fc(x)
        x = self.bn(x)
        x = F.normalize(x, p=2, dim=1)
        return x.to(DEVICE)

In [13]:
# -------------------
# Loss: Batch-Hard Triplet
# -------------------
def pairwise_distance_sq(embeddings):
    dot_product = torch.matmul(embeddings, embeddings.t())
    square_norm = torch.diag(dot_product)
    distances = square_norm.unsqueeze(0) - 2*dot_product + square_norm.unsqueeze(1)
    distances[distances<0] = 0
    return distances

def batch_hard_triplet_loss(embeddings, labels, margin):
    pair_dist = pairwise_distance_sq(embeddings)
    labels_eq = labels.unsqueeze(0).eq(labels.unsqueeze(1))
    dist_ap = pair_dist.clone(); dist_ap[~labels_eq] = float('-inf')
    dist_an = pair_dist.clone(); dist_an[labels_eq] = float('inf')
    hardest_positive_dist, _ = dist_ap.max(dim=1)
    hardest_negative_dist, _ = dist_an.min(dim=1)
    triplet_loss = torch.clamp(hardest_positive_dist - hardest_negative_dist + margin, min=0.0)
    num_hard_triplets = triplet_loss.gt(1e-16).float().sum()
    return triplet_loss.sum()/num_hard_triplets if num_hard_triplets>0 else (embeddings*0).sum()

In [14]:
# -------------------
# Custom collate_fn
# -------------------
def custom_collate_fn(batch: List[Dict]):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    return default_collate(batch)

In [15]:
# -------------------
# Recall@K
# -------------------
def calculate_recall_at_k(model, dataloader, device, ks=[1,5,10]):
    model.eval()
    all_query, all_gallery, all_labels = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            consumer_imgs = batch["consumer"]
            shop_imgs = batch["shop"]
            item_ids = batch["item_id"]
            if isinstance(item_ids, (list, tuple)):
                labels_tensor = torch.tensor([id2label[i] for i in item_ids], dtype=torch.long)
            else:
                labels_tensor = torch.tensor([id2label[item_ids]], dtype=torch.long)
            query_embs = model(consumer_imgs.to(device)).cpu().numpy()
            gallery_embs = model(shop_imgs.to(device)).cpu().numpy()
            all_query.append(query_embs)
            all_gallery.append(gallery_embs)
            all_labels.append(labels_tensor.cpu().numpy())
    query_embs = np.concatenate(all_query, axis=0)
    gallery_embs = np.concatenate(all_gallery, axis=0)
    gallery_labels = np.concatenate(all_labels, axis=0)
    sims = cosine_similarity(query_embs, gallery_embs)
    recalls = {}
    for k in ks:
        topk_idx = np.argsort(-sims, axis=1)[:, :k]
        correct = sum([gallery_labels[i] in gallery_labels[topk_idx[i]] for i in range(len(gallery_labels))])
        recalls[f'R@{k}'] = correct / len(gallery_labels)
    return recalls

In [16]:
# -------------------
# Train Loop
# -------------------
def get_checkpoint_paths():
    return os.path.join(CHECKPOINT_DIR, "checkpoint.pth"), os.path.join(CHECKPOINT_DIR, "best.pth")

def save_checkpoint(model, optimizer, epoch, best_val_metric, patience_count, filename):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save({'epoch':epoch,'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'best_val_metric':best_val_metric,
                'patience_count':patience_count}, filename)

def load_checkpoint(model, optimizer, filename):
    if not os.path.exists(filename):
        return 0, 0.0, 0
    checkpoint = torch.load(filename,map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']+1, checkpoint['best_val_metric'], checkpoint['patience_count']

def train_model(train_dl, val_dl, experiment_name):

    # -------------------
    # Prepare checkpoint path
    # -------------------
    ckpt_dir = os.path.join(CHECKPOINT_DIR, experiment_name)
    os.makedirs(ckpt_dir, exist_ok=True)

    checkpoint_file = os.path.join(ckpt_dir, "checkpoint.pth")
    best_file = os.path.join(ckpt_dir, "best.pth")
    history_file = os.path.join(ckpt_dir, "history.csv")


    # -------------------
    # Load previous history (APPEND mode)
    # -------------------
    if os.path.exists(history_file):
        old_history = pd.read_csv(history_file)
        history = old_history.to_dict("records")
        print(f"Loaded previous history: {len(history)} records")
    else:
        history = []
        print("Starting new history file")

    # -------------------
    # Model & Optimizer
    # -------------------
    model = FeatureEmbedding(EMBEDDING_DIM).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    start_epoch, best_val_metric, patience_counter = load_checkpoint(model, optimizer, checkpoint_file)

    for epoch in range(start_epoch, MAX_EPOCHS):

        model.train()
        total_loss = 0

        for batch in tqdm(train_dl, leave=False):
            consumer_imgs = batch["consumer"].to(DEVICE)
            shop_imgs = batch["shop"].to(DEVICE)
            item_ids = batch["item_id"]

            if isinstance(item_ids, (list, tuple)):
                labels = torch.tensor([id2label[i] for i in item_ids], dtype=torch.long)
            else:
                labels = torch.tensor([id2label[item_ids]], dtype=torch.long)

            labels = labels.to(DEVICE)

            all_imgs = torch.cat([consumer_imgs, shop_imgs], dim=0)
            all_labels = torch.cat([labels, labels], dim=0)

            optimizer.zero_grad()
            embeddings = model(all_imgs)
            loss = batch_hard_triplet_loss(embeddings, all_labels, TRIPLET_MARGIN)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dl)

        # -------------------
        # Validation
        # -------------------
        val_recalls = calculate_recall_at_k(model, val_dl, DEVICE, ks=[1,5,10])
        val_metric = val_recalls["R@5"]

        history.append({
            "epoch": epoch + 1,
            "train_loss": avg_loss,
            "val_R@1": val_recalls["R@1"],
            "val_R@5": val_recalls["R@5"],
            "val_R@10": val_recalls["R@10"],
        })

        # Save full history (overwrite but includes appended records)
        pd.DataFrame(history).to_csv(history_file, index=False)

        # -------------------
        # Check improvement
        # -------------------
        if val_metric > best_val_metric:
            best_val_metric = val_metric
            patience_counter = 0
            save_checkpoint(model, optimizer, epoch, best_val_metric, patience_counter, best_file)
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"Early stopping at epoch {epoch+1}")
                break

        # Save intermediate checkpoint
        save_checkpoint(model, optimizer, epoch, best_val_metric, patience_counter, checkpoint_file)


    # -------------------
    # Load Best Model
    # -------------------
    load_checkpoint(model, optimizer, best_file)

    final_recalls = calculate_recall_at_k(model, val_dl, DEVICE, ks=[1,5,10])
    print(f"Final Best R@1: {final_recalls['R@1']:.4f}, "
          f"R@5: {final_recalls['R@5']:.4f}, "
          f"R@10: {final_recalls['R@10']:.4f}")

    # 이걸 실험 루프에서 받아서 저장함
    return final_recalls["R@1"], final_recalls["R@5"], final_recalls["R@10"]

In [17]:
NUM_WORKERS = 2

In [18]:
# DOMAIN_MODES = ["none", "grayworld", "clahe", "both"]
DOMAIN_MODES = ["both"]

In [19]:
results = {}

for mode in DOMAIN_MODES:
    print(f"\n===== Running: {mode} =====")

    train_transform = build_transforms(domain_mode=mode, train=True)
    val_transform   = build_transforms(domain_mode=mode, train=False)

    train_ds = DeepFashionC2S(csv_path=CSV_PATH_LIGHT, img_root=IMG_ROOT_DIR, transform=train_transform, split="train")
    val_ds   = DeepFashionC2S(csv_path=CSV_PATH_LIGHT, img_root=IMG_ROOT_DIR, transform=val_transform, split="val")

    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=custom_collate_fn)
    val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          collate_fn=custom_collate_fn)

    best_r1, best_r5, best_r10 = train_model(
        train_dl, val_dl,
        experiment_name=f"domain_{mode}"
    )

    results[mode] = {
        "R1": best_r1,
        "R5": best_r5,
        "R10": best_r10
    }


===== Running: both =====
Loaded previous history: 16 records
Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth


100%|██████████| 47.2M/47.2M [00:00<00:00, 131MB/s]
100%|██████████| 35/35 [00:12<00:00,  2.86it/s]
100%|██████████| 35/35 [00:12<00:00,  2.84it/s]
100%|██████████| 35/35 [00:12<00:00,  2.77it/s]
100%|██████████| 35/35 [00:12<00:00,  2.83it/s]
100%|██████████| 35/35 [00:12<00:00,  2.80it/s]
100%|██████████| 35/35 [00:12<00:00,  2.81it/s]
100%|██████████| 35/35 [00:12<00:00,  2.83it/s]
100%|██████████| 35/35 [00:11<00:00,  2.95it/s]


Early stopping at epoch 24


100%|██████████| 35/35 [00:12<00:00,  2.81it/s]


Final Best R@1: 0.6117, R@5: 0.6782, R@10: 0.7284


In [24]:
base_dir = "/content/drive/MyDrive/2025CV/checkpoints_C"
modes = ["domain_none", "domain_grayworld", "domain_clahe", "domain_both"]

best_results = []

for mode in modes:
  csv_path = os.path.join(base_dir, mode, "history.csv")
  df = pd.read_csv(csv_path)

  best_r1_epoch = df["val_R@1"].idxmax()
  best_r5_epoch = df["val_R@10"].idxmax()
  best_r10_epoch = df["val_R@10"].idxmax()

  best_results.append({
      "mode": mode,
      "best_R1_epoch": int(df.loc[best_r1_epoch, "epoch"]),
      "best_R1_value": float(df.loc[best_r1_epoch, "val_R@1"]),
      "best_R5_epoch": int(df.loc[best_r5_epoch, "epoch"]),
      "best_R5_value": float(df.loc[best_r5_epoch, "val_R@5"]),
      "best_R10_epoch": int(df.loc[best_r10_epoch, "epoch"]),
      "best_R10_value": float(df.loc[best_r10_epoch, "val_R@10"]),
  })

pd.DataFrame(best_results)

Unnamed: 0,mode,best_R1_epoch,best_R1_value,best_R5_epoch,best_R5_value,best_R10_epoch,best_R10_value
0,domain_none,7,0.596171,9,0.644485,9,0.696445
1,domain_grayworld,11,0.610757,12,0.682771,12,0.736554
2,domain_clahe,20,0.602552,18,0.672744,18,0.731085
3,domain_both,19,0.611668,15,0.669098,15,0.729262
