# Custom CNN — Transfer Learning (PyTorch)
*Generated on 2025-09-08 17:39*

This keeps **your SimpleCNN** architecture and adds **transfer learning**:
- Load weights from an older SimpleCNN checkpoint (even if channels/classes differ)
- Freeze encoder initially, optionally **unfreeze later**
- Conv1 **channel adapter** (1↔3)
- Strong training loop: AMP, cosine LR, early stopping, class weights

In [None]:
import os, math, json, random
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

## Config

In [None]:
# Paths (edit if needed)
CSV_TRAIN = "/mnt/data/train.csv"
CSV_VAL   = "/mnt/data/val.csv"
CSV_TEST  = "/mnt/data/test.csv"  # optional
LABEL_MAP = "/mnt/data/label_map.json"
IMG_ROOT  = ""

# Columns
COL_IMAGE = "image"
COL_LABEL = "label"

# Model
IN_CHANNELS   = 3        # 1 for grayscale, 3 for RGB
IMG_SIZE      = 224      # set 48 for 48x48 data
BASE_CHANNELS = 32
DROPOUT       = 0.3

# Transfer
SOURCE_CKPT   = ""       # e.g., "/content/drive/MyDrive/emo_cnn_baseline_best.pt"
FREEZE_ENCODER = True
UNFREEZE_AT_EPOCH = 5    # 0 = never

# Training
EPOCHS            = 30
BATCH_SIZE        = 64
LR                = 3e-4
WEIGHT_DECAY      = 1e-4
LABEL_SMOOTH      = 0.05
USE_CLASS_WEIGHTS = True
USE_WEIGHTED_SAMPLER = False
USE_AMP           = True
GRAD_CLIP_NORM    = 1.0
EARLY_STOP        = 8

OUT_DIR = "./runs_custom_cnn_transfer"
os.makedirs(OUT_DIR, exist_ok=True)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

## Data & Transforms

In [None]:
from PIL import Image

class CSVDataset(Dataset):
    def __init__(self, csv_path, transform=None, img_root="", col_img="image", col_lab="label", encoder=None):
        self.df = pd.read_csv(csv_path)
        self.transform = transform
        self.img_root = img_root
        self.col_img = col_img
        self.col_lab = col_lab
        
        if encoder is None:
            self.encoder = LabelEncoder().fit(self.df[self.col_lab].astype(str).values)
        else:
            self.encoder = encoder
        self.labels = self.encoder.transform(self.df[self.col_lab].astype(str).values)
        self.paths = self.df[self.col_img].astype(str).tolist()
        self.n_classes = len(self.encoder.classes_)

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        p = self.paths[idx]
        if self.img_root and not os.path.isabs(p):
            p = os.path.join(self.img_root, p)
        img = Image.open(p)
        img = img.convert("RGB") if IN_CHANNELS==3 else img.convert("L")
        if self.transform: img = self.transform(img)
        y = self.labels[idx]
        return img, y

def build_transforms(img_size=224, in_ch=3):
    if img_size <= 64:
        train_tf = T.Compose([T.Resize((img_size,img_size)), T.RandomHorizontalFlip(0.5),
                              T.ToTensor(), T.Normalize([0.5]*in_ch,[0.5]*in_ch)])
        eval_tf  = T.Compose([T.Resize((img_size,img_size)), T.ToTensor(),
                              T.Normalize([0.5]*in_ch,[0.5]*in_ch)])
    else:
        mean = [0.485,0.456,0.406][:in_ch] if in_ch==3 else [0.5]
        std  = [0.229,0.224,0.225][:in_ch] if in_ch==3 else [0.5]
        train_tf = T.Compose([T.Resize(int(img_size*1.1)), T.CenterCrop(img_size),
                              T.RandomHorizontalFlip(0.5),
                              (T.ColorJitter(0.15,0.15,0.1,0.02) if in_ch==3 else T.Lambda(lambda x:x)),
                              T.ToTensor(), T.Normalize(mean, std)])
        eval_tf  = T.Compose([T.Resize(int(img_size*1.1)), T.CenterCrop(img_size),
                              T.ToTensor(), T.Normalize(mean, std)])
    return train_tf, eval_tf

def read_label_map(path):
    if os.path.exists(path):
        with open(path,"r") as f: data = json.load(f)
        if isinstance(data, dict) and "str2idx" in data and "idx2str" in data:
            str2idx = {k:int(v) for k,v in data["str2idx"].items()}
        else:
            str2idx = {k:int(v) for k,v in data.items()}
        classes_sorted = [k for k,_ in sorted(str2idx.items(), key=lambda kv: kv[1])]
        return LabelEncoder().fit(classes_sorted)
    return None

train_tf, eval_tf = build_transforms(IMG_SIZE, IN_CHANNELS)
enc = read_label_map(LABEL_MAP)

ds_tr = CSVDataset(CSV_TRAIN, train_tf, IMG_ROOT, COL_IMAGE, COL_LABEL, enc)
ds_va = CSVDataset(CSV_VAL,   eval_tf,  IMG_ROOT, COL_IMAGE, COL_LABEL, ds_tr.encoder)
ds_te = CSVDataset(CSV_TEST,  eval_tf,  IMG_ROOT, COL_IMAGE, COL_LABEL, ds_tr.encoder) if os.path.exists(CSV_TEST) else None

classes = list(ds_tr.encoder.classes_)
num_classes = len(classes)
print("Classes:", classes)
print("Sizes:", len(ds_tr), len(ds_va), 0 if ds_te is None else len(ds_te))

y_tr = ds_tr.labels
class_counts = np.bincount(y_tr, minlength=num_classes).astype(float)
class_weights = class_counts.sum() / (num_classes * np.maximum(class_counts, 1.0))
class_weights_t = torch.tensor(class_weights, dtype=torch.float32)

if USE_WEIGHTED_SAMPLER:
    sample_w = class_weights[y_tr]
    sampler = WeightedRandomSampler(sample_w, num_samples=len(sample_w), replacement=True)
    loader_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
else:
    loader_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

loader_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
loader_te = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) if ds_te else None

print("Class counts:", class_counts)

## Custom CNN

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False)
        self.bn   = nn.BatchNorm2d(out_ch)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x): return self.act(self.bn(self.conv(x)))

class SimpleCNN(nn.Module):
    def __init__(self, in_ch=3, num_classes=7, c=32, dropout=0.3):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(in_ch, c),   ConvBlock(c, c),   nn.MaxPool2d(2),
            ConvBlock(c, 2*c),     ConvBlock(2*c,2*c),nn.MaxPool2d(2),
            ConvBlock(2*c,4*c),    ConvBlock(4*c,4*c),nn.MaxPool2d(2),
            ConvBlock(4*c,8*c),    ConvBlock(8*c,8*c),nn.MaxPool2d(2),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(8*c, num_classes)
        )
    def forward(self, x): return self.head(self.features(x))

def build_model(in_ch, n_classes, c=32, dropout=0.3):
    return SimpleCNN(in_ch, n_classes, c, dropout)

## Transfer loader (partial load + conv1 adapter)

In [None]:
def adapt_conv1(w, target_in_ch):
    out_ch, in_ch, k1, k2 = w.shape
    if in_ch == target_in_ch: return w
    if in_ch == 1 and target_in_ch == 3:
        return w.repeat(1,3,1,1)  # replicate gray to RGB
    if in_ch == 3 and target_in_ch == 1:
        return w.mean(dim=1, keepdim=True)
    # generic tile/trim
    if target_in_ch > in_ch:
        reps = (target_in_ch + in_ch - 1)//in_ch
        w2 = w.repeat(1,reps,1,1)[:, :target_in_ch, :, :]
        return w2 * (in_ch/target_in_ch)
    else:
        return w[:, :target_in_ch, :, :]

def load_for_transfer(model, ckpt_path, in_ch_target):
    if not ckpt_path or not os.path.exists(ckpt_path):
        print("No SOURCE_CKPT -> training from scratch."); return
    state = torch.load(ckpt_path, map_location="cpu")
    sd = state.get("model", state)
    msd = model.state_dict()
    loaded = 0
    for k,v in sd.items():
        if k in msd and msd[k].shape == v.shape:
            msd[k] = v; loaded += 1
        elif k.endswith("features.0.conv.weight") and "features.0.conv.weight" in msd:
            v2 = adapt_conv1(v, in_ch_target)
            if msd["features.0.conv.weight"].shape == v2.shape:
                msd["features.0.conv.weight"] = v2; loaded += 1
    model.load_state_dict(msd, strict=False)
    print(f"Loaded {loaded} tensors from {ckpt_path}")

## Train setup

In [None]:
model = build_model(IN_CHANNELS, num_classes, c=BASE_CHANNELS, dropout=DROPOUT).to(device)
load_for_transfer(model, SOURCE_CKPT, IN_CHANNELS)

# Freeze encoder
def set_encoder_trainable(net, flag: bool):
    for p in net.features.parameters():
        p.requires_grad = flag

if FREEZE_ENCODER:
    set_encoder_trainable(model, False)
    print("Encoder frozen. Will unfreeze at epoch:", UNFREEZE_AT_EPOCH if UNFREEZE_AT_EPOCH>0 else "never")

class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.0, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight
    def forward(self, logits, target):
        n = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true = torch.zeros_like(logp)
            true.fill_(self.smoothing / (n-1))
            true.scatter_(1, target.unsqueeze(1), 1-self.smoothing)
        if self.weight is not None:
            w = self.weight.unsqueeze(0)
            loss = -(true*logp*w).sum(dim=1)
        else:
            loss = -(true*logp).sum(dim=1)
        return loss.mean()

criterion = LabelSmoothingCE(LABEL_SMOOTH, class_weights_t.to(device) if USE_CLASS_WEIGHTS else None)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.cuda.amp.GradScaler() if (USE_AMP and device.type=='cuda') else None

def acc(logits, y): return (logits.argmax(1)==y).float().mean().item()

## Train/Eval loops

In [None]:
@torch.no_grad()
def evaluate(net, loader):
    net.eval(); losses=[]; accs=[]; y_true=[]; y_pred=[]
    for x,y in tqdm(loader, desc="Eval", leave=False):
        x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
        logits = net(x); loss = criterion(logits, y)
        losses.append(loss.item()); accs.append(acc(logits,y))
        y_true += y.cpu().tolist(); y_pred += logits.argmax(1).cpu().tolist()
    return float(np.mean(losses)), float(np.mean(accs)), np.array(y_true), np.array(y_pred)

def train_one_epoch(net, loader, opt, scaler=None):
    net.train(); losses=[]; accs=[]
    for x,y in tqdm(loader, desc="Train", leave=False):
        x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
        opt.zero_grad(set_to_none=True)
        if scaler is not None:
            with torch.cuda.amp.autocast():
                logits = net(x); loss = criterion(logits, y)
            scaler.scale(loss).backward()
            if GRAD_CLIP_NORM>0:
                scaler.unscale_(opt); nn.utils.clip_grad_norm_(net.parameters(), GRAD_CLIP_NORM)
            scaler.step(opt); scaler.update()
        else:
            logits = net(x); loss = criterion(logits, y)
            loss.backward()
            if GRAD_CLIP_NORM>0: nn.utils.clip_grad_norm_(net.parameters(), GRAD_CLIP_NORM)
            opt.step()
        losses.append(loss.item()); accs.append(acc(logits,y))
    return float(np.mean(losses)), float(np.mean(accs))

## Train

In [None]:
best_acc=-1.0; noimp=0
ckpt = os.path.join(OUT_DIR,"custom_cnn_best.pt")
hist={"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

for epoch in range(EPOCHS):
    if FREEZE_ENCODER and UNFREEZE_AT_EPOCH>0 and epoch==UNFREEZE_AT_EPOCH:
        set_encoder_trainable(model, True)
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        print(f"🔓 Unfroze encoder at epoch {epoch}")
    tr_loss,tr_acc = train_one_epoch(model, loader_tr, optimizer, scaler)
    scheduler.step()
    va_loss,va_acc,y_true,y_pred = evaluate(model, loader_va)

    hist["train_loss"].append(tr_loss); hist["train_acc"].append(tr_acc)
    hist["val_loss"].append(va_loss);   hist["val_acc"].append(va_acc)

    print(f"Epoch {epoch+1:02d} | train {tr_loss:.4f}/{tr_acc:.4f}  val {va_loss:.4f}/{va_acc:.4f}")
    if va_acc>best_acc:
        best_acc=va_acc; noimp=0
        torch.save({"model":model.state_dict(), "in_channels":IN_CHANNELS, "classes":classes,
                    "arch":{"base_channels":BASE_CHANNELS,"dropout":DROPOUT}}, ckpt)
        print("  ✅ Saved:", ckpt)
    else:
        noimp+=1
        if noimp>=EARLY_STOP:
            print("⏹️ Early stopping"); break

# plots
plt.figure(figsize=(6,4)); plt.plot(hist["train_acc"],label="train_acc"); plt.plot(hist["val_acc"],label="val_acc"); plt.legend(); plt.title("Accuracy"); plt.show()
plt.figure(figsize=(6,4)); plt.plot(hist["train_loss"],label="train_loss"); plt.plot(hist["val_loss"],label="val_loss"); plt.legend(); plt.title("Loss"); plt.show()

# Final eval (best)
state = torch.load(ckpt, map_location=device)
model = SimpleCNN(in_ch=state["in_channels"], num_classes=len(state["classes"]),
                  c=state["arch"]["base_channels"], dropout=state["arch"]["dropout"]).to(device)
model.load_state_dict(state["model"])

_, va_acc, y_true, y_pred = evaluate(model, loader_va)
print("Best Val Acc:", va_acc)
print("\nClassification Report (Val):")
print(classification_report(y_true, y_pred, target_names=classes, digits=4))

cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
plt.figure(figsize=(6,6)); plt.imshow(cm, interpolation="nearest"); plt.title("Confusion Matrix (Val)"); plt.colorbar()
ticks=np.arange(len(classes)); plt.xticks(ticks, classes, rotation=45, ha="right"); plt.yticks(ticks, classes)
plt.xlabel("Predicted"); plt.ylabel("True"); plt.tight_layout(); plt.show()

if loader_te is not None:
    _, te_acc, y_true_te, y_pred_te = evaluate(model, loader_te)
    print("\nTest Acc:", te_acc)
    print("\nClassification Report (Test):")
    print(classification_report(y_true_te, y_pred_te, target_names=classes, digits=4))

## Inference helper

In [None]:
from PIL import Image
def build_infer_transform(img_size, in_ch):
    if img_size<=64:
        mean=[0.5]*in_ch; std=[0.5]*in_ch
        return T.Compose([T.Resize((img_size,img_size)), T.ToTensor(), T.Normalize(mean,std)])
    else:
        if in_ch==3:
            mean=[0.485,0.456,0.406]; std=[0.229,0.224,0.225]
        else:
            mean=[0.5]; std=[0.5]
        return T.Compose([T.Resize(int(img_size*1.1)), T.CenterCrop(img_size), T.ToTensor(), T.Normalize(mean,std)])

@torch.no_grad()
def predict_image(path, ckpt_path, topk=5):
    st = torch.load(ckpt_path, map_location=device)
    net = SimpleCNN(in_ch=st["in_channels"], num_classes=len(st["classes"]),
                    c=st["arch"]["base_channels"], dropout=st["arch"]["dropout"]).to(device)
    net.load_state_dict(st["model"]); net.eval()
    img = Image.open(path)
    img = img.convert("RGB") if st["in_channels"]==3 else img.convert("L")
    tfm = build_infer_transform(IMG_SIZE, st["in_channels"])
    x = tfm(img).unsqueeze(0).to(device)
    logits = net(x); probs = logits.softmax(1).cpu().numpy().squeeze()
    order = probs.argsort()[::-1][:topk]
    return [(st["classes"][i], float(probs[i])) for i in order]

# Example:
# predict_image("/path/to/img.jpg", ckpt)