In [None]:
import os

checkpoints_to_delete = [
    "/kaggle/working/fusion_best.pt",
    "/kaggle/working/fusion_checkpoint.pt"
]

for checkpoint in checkpoints_to_delete:
    if os.path.exists(checkpoint):
        os.remove(checkpoint)
        print(f"Deleted: {checkpoint}")
    else:
        print(f"File not found: {checkpoint}")

print("Cleanup complete! Ready for fresh training.")

In [None]:

import os
import sys
import math
import time
import json
import random
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.amp import autocast, GradScaler

import torchvision
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.utils.class_weight import compute_class_weight

# -----------------------
# Config
# -----------------------
SEED = 42
NUM_EPOCHS = 12
BATCH_SIZE = 32
NUM_WORKERS = 2
IMAGE_SIZE = 224
MAX_TFIDF_FEATURES = 5000
TEXT_NGRAMS = (1, 3)  # unigrams to trigrams
LR = 2e-4
WEIGHT_DECAY = 1e-4
LABEL_SMOOTHING = 0.1
EARLY_STOP_PATIENCE = 4
CHECKPOINT_PATH = "/kaggle/working/fusion_checkpoint_v2.pt"
BEST_MODEL_PATH = "/kaggle/working/fusion_best_v2.pt"
RUN_NAME = "fusion_text_image_fashion_balanced_attention"
PCT_TRAIN = 0.8
TARGET_COLUMN = "masterCategory"
MIN_SAMPLES_PER_CLASS = 20  
MC_DROPOUT_PASSES = 30

# Focal Loss parameters
FOCAL_ALPHA = 1.0
FOCAL_GAMMA = 2.0

# Attention parameters
ATTENTION_DIM = 256
FUSION_DIM = 512

# -----------------------
# Utils
# -----------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
AMP_DEVICE_TYPE = 'cuda' if device.type == 'cuda' else 'cpu'

def find_fashion_dataset_root():
    base = "/kaggle/input"
    if not os.path.exists(base):
        return None
    candidates = []
    for name in os.listdir(base):
        path = os.path.join(base, name)
        if not os.path.isdir(path):
            continue
        styles_here = os.path.exists(os.path.join(path, "styles.csv"))
        images_here = os.path.isdir(os.path.join(path, "images"))
        if styles_here and images_here:
            return path
        for sub in os.listdir(path):
            subpath = os.path.join(path, sub)
            if os.path.isdir(subpath):
                if os.path.exists(os.path.join(subpath, "styles.csv")) and os.path.isdir(os.path.join(subpath, "images")):
                    return subpath
    return None

# -----------------------
# Focal Loss for Class Imbalance
# -----------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = class_weights
        self.label_smoothing = label_smoothing
        
    def forward(self, inputs, targets):
        # Standard cross entropy
        ce_loss = F.cross_entropy(inputs, targets, weight=self.class_weights, 
                                 label_smoothing=self.label_smoothing, reduction='none')
        
        # Focal loss computation
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        return focal_loss.mean()

# -----------------------
# Cross-Modal Attention Mechanism
# -----------------------
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.text_to_img = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
        )
        self.img_to_text = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
        )
        self.text_ln1 = nn.LayerNorm(d_model)
        self.text_ln2 = nn.LayerNorm(d_model)
        self.img_ln1 = nn.LayerNorm(d_model)
        self.img_ln2 = nn.LayerNorm(d_model)
        self.text_ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * d_model, d_model),
        )
        self.img_ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * d_model, d_model),
        )

    def forward(self, text_tokens: torch.Tensor, img_tokens: torch.Tensor):
        # Text attends to image tokens (Q from text, K/V from image)
        t2i_out, t2i_attn = self.text_to_img(
            query=text_tokens,
            key=img_tokens,
            value=img_tokens,
            need_weights=True,
            average_attn_weights=True,
        )  # t2i_attn: [B, T_text, T_img]
        text_updated = self.text_ln1(text_tokens + t2i_out)
        text_updated = self.text_ln2(text_updated + self.text_ffn(text_updated))

        # Image attends to text tokens (optional symmetric co-attention)
        i2t_out, _ = self.img_to_text(
            query=img_tokens,
            key=text_tokens,
            value=text_tokens,
            need_weights=False,
        )
        img_updated = self.img_ln1(img_tokens + i2t_out)
        img_updated = self.img_ln2(img_updated + self.img_ffn(img_updated))

        return text_updated, img_updated, t2i_attn

# -----------------------
# Enhanced Fusion Model with Attention
# -----------------------
class AttentionFusionNet(nn.Module):
    def __init__(self, num_text_features: int, num_classes: int):
        super().__init__()

        # Image backbone -> spatial tokens
        backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.image_cnn = nn.Sequential(*list(backbone.children())[:-2])  # up to layer4 output
        self.img_channel_dim = 512
        self.img_token_proj = nn.Conv2d(self.img_channel_dim, ATTENTION_DIM, kernel_size=1)

        # Text branch - deeper network (produces a single token)
        self.text_branch = nn.Sequential(
            nn.Linear(num_text_features, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.text_token_proj = nn.Linear(256, ATTENTION_DIM)

        # Cross-modal multi-head attention
        self.cross_attention = MultiHeadCrossAttention(
            d_model=ATTENTION_DIM,
            num_heads=8,
            dropout=0.1,
        )

        # Fusion head
        self.fusion_head = nn.Sequential(
            nn.Linear(ATTENTION_DIM * 2, FUSION_DIM),
            nn.ReLU(),
            nn.BatchNorm1d(FUSION_DIM),
            nn.Dropout(0.4),
            nn.Linear(FUSION_DIM, FUSION_DIM // 2),
            nn.ReLU(),
            nn.BatchNorm1d(FUSION_DIM // 2),
            nn.Dropout(0.3),
            nn.Linear(FUSION_DIM // 2, num_classes),
        )

    def forward(self, images: torch.Tensor, text_vecs: torch.Tensor):
        # Image -> tokens [B, T_img, D]
        feats = self.image_cnn(images)  # [B, 512, H, W]
        feats = self.img_token_proj(feats)  # [B, D, H, W]
        B, D, H, W = feats.shape
        img_tokens = feats.flatten(2).transpose(1, 2)  # [B, H*W, D]

        # Text -> single token [B, 1, D]
        text_feats = self.text_branch(text_vecs)  # [B, 256]
        text_tokens = self.text_token_proj(text_feats).unsqueeze(1)  # [B, 1, D]

        # Cross-attention
        text_out, img_out_tokens, t2i_attn = self.cross_attention(text_tokens, img_tokens)

        # Pool image tokens and fuse
        img_pooled = img_out_tokens.mean(dim=1)  # [B, D]
        fused = torch.cat([img_pooled, text_out.squeeze(1)], dim=1)
        logits = self.fusion_head(fused)

        # Attention stats: normalized entropy of text->image attention
        with torch.no_grad():
            p = t2i_attn.float().clamp(min=1e-6)  # [B, 1, T_img]
            entropy = - (p * p.log()).sum(dim=-1) / math.log(p.size(-1))  # [B, 1]
            t2i_entropy = entropy.mean()  # scalar tensor
        attention_stats = {"t2i_entropy": t2i_entropy}

        return logits, attention_stats

# -----------------------
# Data Augmentation for Minority Classes
# -----------------------
class BalancedFashionDataset(Dataset):
    def __init__(self, frame: pd.DataFrame, text_csr, transform, text_vectorizer_vocab_size: int, 
                 oversample_minority=True, target_samples_per_class=None):
        self.original_frame = frame.reset_index(drop=True)
        self.text_csr = text_csr
        self.transform = transform
        self.num_text_features = text_vectorizer_vocab_size
        self.oversample_minority = oversample_minority
        
        if oversample_minority and target_samples_per_class is not None:
            self.frame, self.augmented_text_csr = self._oversample_minority_classes(target_samples_per_class)
        else:
            self.frame = self.original_frame
            self.augmented_text_csr = text_csr

    def _oversample_minority_classes(self, target_samples):
        """Simple oversampling by repeating minority class samples"""
        class_counts = self.original_frame['label_idx'].value_counts()
        augmented_frames = []
        augmented_texts = []
        
        for class_idx in class_counts.index:
            class_frame = self.original_frame[self.original_frame['label_idx'] == class_idx]
            class_text = self.text_csr[class_frame.index]
            current_count = len(class_frame)
            
            if current_count < target_samples:
                # Oversample this class
                repeat_factor = math.ceil(target_samples / current_count)
                repeated_indices = np.tile(class_frame.index.values, repeat_factor)[:target_samples]
                
                oversampled_frame = self.original_frame.loc[repeated_indices].reset_index(drop=True)
                oversampled_text = self.text_csr[repeated_indices]
                
                augmented_frames.append(oversampled_frame)
                augmented_texts.append(oversampled_text)
            else:
                augmented_frames.append(class_frame.reset_index(drop=True))
                augmented_texts.append(class_text)
        
        final_frame = pd.concat(augmented_frames, ignore_index=True)
        
        # Stack text matrices
        from scipy.sparse import vstack
        final_text = vstack(augmented_texts)
        
        return final_frame, final_text

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx: int):
        row = self.frame.iloc[idx]
        
        # Image with enhanced augmentation for minority classes
        img = Image.open(row["image_path"]).convert("RGB")
        img = self.transform(img)
        
        # Text vector
        x_text_dense = self.augmented_text_csr[idx].toarray().astype(np.float32).squeeze(0)
        x_text = torch.from_numpy(x_text_dense)
        
        label = int(row["label_idx"])
        
        return img, x_text, label, row["productDisplayName"], row["image_path"]

# -----------------------
# Load and prepare data
# -----------------------
DATASET_ROOT = find_fashion_dataset_root()
if DATASET_ROOT is None:
    raise RuntimeError("Could not find dataset. Please add the Kaggle dataset as input.")
print(f"Using dataset at: {DATASET_ROOT}")

styles_path = os.path.join(DATASET_ROOT, "styles.csv")
df = pd.read_csv(styles_path, on_bad_lines="skip")

def resolve_image_path(row_id):
    for ext in (".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"):
        candidate = os.path.join(DATASET_ROOT, "images", f"{row_id}{ext}")
        if os.path.exists(candidate):
            return candidate
    return None

df["image_path"] = df["id"].apply(resolve_image_path)
df = df.dropna(subset=["image_path", "productDisplayName", TARGET_COLUMN]).copy()

# Handle extreme class imbalance - remove classes with very few samples
class_counts = df[TARGET_COLUMN].value_counts()
print("Original class distribution:")
for class_name, count in class_counts.items():
    print(f"  {class_name}: {count}")

# Remove extremely small classes (they cause training instability)
EXTREME_MIN_SAMPLES = 200  # Increase threshold significantly
valid_classes = class_counts[class_counts >= EXTREME_MIN_SAMPLES].index.tolist()
removed_classes = class_counts[class_counts < EXTREME_MIN_SAMPLES].index.tolist()

if removed_classes:
    print(f"\nRemoving classes with < {EXTREME_MIN_SAMPLES} samples: {removed_classes}")

df = df[df[TARGET_COLUMN].isin(valid_classes)].copy()

# Encode labels
classes = sorted(df[TARGET_COLUMN].unique().tolist())
class_to_idx = {c: i for i, c in enumerate(classes)}
idx_to_class = {i: c for c, i in class_to_idx.items()}
df["label_idx"] = df[TARGET_COLUMN].map(class_to_idx).astype(int)

print(f"Num samples: {len(df)} | Num classes: {len(classes)}")
print("Class distribution:")
for class_name, count in class_counts[valid_classes].items():
    print(f"  {class_name}: {count}")

# Train/val split
train_df, val_df = train_test_split(
    df,
    test_size=1 - PCT_TRAIN,
    random_state=SEED,
    stratify=df["label_idx"]
)

# Compute class weights for focal loss
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(train_df['label_idx']),
    y=train_df['label_idx']
)
class_weights_tensor = torch.FloatTensor(class_weights).to(device)
print("Class weights:", class_weights)

# Text processing
def normalize_text(s):
    if not isinstance(s, str):
        return ""
    return s.lower().strip()

train_texts = train_df["productDisplayName"].fillna("").apply(normalize_text).tolist()
val_texts = val_df["productDisplayName"].fillna("").apply(normalize_text).tolist()

vectorizer = TfidfVectorizer(
    max_features=MAX_TFIDF_FEATURES,
    ngram_range=TEXT_NGRAMS,
    stop_words="english",
    min_df=2,
    max_df=0.95
)
X_train_text = vectorizer.fit_transform(train_texts)
X_val_text = vectorizer.transform(val_texts)

# Enhanced transforms
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE + 64, IMAGE_SIZE + 64)),
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.6, 1.0), ratio=(0.7, 1.4)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.2), value="random"),
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Calculate reasonable target samples per class for oversampling
target_samples_per_class = min(5000, max(class_counts[valid_classes]) // 3)  # Cap at 5000, use 1/3 of majority

# Datasets with oversampling
train_ds = BalancedFashionDataset(
    train_df, X_train_text, train_transform, X_train_text.shape[1],
    oversample_minority=True, target_samples_per_class=target_samples_per_class
)

val_ds = BalancedFashionDataset(
    val_df, X_val_text, val_transform, X_val_text.shape[1],
    oversample_minority=False
)

print(f"Training samples after balancing: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")

# Use simpler, more stable data loading without weighted sampling
train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,  # Simple shuffle instead of weighted sampling
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=False,
)

# -----------------------
# Model Setup
# -----------------------
num_classes = len(classes)
model = AttentionFusionNet(
    num_text_features=X_train_text.shape[1], 
    num_classes=num_classes
).to(device)

# Focal loss with class weights
criterion = FocalLoss(
    alpha=FOCAL_ALPHA,
    gamma=FOCAL_GAMMA,
    class_weights=class_weights_tensor,
    label_smoothing=LABEL_SMOOTHING
)

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = GradScaler(AMP_DEVICE_TYPE)

# Cosine annealing with warm restarts
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=max(1, len(train_loader) * 2), T_mult=2, eta_min=LR * 0.1
)

start_epoch = 0
best_val_f1 = -1.0

# -----------------------
# Training Functions
# -----------------------
def train_one_epoch(epoch: int):
    model.train()
    total_loss = 0.0
    all_preds = []
    all_targets = []
    attn_entropies = []

    for step, batch in enumerate(train_loader):
        images, text_vecs, labels, _, _ = batch
        images = images.to(device, non_blocking=True)
        text_vecs = text_vecs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type=AMP_DEVICE_TYPE, enabled=(AMP_DEVICE_TYPE == 'cuda')):
            logits, attention_stats = model(images, text_vecs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item() * images.size(0)
        preds = torch.argmax(logits.detach(), dim=1).cpu().numpy().tolist()
        all_preds.extend(preds)
        all_targets.extend(labels.detach().cpu().numpy().tolist())

        t2i_entropy_val = attention_stats["t2i_entropy"]
        if isinstance(t2i_entropy_val, torch.Tensor):
            t2i_entropy_val = float(t2i_entropy_val.detach().cpu().item())
        attn_entropies.append(t2i_entropy_val)

        if (step + 1) % 100 == 0:
            avg_entropy = float(np.mean(attn_entropies)) if len(attn_entropies) > 0 else float('nan')
            print(
                f"Epoch {epoch} | Step {step+1}/{len(train_loader)} | Loss {loss.item():.4f} | "
                f"t2i_attn_entropy {avg_entropy:.3f}"
            )

    avg_loss = total_loss / len(train_ds)
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average="macro")
    avg_attn_entropy = float(np.mean(attn_entropies)) if len(attn_entropies) > 0 else float('nan')

    return avg_loss, acc, f1, avg_attn_entropy

@torch.no_grad()
def evaluate():
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []
    attn_entropies = []

    for batch in val_loader:
        images, text_vecs, labels, _, _ = batch
        images = images.to(device, non_blocking=True)
        text_vecs = text_vecs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        with autocast(device_type=AMP_DEVICE_TYPE, enabled=(AMP_DEVICE_TYPE == 'cuda')):
            logits, attention_stats = model(images, text_vecs)
            loss = criterion(logits, labels)

        total_loss += loss.item() * images.size(0)
        preds = torch.argmax(logits, dim=1).cpu().numpy().tolist()
        all_preds.extend(preds)
        all_targets.extend(labels.cpu().numpy().tolist())

        t2i_entropy_val = attention_stats["t2i_entropy"]
        if isinstance(t2i_entropy_val, torch.Tensor):
            t2i_entropy_val = float(t2i_entropy_val.detach().cpu().item())
        attn_entropies.append(t2i_entropy_val)

    avg_loss = total_loss / len(val_ds)
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average="macro")
    avg_attn_entropy = float(np.mean(attn_entropies)) if len(attn_entropies) > 0 else float('nan')

    return avg_loss, acc, f1, np.array(all_preds), np.array(all_targets), avg_attn_entropy

# -----------------------
# Training Loop
# -----------------------
epochs_no_improve = 0
history = []

print("Starting training with attention fusion and class balancing...")

for epoch in range(start_epoch, NUM_EPOCHS):
    t0 = time.time()
    tr_loss, tr_acc, tr_f1, tr_attn_entropy = train_one_epoch(epoch)
    va_loss, va_acc, va_f1, va_preds, va_tgts, va_attn_entropy = evaluate()

    print(
        f"[Epoch {epoch}] "
        f"train_loss={tr_loss:.4f} acc={tr_acc:.4f} f1={tr_f1:.4f} | "
        f"val_loss={va_loss:.4f} acc={va_acc:.4f} f1={va_f1:.4f} | "
        f"time={(time.time()-t0):.1f}s"
    )

    print(f"Train t2i attention entropy: {tr_attn_entropy:.3f}")
    print(f"Val t2i attention entropy: {va_attn_entropy:.3f}")

    # Save checkpoint
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict(),
        "best_val_f1": best_val_f1,
        "classes": classes,
        "class_to_idx": class_to_idx,
        "idx_to_class": idx_to_class,
    }, CHECKPOINT_PATH)

    # Save best model
    if va_f1 > best_val_f1:
        best_val_f1 = va_f1
        epochs_no_improve = 0
        torch.save({
            "model": model.state_dict(),
            "best_val_f1": best_val_f1,
            "classes": classes,
            "class_to_idx": class_to_idx,
            "idx_to_class": idx_to_class,
        }, BEST_MODEL_PATH)
        print(f"New best model saved with val_f1={best_val_f1:.4f}")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= EARLY_STOP_PATIENCE:
            print("Early stopping triggered.")
            break

# -----------------------
# Final Evaluation
# -----------------------
if os.path.exists(BEST_MODEL_PATH):
    ckpt = torch.load(BEST_MODEL_PATH, map_location=device, weights_only=False)
    model.load_state_dict(ckpt["model"])

val_loss, val_acc, val_f1, val_preds, val_tgts, val_attn_entropy = evaluate()
print(f"\nFinal Results:")
print(f"Validation: loss={val_loss:.4f} acc={val_acc:.4f} f1={val_f1:.4f}")
print(f"Final t2i attention entropy: {val_attn_entropy:.3f}")
print("\nDetailed Classification Report:")
print(classification_report(val_tgts, val_preds, target_names=classes))

# Confusion matrix
cm = confusion_matrix(val_tgts, val_preds)
print("\nConfusion Matrix:")
print(cm)

print("\nTraining completed with attention-based fusion and class balancing!")