In [48]:
%pip install tqdm
import os
import time
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from PIL import Image
import io

Note: you may need to restart the kernel to use updated packages.


In [49]:
# Check GPU / MPS (for Mac)
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("‚úÖ Using GPU (CUDA)")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("‚úÖ Using Apple Metal (MPS)")
else:
    device = torch.device("cpu")
    print("‚ö†Ô∏è Using CPU ‚Äî training will be slower")

‚úÖ Using Apple Metal (MPS)


In [50]:
# Custom Dataset (for JPEG-bytes in Parquet)

from torch.utils.data import Dataset

class SneakerDataset(Dataset):
    def __init__(self, df, label_col="brand", transform=None):
        self.df = df
        self.images = df['image'].values
        self.labels = df[label_col].astype('category').cat.codes
        self.label2name = dict(enumerate(df[label_col].astype('category').cat.categories))
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_bytes = self.images[idx]
        if isinstance(img_bytes, str):
            img_bytes = bytes(img_bytes, 'utf-8')
        image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(self.labels.iloc[idx], dtype=torch.long)
        return image, label

In [51]:
#  Transforms (resize + augment)
from torchvision import models, transforms

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [52]:
# Load Combined Dataset
import pandas as pd

# Load precombined dataset
df = pd.read_parquet("all_sneakers_combined.parquet")

# Drop incomplete rows
df = df.dropna(subset=['image', 'brand'])
df = df.groupby('brand').filter(lambda x: len(x) > 5)

# Encode once for stratified split
df['brand'] = df['brand'].astype('category')
df['label'] = df['brand'].cat.codes


In [53]:
# Balanced Train/Val Split

from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df['label'], random_state=42
)

print(f"Train samples: {len(train_df)}, Validation: {len(val_df)}")

Train samples: 72864, Validation: 18217


In [63]:
# DataLoaders
from torch.utils.data import DataLoader

train_dataset = SneakerDataset(train_df, transform=transform_train)
val_dataset = SneakerDataset(val_df, transform=transform_val)

train_loader = DataLoader(
    train_dataset, batch_size=64, shuffle=True,
    num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=64, shuffle=False,
    num_workers=0, pin_memory=True
)
print(train_dataset.label2name )

{0: 'Adidas', 1: 'Alexander', 2: 'Amiri', 3: 'Asics', 4: 'Autry', 5: 'BAPE', 6: 'Balenciaga', 7: 'Birkenstock', 8: 'Camper', 9: 'Clarks', 10: 'Converse', 11: 'Crocs', 12: 'Diadora', 13: 'Dr.', 14: 'Ewing', 15: 'Hoka', 16: 'Jordan', 17: 'KangaROOS', 18: 'Karhu', 19: 'Keen', 20: 'Lacoste', 21: 'Lanvin', 22: 'Le', 23: 'Mizuno', 24: 'Moon', 25: 'New', 26: 'Nike', 27: 'ON', 28: 'Off-White', 29: 'Onitsuka', 30: 'Puma', 31: 'Reebok', 32: 'Salomon', 33: 'Saucony', 34: 'Suicoke', 35: 'Timberland', 36: 'Vans', 37: 'Veja', 38: 'adidas', 39: 'alexander'}


In [55]:
# Model, Loss, Optimizer, Scaler
from torchvision import models

num_classes = len(set(train_dataset.labels))

# path to the downloaded weights
weights_path = os.path.expanduser("~/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth")

# load model without triggering an online download
model = models.resnet18(weights=None)
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)

print("‚úÖ Loaded ResNet18 weights from local file successfully.")


model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')


‚úÖ Loaded ResNet18 weights from local file successfully.


In [56]:
# Training Loop (optimized)
from tqdm import tqdm

EPOCHS = 15
best_acc = 0.0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{EPOCHS}] Training", leave=False)
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        with  torch.amp.autocast(device_type=device.type):
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()

        loop.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)
    print(f"üìâ Epoch [{epoch+1}] Avg Loss: {avg_loss:.4f}")

    # --- Validation ---
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    val_acc = 100 * correct / total
    print(f"üéØ Validation Accuracy: {val_acc:.2f}%")

    # --- Save best model ---
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_sneaker_model.pth")
        print(f"üíæ New best model saved (Acc: {best_acc:.2f}%)")

    # --- Save checkpoint every 2 epochs ---
    if (epoch + 1) % 2 == 0:
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch+1}.pth")
        print(f"üìç Saved checkpoint at epoch {epoch+1}")


                                                                                      

üìâ Epoch [1] Avg Loss: 0.7836
üéØ Validation Accuracy: 85.06%
üíæ New best model saved (Acc: 85.06%)


                                                                                      

üìâ Epoch [2] Avg Loss: 0.3726
üéØ Validation Accuracy: 89.19%
üíæ New best model saved (Acc: 89.19%)
üìç Saved checkpoint at epoch 2


                                                                                      

üìâ Epoch [3] Avg Loss: 0.2606
üéØ Validation Accuracy: 89.71%
üíæ New best model saved (Acc: 89.71%)


                                                                                       

üìâ Epoch [4] Avg Loss: 0.2039
üéØ Validation Accuracy: 91.15%
üíæ New best model saved (Acc: 91.15%)
üìç Saved checkpoint at epoch 4


                                                                                       

üìâ Epoch [5] Avg Loss: 0.1594
üéØ Validation Accuracy: 91.13%


                                                                                       

üìâ Epoch [6] Avg Loss: 0.1381
üéØ Validation Accuracy: 91.08%
üìç Saved checkpoint at epoch 6


                                                                                       

üìâ Epoch [7] Avg Loss: 0.1193
üéØ Validation Accuracy: 91.65%
üíæ New best model saved (Acc: 91.65%)


                                                                                       

üìâ Epoch [8] Avg Loss: 0.1046
üéØ Validation Accuracy: 91.79%
üíæ New best model saved (Acc: 91.79%)
üìç Saved checkpoint at epoch 8


                                                                                       

üìâ Epoch [9] Avg Loss: 0.0952
üéØ Validation Accuracy: 92.14%
üíæ New best model saved (Acc: 92.14%)


                                                                                        

üìâ Epoch [10] Avg Loss: 0.0876
üéØ Validation Accuracy: 92.01%
üìç Saved checkpoint at epoch 10


                                                                                        

üìâ Epoch [11] Avg Loss: 0.0803
üéØ Validation Accuracy: 92.33%
üíæ New best model saved (Acc: 92.33%)


                                                                                         

üìâ Epoch [12] Avg Loss: 0.0781
üéØ Validation Accuracy: 92.02%
üìç Saved checkpoint at epoch 12


                                                                                         

üìâ Epoch [13] Avg Loss: 0.0704
üéØ Validation Accuracy: 91.38%


                                                                                        

üìâ Epoch [14] Avg Loss: 0.0681
üéØ Validation Accuracy: 92.47%
üíæ New best model saved (Acc: 92.47%)
üìç Saved checkpoint at epoch 14


                                                                                         

üìâ Epoch [15] Avg Loss: 0.0663
üéØ Validation Accuracy: 92.30%


In [62]:
# Quick Prediction Test

model.eval()
img_tensor, label = val_dataset[0]
img_tensor = img_tensor.unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(img_tensor)
    _, pred = torch.max(outputs, 1)

predicted_label = train_dataset.label2name[pred.item()]
actual_label = train_dataset.label2name[label.item()]

print(f"Predicted: {predicted_label} | Actual: {actual_label}")

Predicted: Asics | Actual: Asics
