In [None]:
import os
import argparse
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report

In [None]:
DEFAULT_TRAIN_DIR = r"C:\Users\sudeepta\Desktop\DR_project\augmented_resized_V2\train"
DEFAULT_VAL_DIR   = r"C:\Users\sudeepta\Desktop\DR_project\augmented_resized_V2\val"

IMAGE_SIZE = 224
BATCH_SIZE = 16        # reduce to 8 or 4 if OOM
PROJ_DIM = 256
NUM_CLASSES = 5
HEAD_EPOCHS = 20
PARTIAL_EPOCHS = 0
FINE_EPOCHS = 30
LR = 1e-3
OUT_DIR = "outputs_vgg_densenet_condense"
CHECKPOINT = "hybrid_vgg_dn_condense.pth"

In [None]:
class VGG19_Feature(nn.Module):
    def __init__(self, pretrained=True, out_dim=512):
        super().__init__()
        net = models.vgg19(pretrained=pretrained)
        # vgg.features outputs feature maps; use adaptive pool then flatten
        self.features = net.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.out_dim = out_dim  # choose projection input dim expected (we will adapt)
        # Actual channel output after features for vgg19 is 512
        self.feat_channels = 512

    def forward(self, x):
        x = self.features(x)                  # (B, 512, H, W)
        x = self.pool(x).flatten(1)           # (B, 512)
        return x

    def feature_dim(self):
        return self.feat_channels

class DenseNet121_Feature(nn.Module):
    def __init__(self, pretrained=True, out_dim=1024):
        super().__init__()
        net = models.densenet121(pretrained=pretrained)
        self.features = net.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.feat_channels = 1024

    def forward(self, x):
        x = self.features(x)                  # (B, 1024, H, W)
        x = F.relu(x, inplace=True)
        x = self.pool(x).flatten(1)           # (B, 1024)
        return x

    def feature_dim(self):
        return self.feat_channels

# Simple CondenseNet-like module (lightweight)
class CondenseBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, groups=4):
        super().__init__()
        self.pw = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False, groups=groups)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.act1 = nn.ReLU(inplace=True)
        self.dw = nn.Conv2d(out_ch, out_ch, kernel_size=kernel, padding=kernel//2, stride=stride, groups=out_ch, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.act2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.act1(self.bn1(self.pw(x)))
        x = self.act2(self.bn2(self.dw(x)))
        return x

class CondenseNetSimple(nn.Module):
    def __init__(self, in_channels=3, feat_channels=256):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            CondenseBlock(32, 64, groups=4),
            CondenseBlock(64, 128, groups=4),
            CondenseBlock(128, 256, groups=8),
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.feat_channels = 256

    def forward(self, x):
        x = self.stem(x)
        x = self.pool(x).flatten(1)
        return x

    def feature_dim(self):
        return self.feat_channels

In [None]:
class GatedProj(nn.Module):
    def __init__(self, in_dim, proj_dim=PROJ_DIM):
        super().__init__()
        self.fc = nn.Linear(in_dim, proj_dim)
        self.bn = nn.BatchNorm1d(proj_dim)
        self.act = nn.ReLU(inplace=True)
        self.gate = nn.Sequential(
            nn.Linear(proj_dim, max(8, proj_dim // 8)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, proj_dim // 8), 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.act(self.bn(self.fc(x)))
        g = self.gate(z)     # (B, 1)
        return z * g         # gated projected vector (B, proj_dim)

class HybridVGG_DN_Condense(nn.Module):
    def __init__(self, proj_dim=PROJ_DIM, num_classes=NUM_CLASSES, pretrained=True):
        super().__init__()
        self.vgg = VGG19_Feature(pretrained=pretrained)
        self.dn = DenseNet121_Feature(pretrained=pretrained)
        self.cond = CondenseNetSimple(in_channels=3)

        # projection heads
        self.proj_v = GatedProj(self.vgg.feature_dim(), proj_dim)
        self.proj_d = GatedProj(self.dn.feature_dim(), proj_dim)
        self.proj_c = GatedProj(self.cond.feature_dim(), proj_dim)

        fused_dim = proj_dim * 3
        self.head = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        fv = self.vgg(x)   # (B, 512)
        fd = self.dn(x)    # (B, 1024)
        fc = self.cond(x)  # (B, 256)
        pv = self.proj_v(fv)
        pd = self.proj_d(fd)
        pc = self.proj_c(fc)
        fused = torch.cat([pv, pd, pc], dim=1)
        out = self.head(fused)
        return out

In [None]:
def get_loaders(train_dir: str, val_dir: str, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, num_workers=4) -> Tuple[DataLoader, DataLoader]:
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.02),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    val_tf = transforms.Compose([
        transforms.Resize(int(image_size*1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
    val_ds = datasets.ImageFolder(val_dir, transform=val_tf)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader

In [None]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    total = 0
    correct = 0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        total += imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
    return running_loss / total, correct / total

def evaluate(model, loader, device):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            logits = model(imgs)
            loss = loss_fn(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            total += imgs.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
    return running_loss / total, correct / total


In [None]:
def plot_metrics(train_losses, val_losses, train_accs, val_accs, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    epochs = range(1, len(train_losses)+1)
    plt.figure(); plt.plot(epochs, train_losses, marker='o'); plt.plot(epochs, val_losses, marker='o')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss'); plt.legend(['train','val']); plt.grid(True); plt.savefig(os.path.join(out_dir, 'loss_curve.png')); plt.close()
    plt.figure(); plt.plot(epochs, train_accs, marker='o'); plt.plot(epochs, val_accs, marker='o')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy'); plt.legend(['train','val']); plt.grid(True); plt.savefig(os.path.join(out_dir, 'acc_curve.png')); plt.close()

def confusion_and_report(model, loader, device, class_names, out_dir):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            preds = logits.argmax(dim=1).cpu().numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(labels.numpy().tolist())
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion matrix:\n", cm)
    print("Classification report:\n", classification_report(y_true, y_pred, target_names=class_names))
    plt.figure(figsize=(8,6)); plt.imshow(cm, interpolation='nearest'); plt.colorbar()
    ticks = np.arange(len(class_names))
    plt.xticks(ticks, class_names, rotation=45); plt.yticks(ticks, class_names)
    plt.tight_layout(); os.makedirs(out_dir, exist_ok=True); plt.savefig(os.path.join(out_dir, 'confusion_matrix.png')); plt.close()


In [None]:
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-dir', type=str, default=DEFAULT_TRAIN_DIR)
    parser.add_argument('--val-dir', type=str, default=DEFAULT_VAL_DIR)
    parser.add_argument('--batch', type=int, default=BATCH_SIZE)
    parser.add_argument('--proj-dim', type=int, default=PROJ_DIM)
    parser.add_argument('--head-epochs', type=int, default=HEAD_EPOCHS)
    parser.add_argument('--partial-epochs', type=int, default=PARTIAL_EPOCHS)
    parser.add_argument('--fine-epochs', type=int, default=FINE_EPOCHS)
    parser.add_argument('--lr', type=float, default=LR)
    parser.add_argument('--out-dir', type=str, default=OUT_DIR)
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT)
    args, _ = parser.parse_known_args()  # Jupyter-safe

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    train_loader, val_loader = get_loaders(args.train_dir, args.val_dir, IMAGE_SIZE, args.batch)
    class_names = sorted([d for d in os.listdir(args.train_dir) if os.path.isdir(os.path.join(args.train_dir, d))])
    print("Classes:", class_names)

    model = HybridVGG_DN_Condense(proj_dim=args.proj_dim, num_classes=len(class_names), pretrained=True)
    model.to(device)

    # Stage 1: freeze backbone feature extractors, train projections + head
    for p in model.vgg.parameters(): p.requires_grad = False
    for p in model.dn.parameters(): p.requires_grad = False
    for p in model.cond.parameters(): p.requires_grad = False
    for p in model.proj_v.parameters(): p.requires_grad = True
    for p in model.proj_d.parameters(): p.requires_grad = True
    for p in model.proj_c.parameters(): p.requires_grad = True
    for p in model.head.parameters(): p.requires_grad = True

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4)
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    best_val_acc = 0.0

    print(f"Head-only training for {args.head_epochs} epochs...")
    for epoch in range(1, args.head_epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(tr_loss); val_losses.append(val_loss)
        train_accs.append(tr_acc); val_accs.append(val_acc)
        print(f"[Head] Epoch {epoch}: train_loss={tr_loss:.4f}, train_acc={tr_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    # Stage 2: optional partial unfreeze
    if args.partial_epochs > 0:
        print(f"Partial unfreeze for {args.partial_epochs} epochs...")
        # Example: unfreeze last block of VGG (features[-4:] etc.) - adjust as needed
        for p in list(model.vgg.features.children())[-6:]:
            for q in p.parameters():
                q.requires_grad = True
        # You can also unfreeze specific DenseNet blocks if desired
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr * 0.1, weight_decay=1e-4)
        for epoch in range(1, args.partial_epochs + 1):
            tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
            val_loss, val_acc = evaluate(model, val_loader, device)
            train_losses.append(tr_loss); val_losses.append(val_loss)
            train_accs.append(tr_acc); val_accs.append(val_acc)
            print(f"[Partial] Epoch {epoch}: val_acc={val_acc:.4f}")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    # Stage 3: full fine-tune
    print(f"Full fine-tune for {args.fine_epochs} epochs...")
    for p in model.parameters(): p.requires_grad = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr * 0.01, weight_decay=1e-5)
    start_epoch = len(train_losses)
    for i in range(1, args.fine_epochs + 1):
        epoch = start_epoch + i
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(tr_loss); val_losses.append(val_loss)
        train_accs.append(tr_acc); val_accs.append(val_acc)
        print(f"[Fine] Epoch {i}: train_acc={tr_acc:.4f}, val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    print("Training finished. Best val acc:", best_val_acc)
    plot_metrics(train_losses, val_losses, train_accs, val_accs, args.out_dir)
    confusion_and_report(model, val_loader, device, class_names, args.out_dir)
    print("Outputs saved to", args.out_dir)

if __name__ == "__main__":
    main()