# Fashion Image Classifier: Experiment Notebook

**Research question:** How much does transfer learning improve over a from-scratch CNN for fashion product classification, and which pretrained backbone performs best?

## Setup

In [None]:
import sys
sys.path.insert(0, "..")

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import torch
import wandb
from dotenv import load_dotenv

from src.config import (
    STYLES_CSV, IMAGES_DIR, FIGURES_DIR, WANDB_PROJECT,
    DEFAULT_BATCH_SIZE, DEFAULT_EPOCHS, DEFAULT_LR, DEFAULT_DROPOUT,
)
from src.data_loader import (
    load_metadata, filter_top_categories, verify_images_exist,
    encode_labels, split_data, build_dataloaders,
)
from src.models import build_model, MODEL_REGISTRY
from src.training import train_model, evaluate
from src.evaluation import (
    compute_accuracy, compute_classification_report,
    compute_confusion_matrix, compute_topk_accuracy,
)
from src.wandb_utils import init_wandb_run, log_confusion_matrix, log_sample_predictions

load_dotenv()
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Exploration

In [None]:
# Load and filter metadata
df = load_metadata(STYLES_CSV)
print(f"Total rows in styles.csv: {len(df)}")
print(f"Unique article types: {df['articleType'].nunique()}")

df = filter_top_categories(df)
print(f"\nAfter filtering (>= 500 images): {len(df)} rows, {df['articleType'].nunique()} categories")

df = verify_images_exist(df, IMAGES_DIR)
print(f"After verifying images exist: {len(df)} rows")

In [None]:
# Category distribution
fig, ax = plt.subplots(figsize=(12, 6))
counts = df["articleType"].value_counts()
counts.plot(kind="barh", ax=ax)
ax.set_xlabel("Number of Images")
ax.set_title("Category Distribution (Filtered)")
plt.tight_layout()
fig.savefig(FIGURES_DIR / "category_distribution.png", dpi=150)
plt.show()

In [None]:
# Sample images
from PIL import Image
from src.config import IMAGE_SIZE

sample_cats = counts.head(6).index.tolist()
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
for ax, cat in zip(axes.flat, sample_cats):
    sample_id = df[df["articleType"] == cat].iloc[0]["id"]
    img = Image.open(IMAGES_DIR / f"{sample_id}.jpg")
    ax.imshow(img)
    ax.set_title(cat)
    ax.axis("off")
plt.suptitle("Sample Images by Category")
plt.tight_layout()
fig.savefig(FIGURES_DIR / "sample_images.png", dpi=150)
plt.show()

## 2. Data Preparation

In [None]:
df, label_to_idx, idx_to_label = encode_labels(df)
class_names = [idx_to_label[i] for i in range(len(idx_to_label))]
num_classes = len(class_names)
print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

train_df, val_df, test_df = split_data(df)
print(f"\nTrain: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

In [None]:
train_loader, val_loader, test_loader = build_dataloaders(
    train_df, val_df, test_df,
    images_dir=IMAGES_DIR,
    batch_size=DEFAULT_BATCH_SIZE,
    num_workers=2,
)

## 3. Train All Three Models

Each model gets its own W&B run for comparison.

In [None]:
all_results = {}

model_configs = [
    {"name": "simple_cnn", "pretrained": False, "lr": 1e-3, "epochs": DEFAULT_EPOCHS},
    {"name": "resnet50", "pretrained": True, "lr": 1e-4, "epochs": DEFAULT_EPOCHS},
    {"name": "efficientnet_b0", "pretrained": True, "lr": 1e-4, "epochs": DEFAULT_EPOCHS},
]

for cfg in model_configs:
    print(f"\n{'='*60}")
    print(f"Training: {cfg['name']}")
    print(f"{'='*60}")

    run = init_wandb_run(
        config={
            "model_name": cfg["name"],
            "learning_rate": cfg["lr"],
            "batch_size": DEFAULT_BATCH_SIZE,
            "dropout": DEFAULT_DROPOUT,
            "epochs": cfg["epochs"],
            "num_classes": num_classes,
        },
        project=WANDB_PROJECT,
        name=f"{cfg['name']}_baseline",
    )

    model = build_model(
        cfg["name"],
        num_classes=num_classes,
        dropout=DEFAULT_DROPOUT,
        pretrained=cfg.get("pretrained", False),
    )

    results = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=cfg["epochs"],
        lr=cfg["lr"],
        device=device,
        use_wandb=True,
        class_names=class_names,
    )

    all_results[cfg["name"]] = results
    print(f"Best val accuracy: {results['best_val_accuracy']:.4f}")
    wandb.finish()

print("\nAll training complete.")

## 4. Model Comparison

In [None]:
# Comparison table
comparison = pd.DataFrame(all_results).T
comparison.index.name = "Model"
print(comparison.to_markdown())

In [None]:
# Bar chart of validation accuracy
fig, ax = plt.subplots(figsize=(8, 5))
models = list(all_results.keys())
val_accs = [all_results[m]["best_val_accuracy"] for m in models]
ax.bar(models, val_accs)
ax.set_ylabel("Best Validation Accuracy")
ax.set_title("Model Comparison: Validation Accuracy")
ax.set_ylim(0, 1)
for i, v in enumerate(val_accs):
    ax.text(i, v + 0.01, f"{v:.3f}", ha="center")
plt.tight_layout()
fig.savefig(FIGURES_DIR / "model_comparison.png", dpi=150)
plt.show()

## 5. Test Set Evaluation (Best Model)

In [None]:
# Retrain or load best model and evaluate on test set
best_model_name = max(all_results, key=lambda m: all_results[m]["best_val_accuracy"])
print(f"Best model: {best_model_name}")

best_model = build_model(
    best_model_name,
    num_classes=num_classes,
    pretrained=best_model_name != "simple_cnn",
)

# Quick retrain of best model for test evaluation
best_cfg = next(c for c in model_configs if c["name"] == best_model_name)
_ = train_model(
    model=best_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=best_cfg["epochs"],
    lr=best_cfg["lr"],
    device=device,
    use_wandb=False,
)

# Test set evaluation
criterion = torch.nn.CrossEntropyLoss()
test_loss, test_acc, y_true, y_pred = evaluate(best_model, test_loader, criterion, device)
print(f"\nTest accuracy: {test_acc:.4f}")
print(f"Test loss: {test_loss:.4f}")

In [None]:
# Classification report
report = compute_classification_report(y_true, y_pred, label_names=class_names)
report_df = pd.DataFrame(report).T
print(report_df.to_markdown())

In [None]:
# Confusion matrix heatmap
cm = compute_confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names, yticklabels=class_names, ax=ax)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title(f"Confusion Matrix: {best_model_name}")
plt.tight_layout()
fig.savefig(FIGURES_DIR / "confusion_matrix.png", dpi=150)
plt.show()

## 6. Error Analysis

Which class pairs are most commonly confused?

In [None]:
# Find most confused pairs (off-diagonal elements)
cm_copy = cm.copy()
np.fill_diagonal(cm_copy, 0)

# Top 5 confused pairs
confused_pairs = []
for _ in range(5):
    i, j = np.unravel_index(cm_copy.argmax(), cm_copy.shape)
    confused_pairs.append((class_names[i], class_names[j], cm_copy[i, j]))
    cm_copy[i, j] = 0

print("Most confused pairs:")
for true_cls, pred_cls, count in confused_pairs:
    print(f"  {true_cls} -> {pred_cls}: {count} misclassifications")

## 7. Hyperparameter Sweep Results

After running `wandb sweep sweep.yaml` and `wandb agent <sweep-id>`, pull results via the W&B API.

In [None]:
# Pull sweep results from W&B API
api = wandb.Api()
runs = api.runs(WANDB_PROJECT)

sweep_data = []
for run in runs:
    if run.state == "finished":
        sweep_data.append({
            "name": run.name,
            "model": run.config.get("model_name", "unknown"),
            "lr": run.config.get("learning_rate"),
            "batch_size": run.config.get("batch_size"),
            "dropout": run.config.get("dropout"),
            "val_accuracy": run.summary.get("val/accuracy"),
        })

sweep_df = pd.DataFrame(sweep_data)
print(f"Total finished runs: {len(sweep_df)}")
sweep_df.sort_values("val_accuracy", ascending=False).head(10)

In [None]:
# Best config per model
if len(sweep_df) > 0:
    best_per_model = sweep_df.loc[sweep_df.groupby("model")["val_accuracy"].idxmax()]
    print(best_per_model[["model", "lr", "batch_size", "dropout", "val_accuracy"]].to_markdown())

## 8. Conclusions

**Research question:** How much does transfer learning improve over a from-scratch CNN, and which backbone performs best?

*Fill in after runs complete:*

- SimpleCNN (from scratch): __% validation accuracy
- ResNet-50 (pretrained): __% validation accuracy  
- EfficientNet-B0 (pretrained): __% validation accuracy

**Key findings:**
- Transfer learning provides a ___ percentage point improvement over training from scratch.
- [Best model] achieves the highest accuracy while [comparison notes].
- The most commonly confused categories are [X] and [Y], which makes sense because [reason].
- Hyperparameter sweep found that [key insights about learning rate, batch size, etc.].