In [None]:
import os
import sys
import argparse
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report

In [None]:
DEFAULT_TRAIN_DIR = r"C:\Users\sudeepta\Desktop\DR_project\augmented_resized_V2\train"
DEFAULT_VAL_DIR   = r"C:\Users\sudeepta\Desktop\DR_project\augmented_resized_V2\val"

IMAGE_SIZE = 224
BATCH_SIZE = 16        # reduce if GPU memory limited
PROJ_DIM = 256
NUM_CLASSES = 5
HEAD_EPOCHS = 5        # you can use 20 if you want (keeps quick by default)
PARTIAL_EPOCHS = 0
FINE_EPOCHS = 5        # use 30 for thorough training
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_DIR = "outputs_hybrid"
CHECKPOINT = "hybrid_r_d_m.pth"

In [None]:
class ResNet50_Feature(nn.Module):
    def __init__(self, pretrained=True, out_dim=2048):
        super().__init__()
        net = models.resnet50(pretrained=pretrained)
        # chop off fc
        modules = list(net.children())[:-1]   # ends with AdaptiveAvgPool2d
        self.encoder = nn.Sequential(*modules)  # output (B, 2048, 1,1)
        self.out_dim = out_dim

    def forward(self, x):
        x = self.encoder(x).flatten(1)  # (B, 2048)
        return x

    def feature_dim(self):
        return self.out_dim

class DenseNet121_Feature(nn.Module):
    def __init__(self, pretrained=True, out_dim=1024):
        super().__init__()
        net = models.densenet121(pretrained=pretrained)
        # net.features -> feature map (B, 1024, H, W)
        self.features = net.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.out_dim = out_dim

    def forward(self, x):
        x = self.features(x)
        x = self.pool(F.relu(x)).flatten(1)  # (B, 1024)
        return x

    def feature_dim(self):
        return self.out_dim

class MobileNetV2_Feature(nn.Module):
    def __init__(self, pretrained=True, out_dim=1280):
        super().__init__()
        net = models.mobilenet_v2(pretrained=pretrained)
        # net.features -> (B, 1280, H, W) after features
        self.features = net.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.out_dim = out_dim

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x).flatten(1)  # (B, 1280)
        return x

    def feature_dim(self):
        return self.out_dim

In [None]:
class GatedProj(nn.Module):
    def __init__(self, in_dim, proj_dim=PROJ_DIM):
        super().__init__()
        self.fc = nn.Linear(in_dim, proj_dim, bias=True)
        self.bn = nn.BatchNorm1d(proj_dim)
        self.act = nn.ReLU(inplace=True)
        self.gate = nn.Sequential(
            nn.Linear(proj_dim, max(8, proj_dim // 8)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, proj_dim // 8), 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.act(self.bn(self.fc(x)))
        g = self.gate(z)           # (B,1)
        return z * g               # gated proj (B, proj_dim)

class HybridFusionNet(nn.Module):
    def __init__(self, proj_dim=PROJ_DIM, num_classes=NUM_CLASSES, pretrained=True):
        super().__init__()
        # instantiate backbones
        self.resnet = ResNet50_Feature(pretrained=pretrained)
        self.densenet = DenseNet121_Feature(pretrained=pretrained)
        self.mobilenet = MobileNetV2_Feature(pretrained=pretrained)

        # projection heads
        self.proj_r = GatedProj(self.resnet.feature_dim(), proj_dim)
        self.proj_d = GatedProj(self.densenet.feature_dim(), proj_dim)
        self.proj_m = GatedProj(self.mobilenet.feature_dim(), proj_dim)

        fused_dim = proj_dim * 3
        self.head = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        fr = self.resnet(x)
        fd = self.densenet(x)
        fm = self.mobilenet(x)
        pr = self.proj_r(fr)
        pd = self.proj_d(fd)
        pm = self.proj_m(fm)
        fused = torch.cat([pr, pd, pm], dim=1)
        out = self.head(fused)
        return out

In [None]:
def get_loaders(train_dir: str, val_dir: str,
                image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, num_workers=4) -> Tuple[DataLoader, DataLoader]:
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2,0.2,0.2,0.02),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    val_tf = transforms.Compose([
        transforms.Resize(int(image_size*1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
    val_ds   = datasets.ImageFolder(val_dir, transform=val_tf)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader




In [None]:

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    total = 0
    correct = 0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        total += imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
    return running_loss/total, correct/total

def evaluate(model, loader, device):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            logits = model(imgs)
            loss = loss_fn(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            total += imgs.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
    return running_loss/total, correct/total

In [None]:
def plot_metrics(train_losses, val_losses, train_accs, val_accs, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    epochs = range(1, len(train_losses)+1)
    plt.figure()
    plt.plot(epochs, train_losses, marker='o'); plt.plot(epochs, val_losses, marker='o')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss'); plt.legend(['train','val']); plt.grid(True)
    plt.savefig(os.path.join(out_dir, 'loss_curve.png')); plt.close()

    plt.figure()
    plt.plot(epochs, train_accs, marker='o'); plt.plot(epochs, val_accs, marker='o')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy'); plt.legend(['train','val']); plt.grid(True)
    plt.savefig(os.path.join(out_dir, 'acc_curve.png')); plt.close()

def confusion_and_report(model, loader, device, class_names: List[str], out_dir):
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            preds = logits.argmax(dim=1).cpu().numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(labels.numpy().tolist())
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion matrix:\n", cm)
    print("Classification report:\n", classification_report(y_true, y_pred, target_names=class_names))
    plt.figure(figsize=(8,6))
    plt.imshow(cm, interpolation='nearest')
    plt.colorbar()
    ticks = np.arange(len(class_names))
    plt.xticks(ticks, class_names, rotation=45); plt.yticks(ticks, class_names)
    plt.xlabel('Predicted'); plt.ylabel('True'); plt.tight_layout()
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(os.path.join(out_dir, 'confusion_matrix.png')); plt.close()


In [None]:
def main():
    # CLI safe in Jupyter: use parse_known_args to ignore notebook args
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-dir', type=str, default=DEFAULT_TRAIN_DIR)
    parser.add_argument('--val-dir', type=str, default=DEFAULT_VAL_DIR)
    parser.add_argument('--batch', type=int, default=BATCH_SIZE)
    parser.add_argument('--image-size', type=int, default=IMAGE_SIZE)
    parser.add_argument('--proj-dim', type=int, default=PROJ_DIM)
    parser.add_argument('--num-classes', type=int, default=NUM_CLASSES)
    parser.add_argument('--head-epochs', type=int, default=HEAD_EPOCHS)
    parser.add_argument('--partial-epochs', type=int, default=PARTIAL_EPOCHS)
    parser.add_argument('--fine-epochs', type=int, default=FINE_EPOCHS)
    parser.add_argument('--lr', type=float, default=LR)
    parser.add_argument('--out-dir', type=str, default=OUT_DIR)
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT)
    args, _ = parser.parse_known_args()

    device = torch.device("cuda" if (args and torch.cuda.is_available()) else "cpu")
    print("Device:", device)

    # loaders
    train_loader, val_loader = get_loaders(args.train_dir, args.val_dir, image_size=args.image_size, batch_size=args.batch)
    class_names = sorted([d for d in os.listdir(args.train_dir) if os.path.isdir(os.path.join(args.train_dir, d))])
    print("Classes:", class_names)

    # model
    model = HybridFusionNet(proj_dim=args.proj_dim, num_classes=args.num_classes, pretrained=True)
    model.to(device)

    # Stage 1: freeze backbones, train projection + head
    for p in model.resnet.parameters(): p.requires_grad = False
    for p in model.densenet.parameters(): p.requires_grad = False
    for p in model.mobilenet.parameters(): p.requires_grad = False
    for p in model.proj_r.parameters(): p.requires_grad = True
    for p in model.proj_d.parameters(): p.requires_grad = True
    for p in model.proj_m.parameters(): p.requires_grad = True
    for p in model.head.parameters(): p.requires_grad = True

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4)

    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    best_val_acc = 0.0

    print(f"Stage1 (head-only) epochs: {args.head_epochs}")
    for epoch in range(1, args.head_epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(tr_loss); val_losses.append(val_loss)
        train_accs.append(tr_acc); val_accs.append(val_acc)
        print(f"[Head] Epoch {epoch}: train_loss={tr_loss:.4f}, train_acc={tr_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    # Stage 2: optional partial unfreeze
    if args.partial_epochs > 0:
        print(f"Stage2 (partial) epochs: {args.partial_epochs}")
        # unfreeze last blocks (examples) - adapt to your backbone structure if needed
        for p in model.resnet.encoder[-3:].parameters(): p.requires_grad = True
        # for densenet/mobilenet, unfreeze last conv blocks heuristically
        for p in model.densenet.parameters(): p.requires_grad = True  # simple: unfreeze all or tailor
        for p in model.mobilenet.parameters(): p.requires_grad = True

        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr * 0.1, weight_decay=1e-4)
        for epoch in range(1, args.partial_epochs + 1):
            tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
            val_loss, val_acc = evaluate(model, val_loader, device)
            train_losses.append(tr_loss); val_losses.append(val_loss)
            train_accs.append(tr_acc); val_accs.append(val_acc)
            print(f"[Partial] Epoch {epoch}: train_acc={tr_acc:.4f}, val_acc={val_acc:.4f}")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    # Stage 3: full fine-tune
    print(f"Stage3 (fine-tune) epochs: {args.fine_epochs}")
    for p in model.parameters(): p.requires_grad = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr * 0.01, weight_decay=1e-5)
    start_epoch = len(train_losses)  # epoch count for plotting purposes
    for i in range(1, args.fine_epochs + 1):
        epoch = start_epoch + i
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(tr_loss); val_losses.append(val_loss)
        train_accs.append(tr_acc); val_accs.append(val_acc)
        print(f"[Fine] Epoch {i}: train_loss={tr_loss:.4f}, train_acc={tr_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': i, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    print("Training finished. Best val acc:", best_val_acc)

    # plots & confusion
    plot_metrics(train_losses, val_losses, train_accs, val_accs, args.out_dir)
    confusion_and_report(model, val_loader, device, class_names, args.out_dir)
    print("Plots and confusion matrix saved to", args.out_dir)

if __name__ == "__main__":
    main()