In [1]:
import os, zipfile, requests, time, numpy as np
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

In [6]:

# ======================================================
# 2. Dataset + Transformations
# ======================================================
class SkinDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        self.image_paths, self.labels = [], []
        valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.jfif')
        for cls in self.classes:
            p = os.path.join(root_dir, cls)
            for fname in os.listdir(p):
                if fname.lower().endswith(valid_exts):
                    self.image_paths.append(os.path.join(p, fname))
                    self.labels.append(self.class_to_idx[cls])

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, idx):
        path, label = self.image_paths[idx], self.labels[idx]
        try:
            img = Image.open(path).convert("RGB")
        except:
            img = Image.new("RGB", (120,120), (0,0,0))
        if self.transform:
            img = self.transform(img)
        return img, label


train_tf = transforms.Compose([
    transforms.Resize(128),
    transforms.RandomResizedCrop(120, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
val_tf = transforms.Compose([
    transforms.Resize((120,120)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

dataset = SkinDataset("../data/", transform=None)
num_classes = len(dataset.classes)
print(f"Found {num_classes} classes, {len(dataset)} total images.")

# split
indices = torch.randperm(len(dataset))
split = int(0.9 * len(dataset))
train_idx, val_idx = indices[:split], indices[split:]
train_ds = Subset(SkinDataset("../data/", transform=train_tf), train_idx)
val_ds   = Subset(SkinDataset("../data/", transform=val_tf), val_idx)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4)


Found 10 classes, 10000 total images.


In [7]:


# ======================================================
# 3. Model (Compact Depthwise-Separable CNN)
# ======================================================
def dwpw(cin, cout, stride=1):
    return nn.Sequential(
        nn.Conv2d(cin, cin, 3, stride=stride, padding=1, groups=cin, bias=False),
        nn.BatchNorm2d(cin), nn.ReLU(inplace=True),
        nn.Conv2d(cin, cout, 1, bias=False),
        nn.BatchNorm2d(cout), nn.ReLU(inplace=True),
    )

class TinyDermNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        c = [16, 32, 64, 96, 128, 160]
        self.stem = nn.Sequential(
            nn.Conv2d(3, c[0], 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c[0]), nn.ReLU(inplace=True)
        )
        self.blocks = nn.Sequential(
            dwpw(c[0], c[1]), dwpw(c[1], c[2], 2),
            dwpw(c[2], c[2]), dwpw(c[2], c[3], 2),
            dwpw(c[3], c[3]), dwpw(c[3], c[4], 2),
            dwpw(c[4], c[5])
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(c[-1], num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.pool(x).flatten(1)
        return self.head(x)



In [None]:

# ======================================================
# 4. Training Loop + Metrics
# ======================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model = TinyDermNet(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
epochs = 10

def evaluate(model, loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            p = torch.argmax(out, dim=1)
            preds.append(p.cpu())
            labels.append(y.cpu())
    preds, labels = torch.cat(preds), torch.cat(labels)
    acc = accuracy_score(labels, preds)
    f1  = f1_score(labels, preds, average="macro")
    return acc, f1

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for imgs, labs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        imgs, labs = imgs.to(device), labs.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    train_loss = running_loss / len(train_loader.dataset)
    acc, f1 = evaluate(model, val_loader)
    print(f"Epoch {epoch+1}: loss={train_loss:.4f}, val_acc={acc:.4f}, val_f1={f1:.4f}")

# ======================================================
# 5. Save TorchScript model (<5MB)
# ======================================================
model.eval()
scripted = torch.jit.script(model)
save_path = "TinyDermNet_scripted.pt"
scripted.save(save_path)
size_mb = os.path.getsize(save_path)/1024/1024
print(f"Saved {save_path} ({size_mb:.2f} MB)")


# ======================================================
# 6. Local evaluation summary
# ======================================================
acc, f1 = evaluate(model, val_loader)
print(f"\nFinal Validation Accuracy: {acc:.4f}, Macro F1: {f1:.4f}")
