**FINAL CODE WORKING:**

In [None]:
import torch
import torch.nn as nn
import timm

class HybridFusionModel(nn.Module):
    def __init__(self, num_classes=6):
        super().__init__()
        self.cnn = timm.create_model("resnet18", pretrained=True, features_only=True)
        self.swin = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, features_only=True)

        self.fusion_conv = None
        self.classifier = None
        self.num_classes = num_classes

    def _init_fusion(self, cnn_feat, swin_feat):
        """Build fusion layers dynamically using first batch shapes."""
        c_cnn = cnn_feat.shape[1]
        c_swin = swin_feat.shape[1]
        total = c_cnn + c_swin
        print(f"Auto-configuring fusion: CNN {c_cnn} + Swin {c_swin} ‚Üí total {total}")
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(total, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, self.num_classes)
        )

    def forward(self, x):
        cnn_feats = self.cnn(x)[-1]
        swin_feats = self.swin(x)[-1]
        if swin_feats.shape[2:] != cnn_feats.shape[2:]:
            swin_feats = torch.nn.functional.interpolate(
                swin_feats, size=cnn_feats.shape[2:], mode="bilinear"
            )

        if self.fusion_conv is None:
            # first forward pass ‚Üí create layers dynamically
            self._init_fusion(cnn_feats, swin_feats)

        fused = torch.cat((cnn_feats, swin_feats), dim=1)
        fused = self.fusion_conv(fused)
        out = self.classifier(fused)
        return out, fused


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridFusionModel(num_classes=6).to(device)

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

dataset = datasets.ImageFolder("/content/drive/MyDrive/Dataset2.0", transform=transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True)


imgs, _ = next(iter(loader))
_ = model(imgs.to(device))
print("Model initialized successfully ‚úîÔ∏è")



Auto-configuring fusion: CNN 512 + Swin 7 ‚Üí total 519
Model initialized successfully ‚úîÔ∏è


In [None]:

from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

epochs = 10
best_val_acc = 0

val_split = int(0.8 * len(dataset))
train_set, val_set = torch.utils.data.random_split(dataset, [val_split, len(dataset) - val_split])

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)

for epoch in range(epochs):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        out, _ = model(imgs)
        loss = criterion(out, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, pred = torch.max(out, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

    train_acc = 100 * correct / total
    scheduler.step()
    print(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%")

    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            out, _ = model(imgs)
            _, pred = torch.max(out, 1)
            val_total += labels.size(0)
            val_correct += (pred == labels).sum().item()

    val_acc = 100 * val_correct / val_total
    print(f"Validation Acc: {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "/content/hybrid_fusion_gradcam_best.pth")
        print(f"Saved best model (Val Acc: {val_acc:.2f}%)")

print(f"\nüèÜ Training Complete | Best Validation Accuracy: {best_val_acc:.2f}%")


Epoch 1/10:   3%|‚ñé         | 2/60 [00:06<03:07,  3.23s/it]


KeyboardInterrupt: 