# Phase 2: Shoe Type Specialist Classifier
This notebook trains a specialized ResNet50 model to distinguish between different types of shoes.
It uses the existing `fashion-dataset` but filters specifically for footwear categories.

**Goal:** Provide the "Intelligence" for the "Vibe Check" logic (e.g., distinguishing Sneakers from Formal Shoes).

In [7]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from PIL import Image
from tqdm import tqdm
import pickle

# Configuration
DATA_DIR = r"d:/AAI3001/fashion-dataset/fashion-dataset"
CSV_PATH = os.path.join(DATA_DIR, "styles.csv")
IMAGES_DIR = os.path.join(DATA_DIR, "images")
MODEL_SAVE_PATH = r"d:/AAI3001/best_model_shoes.pth"
ENCODER_SAVE_PATH = r"d:/AAI3001/le_shoes.pkl"
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
# 1. Load and Filter Data
try:
    df = pd.read_csv(CSV_PATH, on_bad_lines='skip')
except Exception as e:
    print(f"Error reading CSV: {e}")
    # Fallback for some pandas versions
    df = pd.read_csv(CSV_PATH, error_bad_lines=False)

# Define the Shoe Classes we care about
SHOE_CLASSES = [
    "Casual Shoes", 
    "Sports Shoes", 
    "Formal Shoes", 
    "Heels", 
    "Flats", 
    "Sandals", 
    "Flip Flops"
]

# Filter
df_shoes = df[df["articleType"].isin(SHOE_CLASSES)].copy()
df_shoes["image_path"] = df_shoes["id"].apply(lambda x: os.path.join(IMAGES_DIR, f"{x}.jpg"))

# Check file existence
df_shoes = df_shoes[df_shoes["image_path"].apply(os.path.exists)]

print(f"Total Shoe Images Found: {len(df_shoes)}")
print(df_shoes["articleType"].value_counts())

# Encode Labels
le = LabelEncoder()
df_shoes["label"] = le.fit_transform(df_shoes["articleType"])

# Save Encoder
with open(ENCODER_SAVE_PATH, "wb") as f:
    pickle.dump(le, f)
print(f"Encoder saved to {ENCODER_SAVE_PATH}")
print(f"Classes: {le.classes_}")

Total Shoe Images Found: 9152
articleType
Casual Shoes    2845
Sports Shoes    2036
Heels           1323
Flip Flops       914
Sandals          897
Formal Shoes     637
Flats            500
Name: count, dtype: int64
Encoder saved to d:/AAI3001/le_shoes.pkl
Classes: ['Casual Shoes' 'Flats' 'Flip Flops' 'Formal Shoes' 'Heels' 'Sandals'
 'Sports Shoes']


In [4]:
# 2. Dataset and DataLoader
class ShoeDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row["image_path"]
        label = row["label"]
        
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            # Fallback for corrupt images
            image = Image.new("RGB", (224, 224), (0, 0, 0))
            
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label, dtype=torch.long)

# Transforms
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split
train_df, val_df = train_test_split(df_shoes, test_size=0.2, stratify=df_shoes["label"], random_state=42)

train_dataset = ShoeDataset(train_df, transform=train_transform)
val_dataset = ShoeDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 for Windows safety
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")

Train size: 7321
Val size: 1831


In [5]:
# 3. Model Setup
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# Freeze early layers (optional, but good for speed)
for param in model.parameters():
    param.requires_grad = False

# Unfreeze layer4 for fine-tuning
for param in model.layer4.parameters():
    param.requires_grad = True

# Replace Head
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_ftrs, len(le.classes_))
)

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [6]:
# 4. Training Loop
best_acc = 0.0

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Train
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(train_loader, desc="Training"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = correct / total
    print(f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
    
    # Validate
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    val_loss = val_loss / len(val_dataset)
    val_acc = correct / total
    print(f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
    
    scheduler.step()
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"Saved new best model with acc: {best_acc:.4f}")

print("Training Complete.")

Epoch 1/10


Training: 100%|██████████| 229/229 [06:05<00:00,  1.59s/it]
Training: 100%|██████████| 229/229 [06:05<00:00,  1.59s/it]


Train Loss: 0.7733 Acc: 0.7167


Validation: 100%|██████████| 58/58 [01:23<00:00,  1.44s/it]



Val Loss: 0.7065 Acc: 0.7346
Saved new best model with acc: 0.7346
Epoch 2/10
Saved new best model with acc: 0.7346
Epoch 2/10


Training: 100%|██████████| 229/229 [02:58<00:00,  1.28it/s]
Training: 100%|██████████| 229/229 [02:58<00:00,  1.28it/s]


Train Loss: 0.5696 Acc: 0.7839


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.42it/s]



Val Loss: 0.5827 Acc: 0.7957
Saved new best model with acc: 0.7957
Epoch 3/10


Training: 100%|██████████| 229/229 [02:57<00:00,  1.29it/s]
Training: 100%|██████████| 229/229 [02:57<00:00,  1.29it/s]


Train Loss: 0.5026 Acc: 0.8066


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s]



Val Loss: 0.4708 Acc: 0.8225
Saved new best model with acc: 0.8225
Epoch 4/10


Training: 100%|██████████| 229/229 [02:57<00:00,  1.29it/s]
Training: 100%|██████████| 229/229 [02:57<00:00,  1.29it/s]


Train Loss: 0.4001 Acc: 0.8466


Validation: 100%|██████████| 58/58 [00:41<00:00,  1.41it/s]



Val Loss: 0.4514 Acc: 0.8236
Saved new best model with acc: 0.8236
Epoch 5/10


Training: 100%|██████████| 229/229 [03:01<00:00,  1.26it/s]
Training: 100%|██████████| 229/229 [03:01<00:00,  1.26it/s]


Train Loss: 0.3733 Acc: 0.8547


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.44it/s]
Validation: 100%|██████████| 58/58 [00:40<00:00,  1.44it/s]


Val Loss: 0.4641 Acc: 0.8209
Epoch 6/10


Training: 100%|██████████| 229/229 [02:55<00:00,  1.31it/s]
Training: 100%|██████████| 229/229 [02:55<00:00,  1.31it/s]


Train Loss: 0.3592 Acc: 0.8589


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s]
Validation: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s]


Val Loss: 0.4763 Acc: 0.8192
Epoch 7/10


Training: 100%|██████████| 229/229 [02:56<00:00,  1.30it/s]
Training: 100%|██████████| 229/229 [02:56<00:00,  1.30it/s]


Train Loss: 0.3391 Acc: 0.8661


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s]



Val Loss: 0.4536 Acc: 0.8247
Saved new best model with acc: 0.8247
Epoch 8/10


Training: 100%|██████████| 229/229 [02:55<00:00,  1.30it/s]
Training: 100%|██████████| 229/229 [02:55<00:00,  1.30it/s]


Train Loss: 0.3293 Acc: 0.8678


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.42it/s]
Validation: 100%|██████████| 58/58 [00:40<00:00,  1.42it/s]


Val Loss: 0.4712 Acc: 0.8165
Epoch 9/10


Training: 100%|██████████| 229/229 [02:55<00:00,  1.30it/s]
Training: 100%|██████████| 229/229 [02:55<00:00,  1.30it/s]


Train Loss: 0.3321 Acc: 0.8701


Validation: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s]
Validation: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s]


Val Loss: 0.4525 Acc: 0.8236
Epoch 10/10


Training: 100%|██████████| 229/229 [02:55<00:00,  1.31it/s]
Training: 100%|██████████| 229/229 [02:55<00:00,  1.31it/s]


Train Loss: 0.3298 Acc: 0.8697


Validation: 100%|██████████| 58/58 [00:39<00:00,  1.45it/s]

Val Loss: 0.4597 Acc: 0.8225
Training Complete.





# Next Steps (The "Eyes")
Now that you have the "Brain" (this classifier), you need the "Eyes" (Detector) to find the shoes in the first place.

**To train the Detector (Step 1):**
1. Download the **Fashionpedia** dataset (or a subset with 'shoe' labels).
2. Train a YOLOv8 model on just the accessory classes.
3. Use that YOLO model to crop the feet, then pass the crop to this ResNet model.