In [None]:
!pip install medmnist

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
import medmnist
from medmnist import INFO
import numpy as np
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
import os
import signal
import sys
import json

In [None]:
# Handle interruptions gracefully
def handle_interrupt(signum, frame):
    print("\nTraining interrupted! Saving model and metrics...")
    save_checkpoint()
    sys.exit(1)

signal.signal(signal.SIGINT, handle_interrupt)

In [None]:
# Enable GPU memory optimizations for Colab
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# os.makedirs('/content/drive/MyDrive/pathmnist', exist_ok=True)


config = {
    'data_flag': 'pathmnist',  # Example dataset from MedMNIST
    'batch_size': 128,
    'max_epochs': 50,
    'early_stop_patience': 5,
    'checkpoint_path': '/content/checkpoint.pth',
    'best_model_path': '/content/best_model.pth',
    'metrics_path': '/content/training_metrics.json',
    'resume_training': False  # Set to True to resume from checkpoint
}

print(f"Using device: {device}")


In [None]:
# Data Preparation
info = INFO[config['data_flag']]
DataClass = getattr(medmnist, info['python_class'])

# Compute class weights
train_dataset = DataClass(split='train', download=True)
train_labels = [label.item() for _, label in train_dataset]
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Data Augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Datasets and DataLoaders
train_dataset = DataClass(split='train', transform=train_transform)
val_dataset = DataClass(split='val', transform=test_transform)
test_dataset = DataClass(split='test', transform=test_transform)

sampler = WeightedRandomSampler(
    weights=class_weights[train_labels], 
    num_samples=len(train_dataset),
    replacement=True
)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], sampler=sampler,
                         num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,
                       num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False,
                        num_workers=4, pin_memory=True)

In [None]:
# Model Components
class SqueezeExcitation(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class MBConvBlock(nn.Module):
    def __init__(self, in_c, out_c, expansion=4):
        super().__init__()
        hidden_dim = in_c * expansion
        self.use_residual = in_c == out_c
        
        self.block = nn.Sequential(
            nn.Conv2d(in_c, hidden_dim, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            SqueezeExcitation(hidden_dim),
            nn.Conv2d(hidden_dim, out_c, 1, bias=False),
            nn.BatchNorm2d(out_c)
        )

    def forward(self, x):
        if self.use_residual:
            return x + self.block(x)
        return self.block(x)

class EfficientVSSBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
        
        possible_groups = [32, 16, 8, 4, 2, 1]
        self.num_groups = next((g for g in possible_groups if dim % g == 0), 1)
        
        self.norm = nn.GroupNorm(num_groups=self.num_groups, num_channels=dim)
        self.pwconv1 = nn.Linear(dim, 2 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(2 * dim, dim)
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1) * 1e-6)
        self.se = SqueezeExcitation(dim)

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = self.se(x)
        
        b, c, h, w = x.shape
        x = x.permute(0, 2, 3, 1).contiguous()
        
        if self.num_groups > 1:
            x = x.reshape(b * h * w, c)
            x = self.norm(x)
            x = x.reshape(b, h, w, c)
        else:
            x = self.norm(x)
        
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        
        x = x.permute(0, 3, 1, 2).contiguous()
        return input + self.gamma * x

In [None]:
# Main Model Architecture
class EnhancedPathoNet(nn.Module):
    def __init__(self, num_classes=9):
        super().__init__()
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        
        # Encoder
        self.encoder = nn.Sequential(
            MBConvBlock(32, 64),
            MBConvBlock(64, 128),
            MBConvBlock(128, 256),
            MBConvBlock(256, 512),
            nn.Conv2d(512, 512, 3, padding=1, groups=512),
            nn.Conv2d(512, 1024, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        
        # Transformer
        self.vss_block = EfficientVSSBlock(1024)
        
        # Multi-Scale Features
        self.output1 = nn.Sequential(
            SqueezeExcitation(1024),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1024, 512)
        )
        
        self.output2 = nn.Sequential(
            nn.Conv2d(1024, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
        # Decoders
        self.decoder1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
        self.decoder2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        # Skip Connections
        self.skip1 = nn.Conv2d(1024, 256, 1)
        self.skip2 = nn.Conv2d(512, 128, 1)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes))
        
        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Stem
        x = self.stem(x)
        
        # Encoder
        x = self.encoder(x)
        
        # Transformer
        x = self.vss_block(x)
        
        # Multi-path
        global_feat = self.output1(x)
        spatial_feat = self.output2(x)
        
        # Decoder with skip connections
        d1 = self.decoder1(spatial_feat)
        skip1 = F.interpolate(self.skip1(x), size=d1.shape[2:], mode='bilinear', align_corners=True)
        
        d2 = self.decoder2(d1 + skip1)
        skip2 = F.interpolate(self.skip2(spatial_feat), size=d2.shape[2:], mode='bilinear', align_corners=True)
        
        out = self.classifier(d2 + skip2)
        return out


In [None]:
# Loss Functions
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (self.alpha * (1-pt)**self.gamma * ce_loss).mean()
        return focal_loss

# class CombinedLoss(nn.Module):
#     def __init__(self, class_weights, alpha=0.75, gamma=2.0, label_smoothing=0.2):
#         super().__init__()
#         self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)
#         self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
        
#     def forward(self, inputs, targets):
#         return self.ce_loss(inputs, targets) + self.focal_loss(inputs, targets)

class CombinedLoss(nn.Module):
    def __init__(self, class_weights, alpha=0.75, gamma=2.0, label_smoothing=0.2):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=gamma)
        self.class_weights = class_weights
        self.label_smoothing = label_smoothing
        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)

    def forward(self, inputs, targets):
        loss_focal = self.focal(inputs, targets)
        loss_ce = self.ce_loss(inputs, targets)
        return 0.5 * loss_focal + 0.5 * loss_ce



In [None]:
# Training State
training_state = {
    'epoch': 0,
    'best_val_acc': 0,
    'best_epoch': 0,
    'early_stop_counter': 0,
    'train_loss_history': [],
    'train_acc_history': [],
    'val_loss_history': [],
    'val_acc_history': []
}

def save_checkpoint():
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'training_state': training_state
    }
    torch.save(checkpoint, config['checkpoint_path'])
    
    with open(config['metrics_path'], 'w') as f:
        json.dump(training_state, f)
    
    print(f"Checkpoint saved at epoch {training_state['epoch']}")

def load_checkpoint():
    if os.path.exists(config['checkpoint_path']):
        checkpoint = torch.load(config['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        global training_state
        training_state = checkpoint['training_state']
        
        if os.path.exists(config['metrics_path']):
            with open(config['metrics_path'], 'r') as f:
                training_state.update(json.load(f))
        
        print(f"Resuming training from epoch {training_state['epoch'] + 1}")
        return True
    return False

def check_early_stopping():
    if training_state['early_stop_counter'] >= config['early_stop_patience']:
        print(f"\nEarly stopping triggered! No improvement for {config['early_stop_patience']} epochs.")
        save_checkpoint()
        return True
    return False


In [None]:
# Initialize Model and Training Components
model = EnhancedPathoNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5)
criterion = CombinedLoss(class_weights=class_weights)
scaler = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=5e-4, steps_per_epoch=len(train_loader), epochs=config['max_epochs'],
    pct_start=0.1, anneal_strategy='cos'
)

# Resume training if requested
if config['resume_training'] and load_checkpoint():
    start_epoch = training_state['epoch'] + 1
else:
    start_epoch = 0

# Training Loop
for epoch in range(start_epoch, config['max_epochs']):
    model.train()
    train_loss = correct = total = 0
    
    for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["max_epochs"]}'):
        inputs, targets = inputs.to(device), targets.squeeze().long().to(device)
        
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    # Validation
    model.eval()
    val_loss = val_correct = val_total = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.squeeze().long().to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += targets.size(0)
            val_correct += predicted.eq(targets).sum().item()
    
    # Calculate metrics
    train_acc = 100 * correct / total
    val_acc = 100 * val_correct / val_total
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    
    # Update training state
    training_state['epoch'] = epoch
    training_state['train_loss_history'].append(avg_train_loss)
    training_state['train_acc_history'].append(train_acc)
    training_state['val_loss_history'].append(avg_val_loss)
    training_state['val_acc_history'].append(val_acc)
    
    # Check for best model
    if val_acc > training_state['best_val_acc']:
        training_state['best_val_acc'] = val_acc
        training_state['best_epoch'] = epoch
        training_state['early_stop_counter'] = 0
        torch.save(model.state_dict(), config['best_model_path'])
        print(f'New best model saved with val_acc: {val_acc:.2f}% at epoch {epoch+1}')
    else:
        training_state['early_stop_counter'] += 1
    
    print(f'Epoch {epoch+1}: '
          f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # Save checkpoint periodically
    if (epoch + 1) % 5 == 0:
        save_checkpoint()
    
    # Early stopping check
    if check_early_stopping():
        break

# Final save
save_checkpoint()

In [None]:
# Testing
model.load_state_dict(torch.load(config['best_model_path']))
model.eval()
test_correct = test_total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.squeeze().long().to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        test_total += targets.size(0)
        test_correct += predicted.eq(targets).sum().item()

print(f'\nFinal Test Accuracy: {100 * test_correct / test_total:.2f}%')

# Save final model
torch.save(model.state_dict(), 'pathonet_final.pth')

In [None]:
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

def evaluate_and_visualize(model, dataloader, class_names, save_path="results"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    report_df = pd.DataFrame(report).transpose()

    # Save classification report
    os.makedirs(save_path, exist_ok=True)
    report_df.to_csv(os.path.join(save_path, "classification_report.csv"), index=True)

    # Save confusion matrix plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, "confusion_matrix.png"))
    plt.close()

    print("Confusion Matrix and Classification Report saved to:", save_path)

# Example usage
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = EnhancedPathoNet(num_classes=9).to(device)
    
    # --- CORRECTION HERE ---
    model.load_state_dict(torch.load("pathonet_final.pth", map_location=device))
    
    # You showed that 'info' is a dictionary with 'label' keys like '0', '1', etc
    class_names = [info['label'][str(i)] for i in range(9)]
    
    evaluate_and_visualize(model, test_loader, class_names)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import IPython.display as display

# Display the Classification Report
report_df = pd.read_csv('results/classification_report.csv', index_col=0)
display.display(report_df)

# Display the Confusion Matrix Image
img = Image.open('results/confusion_matrix.png')
plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.show()
