In [None]:
!pip install -q medmnist

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist import PathMNIST
import medmnist
from tqdm import tqdm
from train_utils import train, evaluate

from google.colab import drive
drive.mount('/content/drive')

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = PathMNIST(split='train', transform=transform, download=True)
val_dataset = PathMNIST(split='val', transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


In [None]:
class CNNBaseline(nn.Module):
    def __init__(self, num_classes=9):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNNBaseline().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

save_path = '/content/drive/MyDrive/NCA/best_cnn_pathmnist.pth'
best_acc = 0

for epoch in range(1, 31):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(f"Epoch {epoch:02d}: Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print(f"Best model saved to: {save_path}")

    if val_acc > 0.90:
        print("Converged with >90% accuracy!")
        break

print("🎉 Training complete.")