In [None]:
import os
import argparse
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report

In [None]:
DEFAULT_TRAIN_DIR = r"C:\Users\sudeepta\Desktop\DR_project\augmented_resized_V2\train"
DEFAULT_VAL_DIR   = r"C:\Users\sudeepta\Desktop\DR_project\augmented_resized_V2\val"

IMAGE_SIZE = 224
BATCH_SIZE = 16        # reduce if OOM (8 or 4)
PROJ_DIM = 256
HEAD_EPOCHS = 20
PARTIAL_EPOCHS = 0
FINE_EPOCHS = 30
LR = 1e-3
OUT_DIR = "outputs_resnet_vision_fusion"
CHECKPOINT = "hybrid_resnet_vision.pth"


In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1, groups=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class VisionNet(nn.Module):
    """Small example VisionNet"""
    def __init__(self, in_channels=3, out_features=512):
        super().__init__()
        self.stem = nn.Sequential(
            ConvBlock(in_channels, 32, stride=2),
            ConvBlock(32, 32),
            ConvBlock(32, 64, stride=2),
        )
        self.stage1 = nn.Sequential(
            ConvBlock(64, 128),
            ConvBlock(128, 128),
        )
        self.stage2 = nn.Sequential(
            ConvBlock(128, 256, stride=2),
            ConvBlock(256, 256),
        )
        self.stage3 = nn.Sequential(
            ConvBlock(256, 512, stride=2),
            ConvBlock(512, 512),
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self._out = out_features
        self.project = nn.Linear(512, self._out)

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.pool(x).flatten(1)
        x = self.project(x)
        return x

    def feature_dim(self):
        return self._out

class VisionMamba(nn.Module):
    """Small example VisionMamba using separable/depthwise style ops"""
    def __init__(self, in_channels=3, out_features=512):
        super().__init__()
        self.stem = nn.Sequential(
            ConvBlock(in_channels, 24, stride=2),
            ConvBlock(24, 24),
        )
        self.depthwise = nn.Sequential(
            ConvBlock(24, 48, groups=24, stride=2),
            ConvBlock(48, 48, groups=48),
            ConvBlock(48, 96, stride=2),
            ConvBlock(96, 96),
        )
        self.spp = nn.Sequential(
            ConvBlock(96, 192, kernel=1, padding=0),
            ConvBlock(192, 192),
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self._out = out_features
        self.project = nn.Linear(192, self._out)

    def forward(self, x):
        x = self.stem(x)
        x = self.depthwise(x)
        x = self.spp(x)
        x = self.pool(x).flatten(1)
        x = self.project(x)
        return x

    def feature_dim(self):
        return self._out

In [None]:
class ResNet50_Feature(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        net = models.resnet50(pretrained=pretrained)
        # chop off final fc
        self.encoder = nn.Sequential(*list(net.children())[:-1])  # ends with AdaptiveAvgPool2d
        self.out_dim = 2048

    def forward(self, x):
        x = self.encoder(x).flatten(1)
        return x

    def feature_dim(self):
        return self.out_dim

In [None]:
class GatedProjection(nn.Module):
    def __init__(self, in_dim, proj_dim=PROJ_DIM):
        super().__init__()
        self.fc = nn.Linear(in_dim, proj_dim)
        self.bn = nn.BatchNorm1d(proj_dim)
        self.act = nn.ReLU(inplace=True)
        self.gate = nn.Sequential(
            nn.Linear(proj_dim, max(8, proj_dim // 8)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, proj_dim // 8), 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        z = self.act(self.bn(self.fc(x)))
        g = self.gate(z)   # (B,1)
        return z * g

class HybridResNetVisionFusion(nn.Module):
    def __init__(self, proj_dim=PROJ_DIM, num_classes=5, pretrained_resnet=True):
        super().__init__()
        self.resnet = ResNet50_Feature(pretrained=pretrained_resnet)
        self.visionnet = VisionNet(out_features=proj_dim)      # output proj_dim from project layer
        self.visionmamba = VisionMamba(out_features=proj_dim)  # same

        # For resnet, map 2048 -> proj_dim via gated proj
        self.proj_r = GatedProjection(self.resnet.feature_dim(), proj_dim)
        # visionnet and visionmamba already output proj_dim, but we'll still apply gating to keep interface uniform
        self.proj_vn = GatedProjection(self.visionnet.feature_dim(), proj_dim)
        self.proj_vm = GatedProjection(self.visionmamba.feature_dim(), proj_dim)

        fused_dim = proj_dim * 3
        self.head = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        fr = self.resnet(x)        # (B,2048)
        fvn = self.visionnet(x)    # (B,proj_dim)
        fvm = self.visionmamba(x)  # (B,proj_dim)
        pr = self.proj_r(fr)
        pvn = self.proj_vn(fvn)
        pvm = self.proj_vm(fvm)
        fused = torch.cat([pr, pvn, pvm], dim=1)
        out = self.head(fused)
        return out

In [None]:
def get_loaders(train_dir: str, val_dir: str, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, num_workers=4) -> Tuple[DataLoader, DataLoader]:
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.15,0.15,0.15,0.02),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    val_tf = transforms.Compose([
        transforms.Resize(int(image_size*1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
    val_ds = datasets.ImageFolder(val_dir, transform=val_tf)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader

In [None]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    total = 0
    correct = 0
    for imgs, labels in loader:
        imgs = imgs.to(device); labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        total += imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
    return running_loss / total, correct / total

def evaluate(model, loader, device):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    running_loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device); labels = labels.to(device)
            logits = model(imgs)
            loss = loss_fn(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            total += imgs.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
    return running_loss / total, correct / total

In [None]:
def plot_metrics(train_losses, val_losses, train_accs, val_accs, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    epochs = range(1, len(train_losses)+1)
    plt.figure()
    plt.plot(epochs, train_losses, marker='o'); plt.plot(epochs, val_losses, marker='o')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training & Validation Loss'); plt.legend(['train','val']); plt.grid(True)
    plt.savefig(os.path.join(out_dir, 'loss_curve.png')); plt.close()

    plt.figure()
    plt.plot(epochs, train_accs, marker='o'); plt.plot(epochs, val_accs, marker='o')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training & Validation Accuracy'); plt.legend(['train','val']); plt.grid(True)
    plt.savefig(os.path.join(out_dir, 'acc_curve.png')); plt.close()
    print("Saved loss/accuracy plots to", out_dir)

def confusion_and_report(model, loader, device, class_names: List[str], out_dir: str):
    model.eval()
    y_true = []; y_pred = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            preds = logits.argmax(dim=1).cpu().numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(labels.numpy().tolist())
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion matrix:\n", cm)
    print("Classification report:\n", classification_report(y_true, y_pred, target_names=class_names))
    plt.figure(figsize=(8,6))
    plt.imshow(cm, interpolation='nearest')
    plt.colorbar()
    ticks = np.arange(len(class_names))
    plt.xticks(ticks, class_names, rotation=45); plt.yticks(ticks, class_names)
    os.makedirs(out_dir, exist_ok=True)
    plt.tight_layout(); plt.savefig(os.path.join(out_dir, 'confusion_matrix.png')); plt.close()
    print("Saved confusion matrix to", out_dir)


In [None]:
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-dir', type=str, default=DEFAULT_TRAIN_DIR)
    parser.add_argument('--val-dir', type=str, default=DEFAULT_VAL_DIR)
    parser.add_argument('--batch', type=int, default=BATCH_SIZE)
    parser.add_argument('--proj-dim', type=int, default=PROJ_DIM)
    parser.add_argument('--head-epochs', type=int, default=HEAD_EPOCHS)
    parser.add_argument('--partial-epochs', type=int, default=PARTIAL_EPOCHS)
    parser.add_argument('--fine-epochs', type=int, default=FINE_EPOCHS)
    parser.add_argument('--lr', type=float, default=LR)
    parser.add_argument('--out-dir', type=str, default=OUT_DIR)
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT)
    args, _ = parser.parse_known_args()  # safe in notebooks

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    train_loader, val_loader = get_loaders(args.train_dir, args.val_dir, image_size=IMAGE_SIZE, batch_size=args.batch)
    class_names = sorted([d for d in os.listdir(args.train_dir) if os.path.isdir(os.path.join(args.train_dir, d))])
    print("Classes:", class_names)

    # Build model (visionnet/vm already output proj-dim from project, but we still use gated proj for consistency)
    model = HybridResNetVisionFusion(proj_dim=args.proj_dim, num_classes=len(class_names), pretrained_resnet=True)
    model.to(device)

    # Stage 1: freeze backbones (resnet, visionnet, visionmamba) except projection heads + head
    for p in model.resnet.parameters(): p.requires_grad = False
    for p in model.visionnet.parameters(): p.requires_grad = False
    for p in model.visionmamba.parameters(): p.requires_grad = False
    for p in model.proj_r.parameters(): p.requires_grad = True
    for p in model.proj_vn.parameters(): p.requires_grad = True
    for p in model.proj_vm.parameters(): p.requires_grad = True
    for p in model.head.parameters(): p.requires_grad = True

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4)

    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    best_val_acc = 0.0

    print(f"Head-only training for {args.head_epochs} epochs...")
    for epoch in range(1, args.head_epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(tr_loss); val_losses.append(val_loss)
        train_accs.append(tr_acc); val_accs.append(val_acc)
        print(f"[Head] Epoch {epoch}: train_loss={tr_loss:.4f}, train_acc={tr_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    # Partial unfreeze (if requested)
    if args.partial_epochs > 0:
        print(f"Partial unfreeze for {args.partial_epochs} epochs...")
        # Example unfreeze â€” adapt as desired
        for p in list(model.visionnet.stage3.parameters()):
            p.requires_grad = True
        for p in list(model.visionmamba.spp.parameters()):
            p.requires_grad = True
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr * 0.1, weight_decay=1e-4)
        for epoch in range(1, args.partial_epochs + 1):
            tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
            val_loss, val_acc = evaluate(model, val_loader, device)
            train_losses.append(tr_loss); val_losses.append(val_loss)
            train_accs.append(tr_acc); val_accs.append(val_acc)
            print(f"[Partial] Epoch {epoch}: val_acc={val_acc:.4f}")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    # Full fine-tune
    print(f"Full fine-tune for {args.fine_epochs} epochs...")
    for p in model.parameters(): p.requires_grad = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr * 0.01, weight_decay=1e-5)
    for i in range(1, args.fine_epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(tr_loss); val_losses.append(val_loss)
        train_accs.append(tr_acc); val_accs.append(val_acc)
        print(f"[Fine] Epoch {i}: train_loss={tr_loss:.4f}, train_acc={tr_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': i, 'model_state': model.state_dict(), 'val_acc': val_acc}, args.checkpoint)

    print("Training finished. Best val acc:", best_val_acc)
    plot_metrics(train_losses, val_losses, train_accs, val_accs, args.out_dir)
    confusion_and_report(model, val_loader, device, class_names, args.out_dir)
    print("Outputs saved to", args.out_dir)

if __name__ == "__main__":
    main()