# Big Project – Session 2: Train Image Classifier
Train a small classifier (ResNet18) and save it to `game_project/models/classifier.pt`.

You can start with CIFAR-10 for speed, then optionally fine-tune on your own assets.

In [None]:
# Setup
!pip -q install torch torchvision matplotlib --extra-index-url https://download.pytorch.org/whl/cu121
import torch, torch.nn as nn, torch.optim as optim
import torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import os, json, time, matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_PATH = "/mnt/data/nicegpu_ai_workshop/game_project/models/classifier.pt"
CLASSES_PATH = "/mnt/data/nicegpu_ai_workshop/game_project/models/classes.txt"
os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
print('Saving model to:', SAVE_PATH)

### Option A: CIFAR-10 (fast & simple)

In [None]:
BATCH_SIZE=128; EPOCHS=2; LR=1e-3
train_tfms=T.Compose([T.RandomCrop(32,padding=4), T.RandomHorizontalFlip(), T.ToTensor()])
test_tfms=T.Compose([T.ToTensor()])
train_ds=torchvision.datasets.CIFAR10('./data', train=True, transform=train_tfms, download=True)
test_ds=torchvision.datasets.CIFAR10('./data', train=False, transform=test_tfms, download=True)
classes=train_ds.classes
open(CLASSES_PATH,'w').write("\n".join(classes))

train_dl=DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,num_workers=2,pin_memory=True)
test_dl=DataLoader(test_ds,batch_size=BATCH_SIZE,shuffle=False,num_workers=2,pin_memory=True)

model=resnet18(num_classes=len(classes)).to(DEVICE)
crit=nn.CrossEntropyLoss(); opt=optim.AdamW(model.parameters(), lr=LR)

for e in range(1,EPOCHS+1):
    model.train(); tot=0; correct=0; loss_sum=0
    t0=time.time()
    for x,y in train_dl:
        x,y=x.to(DEVICE),y.to(DEVICE)
        opt.zero_grad(); out=model(x); loss=crit(out,y); loss.backward(); opt.step()
        loss_sum+=loss.item()*x.size(0); correct+=(out.argmax(1)==y).sum().item(); tot+=x.size(0)
    print(f"Epoch {e}/{EPOCHS} | loss={loss_sum/tot:.4f} acc={correct/tot:.3f} time={time.time()-t0:.1f}s")

# Evaluate
model.eval(); tot=0; correct=0
with torch.no_grad():
    for x,y in test_dl:
        x,y=x.to(DEVICE),y.to(DEVICE)
        out=model(x); pred=out.argmax(1); correct+=(pred==y).sum().item(); tot+=x.size(0)
print('Test accuracy:', correct/tot)

torch.save(model.state_dict(), SAVE_PATH)
print('✅ Saved model to', SAVE_PATH)
print('✅ Saved classes to', CLASSES_PATH)

### Option B (Optional): Fine-tune on your own assets
Create folders like `my_assets/classA/*.png`, `my_assets/classB/*.png`, then run below.

In [None]:
from pathlib import Path
own_root = Path("my_assets")  # put your labeled folders here
if own_root.exists():
    print("Found:", own_root)
    tfms = T.Compose([T.Resize((32,32)), T.ToTensor()])
    own_ds = torchvision.datasets.ImageFolder(own_root, transform=tfms)
    own_dl = DataLoader(own_ds, batch_size=64, shuffle=True)
    # quick single-epoch fine-tune
    model.fc = nn.Linear(model.fc.in_features, len(own_ds.classes)).to(DEVICE)
    opt = optim.AdamW(model.parameters(), lr=5e-4)
    for x,y in own_dl:
        x,y=x.to(DEVICE),y.to(DEVICE)
        opt.zero_grad(); out=model(x); loss=crit(out,y); loss.backward(); opt.step()
    # save new head + classes
    with open(CLASSES_PATH,'w') as f: f.write("\n".join(own_ds.classes))
    torch.save(model.state_dict(), SAVE_PATH)
    print("✅ Fine-tuned on custom assets and saved model.")
else:
    print("Optional: add labeled images to ./my_assets/ to fine-tune.")