In [None]:
# =============================
# INSTALLS
# =============================
!pip install --quiet timm scikit-learn albumentations==1.4.10 opencv-python

# =============================
# IMPORTS
# =============================
import os
import gc
import math
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
import timm

from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, balanced_accuracy_score, classification_report
import seaborn as sns

import albumentations as A
from albumentations.pytorch import ToTensorV2

# =============================
# CONFIG
# =============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/home/dineshkumarv.22it/preprocessed_data"
IMG_DIR = os.path.join(DATA_DIR, "images")
ARR_DIR = os.path.join(DATA_DIR, "arrays")

BATCH_SIZE = 10
EPOCHS = 60
IMG_SIZE = 224
LR = 2e-4
WD = 1e-4
PATIENCE = 10
MIXUP_ALPHA = 0.4
CUTMIX_ALPHA = 0.4
P_MIXUP = 0.5
P_CUTMIX = 0.5
SEED = 42
FREEZE_WARMUP_EPOCHS = 3

torch.backends.cudnn.benchmark = True

def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed()

# =============================
# LOAD ARRAYS
# =============================
image_ids = np.load(os.path.join(ARR_DIR, "image_ids.npy"), allow_pickle=True)
labels = np.load(os.path.join(ARR_DIR, "y_labels.npy"), allow_pickle=True)
class_names = np.load(os.path.join(ARR_DIR, "label_classes.npy"), allow_pickle=True)

num_classes = int(len(class_names))
print("Classes:", class_names, "\n#classes:", num_classes, " #images:", len(image_ids))

# =============================
# TRANSFORMS
# =============================
train_tf = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.RandomRotate90(p=0.3),
    A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.05), rotate=(-15, 15), p=0.5),
    A.ColorJitter(p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.CLAHE(p=0.2),
    A.GaussianBlur(blur_limit=(3, 5), p=0.2),
    A.CoarseDropout(
        min_holes=1, max_holes=8,
        min_height=8, max_height=16,
        min_width=8, max_width=16,
        fill_value=0, p=0.5
    ),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])


val_tf = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])

# =============================
# DATASET
# =============================
class SkinDataset(Dataset):
    def __init__(self, image_ids, labels, img_dir, transform=None):
        self.image_ids = list(image_ids)
        self.labels = list(labels)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        label = int(self.labels[idx])
        img_path = os.path.join(self.img_dir, img_id + ".jpg")
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception:
            img = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (0, 0, 0))
        img = np.array(img)
        if self.transform:
            img = self.transform(image=img)["image"]
        return img, torch.tensor(label, dtype=torch.long)

# =============================
# SPLIT TRAIN/VAL
# =============================
from sklearn.model_selection import train_test_split
train_ids, val_ids, train_labels, val_labels = train_test_split(
    image_ids, labels, test_size=0.3, stratify=labels, random_state=SEED
)

train_dataset = SkinDataset(train_ids, train_labels, IMG_DIR, transform=train_tf)
val_dataset = SkinDataset(val_ids, val_labels, IMG_DIR, transform=val_tf)

# =============================
# HANDLE IMBALANCE
# =============================
class_counts = Counter(train_labels)
class_weights = {cls: 1.0 / max(1, count) for cls, count in class_counts.items()}
sample_weights = [class_weights[int(l)] for l in train_labels]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# =============================
# MIXUP / CUTMIX HELPERS
# =============================
def rand_bbox(W, H, lam):
    cut_w = int(W * math.sqrt(1 - lam))
    cut_h = int(H * math.sqrt(1 - lam))
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y2 = np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def apply_mixup_cutmix(images, targets, mixup_alpha=MIXUP_ALPHA, cutmix_alpha=CUTMIX_ALPHA,
                         p_mixup=P_MIXUP, p_cutmix=P_CUTMIX):
    B = images.size(0)
    if B < 2:
        return images, targets, targets, 1.0, None
    mode = None
    r = random.random()
    if mixup_alpha > 0 and r < p_mixup:
        mode = 'mixup'
        lam = np.random.beta(mixup_alpha, mixup_alpha)
        index = torch.randperm(B, device=images.device)
        mixed = lam * images + (1 - lam) * images[index, :]
        return mixed, targets, targets[index], lam, mode
    if cutmix_alpha > 0 and r >= p_mixup and r < p_mixup + p_cutmix:
        mode = 'cutmix'
        lam = np.random.beta(cutmix_alpha, cutmix_alpha)
        index = torch.randperm(B, device=images.device)
        _, H, W = images.size(1), images.size(2), images.size(3)
        x1, y1, x2, y2 = rand_bbox(W, H, lam)
        mixed = images.clone()
        mixed[:, :, y1:y2, x1:x2] = images[index, :, y1:y2, x1:x2]
        lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
        return mixed, targets, targets[index], lam, mode
    return images, targets, targets, 1.0, None

# =============================
# FOCAL LOSS
# =============================
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ce = nn.CrossEntropyLoss(reduction='none')
    def forward(self, inputs, targets):
        ce = self.ce(inputs, targets)
        pt = torch.exp(-ce)
        loss = self.alpha * (1 - pt) ** self.gamma * ce
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

def mix_criterion(criterion, outputs, y1, y2, lam, mode):
    if mode in ('mixup', 'cutmix'):
        return lam * criterion(outputs, y1) + (1 - lam) * criterion(outputs, y2)
    else:
        return criterion(outputs, y1)

# =============================
# MODEL
# =============================
class ConvNeXtSwin(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.convnext = timm.create_model("convnext_base", pretrained=True, num_classes=0, global_pool="avg")
        self.swin = timm.create_model("swin_base_patch4_window7_224", pretrained=True, num_classes=0, global_pool="avg")
        feat_dim = self.convnext.num_features + self.swin.num_features
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 768),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(768, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        f1 = self.convnext(x)
        f2 = self.swin(x)
        out = torch.cat([f1, f2], dim=1)
        return self.head(out)

model = ConvNeXtSwin(num_classes).to(DEVICE)

def set_backbone_grad(enabled: bool):
    for p in model.convnext.parameters():
        p.requires_grad = enabled
    for p in model.swin.parameters():
        p.requires_grad = enabled
set_backbone_grad(False)

# =============================
# OPTIMIZER + SCHEDULER
# =============================
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WD)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR * 0.05)
criterion = FocalLoss(alpha=1.0, gamma=2.0).to(DEVICE)

# =============================
# TRAINING LOOP
# =============================
best_val_f1 = -1.0
best_state = None
patience_counter = 0

def evaluate(model, loader, device=DEVICE):
    model.eval()
    y_true, y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device, non_blocking=True)
            lbls = lbls.to(device, non_blocking=True)
            out = model(imgs)
            loss = criterion(out, lbls)
            val_loss += loss.item() * imgs.size(0)
            preds = out.argmax(dim=1)
            y_true.extend(lbls.cpu().numpy().tolist())
            y_pred.extend(preds.cpu().numpy().tolist())
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    acc = accuracy_score(y_true, y_pred)
    bacc = balanced_accuracy_score(y_true, y_pred)
    mf1 = f1_score(y_true, y_pred, average='macro')
    return val_loss / len(loader.dataset), acc, bacc, mf1, y_true, y_pred

scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    if epoch == FREEZE_WARMUP_EPOCHS + 1:
        set_backbone_grad(True)
        optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS - epoch + 1, eta_min=LR * 0.05)
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [Train]", leave=False)
    for imgs, lbls in pbar:
        imgs = imgs.to(DEVICE, non_blocking=True)
        lbls = lbls.to(DEVICE, non_blocking=True)
        imgs_mixed, y1, y2, lam, mode = apply_mixup_cutmix(imgs, lbls)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
            outputs = model(imgs_mixed)
            loss = mix_criterion(criterion, outputs, y1, y2, lam, mode)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * imgs.size(0)
        
        # Calculate correct predictions based on original labels for accuracy logging
        preds = outputs.argmax(dim=1)
        if mode in ('mixup', 'cutmix'):
            correct += (preds == y1).sum().item() * lam + (preds == y2).sum().item() * (1 - lam)
        else:
            correct += (preds == lbls).sum().item()

        total += lbls.size(0)
        pbar.set_postfix(loss=f"{running_loss / total:.4f}", acc=f"{correct / total:.4f}")
    
    scheduler.step()
    val_loss, val_acc, val_bacc, val_mf1, y_true, y_pred = evaluate(model, val_loader)
    train_loss = running_loss / len(train_loader.dataset)
    train_acc = correct / len(train_loader.dataset)
    print(f"Epoch {epoch:03d}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} BalAcc: {val_bacc:.4f} MacroF1: {val_mf1:.4f} | "
          f"LR: {scheduler.get_last_lr()[0]:.6f}")
    if val_mf1 > best_val_f1:
        best_val_f1 = val_mf1
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping triggered at epoch {epoch}. Best Macro F1: {best_val_f1:.4f}")
            break

# =============================
# LOAD BEST & FINAL EVAL
# =============================
if best_state is not None:
    model.load_state_dict(best_state)

val_loss, val_acc, val_bacc, val_mf1, y_true, y_pred = evaluate(model, val_loader)
print("\n=== FINAL VALIDATION ===")
print(f"Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | Balanced Acc: {val_bacc:.4f} | Macro F1: {val_mf1:.4f}")
print("\nPer-class F1:")
print({class_names[i]: f"{f1:.3f}" for i, f1 in enumerate(f1_score(y_true, y_pred, average=None))})
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=[str(c) for c in class_names]))

# =============================
# CONFUSION MATRIX
# =============================
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()

# =============================
# SAVE MODEL
# =============================
save_path = "/home/dineshkumarv.22it/convnext_swin_focal_mix_aug_best.pth"
torch.save(model.state_dict(), save_path)
print("Model saved to:", save_path)

gc.collect()
torch.cuda.empty_cache()

[1;31merror[0m: [1mexternally-managed-environment[0m

[31m×[0m This environment is externally managed
[31m╰─>[0m To install Python packages system-wide, try apt install
[31m   [0m python3-xyz, where xyz is the package you are trying to
[31m   [0m install.
[31m   [0m 
[31m   [0m If you wish to install a non-Debian-packaged Python package,
[31m   [0m create a virtual environment using python3 -m venv path/to/venv.
[31m   [0m Then use path/to/venv/bin/python and path/to/venv/bin/pip. Make
[31m   [0m sure you have python3-full installed.
[31m   [0m 
[31m   [0m If you wish to install a non-Debian packaged Python application,
[31m   [0m it may be easiest to use pipx install xyz, which will manage a
[31m   [0m virtual environment for you. Make sure you have pipx installed.
[31m   [0m 
[31m   [0m See /usr/share/doc/python3.12/README.venv for more information.

[1;35mnote[0m: If you believe this is a mistake, please contact your Python installation or OS dist

  A.CoarseDropout(
                                                                                              

Epoch 001/60 | Train Loss: 1.0505 Acc: 0.4008 | Val Loss: 0.6964 Acc: 0.5483 BalAcc: 0.6092 MacroF1: 0.4248 | LR: 0.000200


                                                                                              

Epoch 002/60 | Train Loss: 0.9069 Acc: 0.4850 | Val Loss: 0.5643 Acc: 0.6293 BalAcc: 0.6216 MacroF1: 0.4972 | LR: 0.000199


                                                                                              

Epoch 003/60 | Train Loss: 0.8639 Acc: 0.5097 | Val Loss: 0.6015 Acc: 0.5833 BalAcc: 0.6354 MacroF1: 0.4654 | LR: 0.000199


                                                                                              

Epoch 004/60 | Train Loss: 0.8665 Acc: 0.5346 | Val Loss: 0.4281 Acc: 0.6809 BalAcc: 0.7267 MacroF1: 0.6351 | LR: 0.000200


                                                                                              

Epoch 005/60 | Train Loss: 0.6872 Acc: 0.6226 | Val Loss: 0.3158 Acc: 0.7228 BalAcc: 0.7319 MacroF1: 0.6554 | LR: 0.000199


                                                                                              

Epoch 006/60 | Train Loss: 0.5945 Acc: 0.6571 | Val Loss: 0.3885 Acc: 0.7358 BalAcc: 0.7323 MacroF1: 0.6776 | LR: 0.000199


                                                                                              

Epoch 007/60 | Train Loss: 0.5723 Acc: 0.6707 | Val Loss: 0.3200 Acc: 0.7578 BalAcc: 0.7515 MacroF1: 0.6939 | LR: 0.000198


                                                                                              

Epoch 008/60 | Train Loss: 0.5253 Acc: 0.6930 | Val Loss: 0.3522 Acc: 0.6816 BalAcc: 0.7261 MacroF1: 0.6732 | LR: 0.000196


                                                                                              

Epoch 009/60 | Train Loss: 0.5143 Acc: 0.7006 | Val Loss: 0.3142 Acc: 0.6685 BalAcc: 0.7397 MacroF1: 0.6844 | LR: 0.000195


                                                                                               

Epoch 010/60 | Train Loss: 0.4934 Acc: 0.7077 | Val Loss: 0.2610 Acc: 0.8131 BalAcc: 0.7604 MacroF1: 0.7231 | LR: 0.000193


                                                                                               

Epoch 011/60 | Train Loss: 0.4514 Acc: 0.7333 | Val Loss: 0.2417 Acc: 0.8038 BalAcc: 0.7370 MacroF1: 0.7374 | LR: 0.000191


                                                                                               

Epoch 012/60 | Train Loss: 0.4659 Acc: 0.7271 | Val Loss: 0.3657 Acc: 0.7602 BalAcc: 0.7518 MacroF1: 0.7019 | LR: 0.000189


                                                                                               

Epoch 013/60 | Train Loss: 0.4453 Acc: 0.7345 | Val Loss: 0.3639 Acc: 0.7066 BalAcc: 0.7537 MacroF1: 0.6446 | LR: 0.000186


                                                                                               

Epoch 014/60 | Train Loss: 0.4650 Acc: 0.7266 | Val Loss: 0.2722 Acc: 0.7400 BalAcc: 0.6944 MacroF1: 0.6988 | LR: 0.000183


                                                                                               

Epoch 015/60 | Train Loss: 0.4248 Acc: 0.7437 | Val Loss: 0.2555 Acc: 0.7870 BalAcc: 0.7320 MacroF1: 0.7237 | LR: 0.000180


                                                                                               

Epoch 016/60 | Train Loss: 0.4054 Acc: 0.7553 | Val Loss: 0.2482 Acc: 0.7942 BalAcc: 0.7568 MacroF1: 0.7322 | LR: 0.000177


                                                                                               

Epoch 017/60 | Train Loss: 0.3914 Acc: 0.7651 | Val Loss: 0.2235 Acc: 0.8344 BalAcc: 0.7667 MacroF1: 0.7605 | LR: 0.000173


                                                                                               

Epoch 018/60 | Train Loss: 0.3842 Acc: 0.7602 | Val Loss: 0.2452 Acc: 0.8224 BalAcc: 0.7763 MacroF1: 0.7425 | LR: 0.000169


                                                                                               

Epoch 019/60 | Train Loss: 0.3855 Acc: 0.7606 | Val Loss: 0.2266 Acc: 0.8482 BalAcc: 0.7661 MacroF1: 0.7514 | LR: 0.000165


                                                                                               

Epoch 020/60 | Train Loss: 0.3803 Acc: 0.7686 | Val Loss: 0.1970 Acc: 0.8712 BalAcc: 0.7698 MacroF1: 0.7970 | LR: 0.000161


                                                                                               

Epoch 021/60 | Train Loss: 0.3651 Acc: 0.7736 | Val Loss: 0.2775 Acc: 0.7314 BalAcc: 0.7718 MacroF1: 0.7349 | LR: 0.000157


                                                                                               

Epoch 022/60 | Train Loss: 0.3564 Acc: 0.7803 | Val Loss: 0.2341 Acc: 0.8382 BalAcc: 0.7743 MacroF1: 0.7689 | LR: 0.000153


                                                                                               

Epoch 023/60 | Train Loss: 0.3637 Acc: 0.7702 | Val Loss: 0.2094 Acc: 0.8667 BalAcc: 0.7818 MacroF1: 0.7860 | LR: 0.000148


                                                                                               

Epoch 024/60 | Train Loss: 0.3541 Acc: 0.7732 | Val Loss: 0.1957 Acc: 0.8695 BalAcc: 0.7353 MacroF1: 0.7748 | LR: 0.000143


                                                                                               

Epoch 025/60 | Train Loss: 0.3364 Acc: 0.7894 | Val Loss: 0.2008 Acc: 0.8379 BalAcc: 0.7691 MacroF1: 0.7718 | LR: 0.000138


                                                                                               

Epoch 026/60 | Train Loss: 0.3434 Acc: 0.7800 | Val Loss: 0.2099 Acc: 0.8633 BalAcc: 0.7493 MacroF1: 0.7821 | LR: 0.000133


                                                                                               

Epoch 027/60 | Train Loss: 0.3284 Acc: 0.7792 | Val Loss: 0.2041 Acc: 0.8664 BalAcc: 0.7697 MacroF1: 0.7894 | LR: 0.000128


                                                                                               

Epoch 028/60 | Train Loss: 0.3225 Acc: 0.7875 | Val Loss: 0.2588 Acc: 0.8482 BalAcc: 0.7630 MacroF1: 0.7684 | LR: 0.000123


                                                                                               

Epoch 029/60 | Train Loss: 0.2982 Acc: 0.8026 | Val Loss: 0.2104 Acc: 0.8598 BalAcc: 0.7977 MacroF1: 0.7968 | LR: 0.000118


Epoch 30/60 [Train]:   5%|▌         | 37/679 [00:05<01:35,  6.73it/s, acc=0.7860, loss=0.3461]