# PathoVision: BreakHis Binary Classification (Benign vs Malignant)
Academic support model for histopathology screening. Not a clinical diagnostic tool.

**Goals**
- Binary classification (Benign/Malignant)
- ResNet50 transfer learning
- Grad-CAM explainability
- Kaggle/Colab ready
- Exportable for backend inference

## 1. Setup

In [None]:
import os
import random
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_curve, auc, classification_report
)

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
device

## 2. Dataset Setup (BreakHis)
Set the dataset path below. Example for Kaggle: `/kaggle/input/breakhis`

In [None]:
DATA_ROOT = '/kaggle/input/breakhis'  # change if needed

# Expected BreakHis path structure (one possible layout):
# /BreaKHis_v1/histology_slides/breast/benign/SOB/.../40X/*.png
# /BreaKHis_v1/histology_slides/breast/malignant/SOB/.../40X/*.png

benign_paths = glob(os.path.join(DATA_ROOT, '**', 'benign', '**', '*.png'), recursive=True)
malignant_paths = glob(os.path.join(DATA_ROOT, '**', 'malignant', '**', '*.png'), recursive=True)

print(f'Benign images: {len(benign_paths)}')
print(f'Malignant images: {len(malignant_paths)}')
print(f'Class imbalance ratio (M/B): {len(malignant_paths)/len(benign_paths):.2f}')

all_paths = benign_paths + malignant_paths
all_labels = [0] * len(benign_paths) + [1] * len(malignant_paths)

# Build dataframe for easy splits
df = pd.DataFrame({'path': all_paths, 'label': all_labels})

# Check for corrupted images
print('\nVerifying image integrity...')
valid_indices = []
for idx, row in df.iterrows():
    try:
        img = Image.open(row['path']).convert('RGB')
        img_array = np.array(img)
        if img_array.shape == (img_array.shape[0], img_array.shape[1], 3):
            valid_indices.append(idx)
    except:
        print(f'Corrupted or invalid image: {row["path"]}')

df = df.loc[valid_indices].reset_index(drop=True)
print(f'Valid images: {len(df)} (removed {len(all_labels) - len(df)} corrupted)')
print(f'Final class distribution:\n{df["label"].value_counts().sort_index()}')
df.head()

## 3. Train/Val/Test Split (70/15/15)

In [None]:
train_df, temp_df = train_test_split(
    df, test_size=0.30, random_state=SEED, stratify=df['label']
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.50, random_state=SEED, stratify=temp_df['label']
)

print('Train:', len(train_df), 'Val:', len(val_df), 'Test:', len(test_df))
train_df['label'].value_counts(), val_df['label'].value_counts(), test_df['label'].value_counts()

## 4. Transforms and Dataset

In [None]:
IMG_SIZE = 224

# Advanced augmentation tailored for histopathology (no noise to avoid confusion with artifacts)
# Medical images should focus on geometric/intensity variations, not random noise
train_tfms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    # Geometric augmentations
    T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.85, 1.15), shear=5),
    T.RandomRotation(degrees=20),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    # Intensity augmentations (no Gaussian noise - medical images need clarity)
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.15, hue=0.05),
    # Elastic deformation might be useful but we'll skip for simplicity
    T.RandomPerspective(distortion_scale=0.2, p=0.3),
    # Normalize using ImageNet stats
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_tfms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Mixup augmentation implementation
def mixup(x1, x2, y1, y2, alpha=0.2):
    """Mixup augmentation for better generalization"""
    lam = np.random.beta(alpha, alpha)
    mixed_x = lam * x1 + (1 - lam) * x2
    mixed_y = (lam, 1 - lam, y1, y2)  # soft labels
    return mixed_x, mixed_y

class BreakHisDataset(Dataset):
    def __init__(self, df, transforms=None, use_mixup=False):
        self.df = df.reset_index(drop=True)
        self.transforms = transforms
        self.use_mixup = use_mixup

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert('RGB')
        label = int(row['label'])
        
        if self.transforms:
            img = self.transforms(img)
        
        # Mixup: randomly mix with another sample from the same batch
        if self.use_mixup and np.random.rand() < 0.3:
            idx2 = np.random.randint(0, len(self.df))
            row2 = self.df.iloc[idx2]
            img2 = Image.open(row2['path']).convert('RGB')
            label2 = int(row2['label'])
            if self.transforms:
                img2 = self.transforms(img2)
            img, (lam1, lam2, y1, y2) = mixup(img, img2, label, label2)
            # For simplicity, return original label. In full implementation, use soft targets
            return img, label
        
        return img, label

train_ds = BreakHisDataset(train_df, transforms=train_tfms, use_mixup=True)
val_ds = BreakHisDataset(val_df, transforms=val_tfms, use_mixup=False)
test_ds = BreakHisDataset(test_df, transforms=val_tfms, use_mixup=False)

# Compute class weights for weighted sampling/loss
class_counts = train_df['label'].value_counts().sort_index().values
class_weights = torch.Tensor(1.0 / class_counts)
class_weights = class_weights / class_weights.sum() * 2
sample_weights = [class_weights[label] for label in train_df['label']]
sampler = torch.utils.data.WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

BATCH_SIZE = 32  # Increased for better gradient estimates
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f'Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | Test samples: {len(test_ds)}')

## 5. Model (ResNet50 Transfer Learning)

In [None]:
from torchvision.models import ResNet50_Weights

model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Progressive unfreezing strategy:
# 1. Freeze all initially
for param in model.parameters():
    param.requires_grad = False

# 2. Unfreeze layer4 (most specific to task)
for param in model.layer4.parameters():
    param.requires_grad = True

# 3. Selectively unfreeze layer3[-1] (gradual unfreezing reduces overfitting)
for param in model.layer3[-1].parameters():
    param.requires_grad = True

# 4. Unfreeze BN layers even in frozen layers (helps adapt to new domain)
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.requires_grad = True
        # Increase momentum for batchnorm (helps stability)
        module.momentum = 0.01

# 5. Replace final linear layer with improved architecture
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(num_ftrs, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.3),
    nn.Linear(512, 2)
)

model = model.to(device)

# Compute total and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Training ratio: {100 * trainable_params / total_params:.1f}%')

# Loss function with class weights for imbalanced data
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device), label_smoothing=0.1)

# Optimizer: Use different learning rates for different layers (discriminative learning rates)
param_groups = [
    {'params': model.layer4[-1].parameters(), 'lr': 1e-3},
    {'params': model.layer3[-1].parameters(), 'lr': 5e-4},
    {'params': model.fc.parameters(), 'lr': 1e-2},
    {'params': [p for module in [model.layer4, model.layer3[-1], model.fc] for p in module.parameters()
                if not any(p is q for q in module.parameters() if hasattr(q, 'requires_grad'))],
     'lr': 1e-4}
]

# Simplified optimizer - use default LR
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
                       lr=1e-3, weight_decay=1e-5, betas=(0.9, 0.999))

print('Model ready for training')

## 6. Training Loop with Early Stopping

In [None]:
class EarlyStoppingAUC:
    """Early stopping based on validation AUC instead of loss"""
    def __init__(self, patience=7, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_auc = 0
        self.best_epoch = 0
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_auc, model):
        if val_auc > self.best_auc + self.min_delta:
            self.best_auc = val_auc
            self.counter = 0
            self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

    def load_best_model(self, model):
        if self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)

def train_one_epoch(model, loader, optimizer, criterion, scheduler=None):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    
    for images, labels in tqdm(loader, desc='Training', leave=False):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    if scheduler is not None:
        scheduler.step()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    epoch_f1 = f1_score(all_labels, all_preds, zero_division=0)
    
    return epoch_loss, epoch_acc, epoch_f1

def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_probs = [], []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating', leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    epoch_f1 = f1_score(all_labels, all_preds, zero_division=0)
    
    # Calculate AUC
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    epoch_auc = auc(fpr, tpr)
    
    return epoch_loss, epoch_acc, epoch_f1, epoch_auc

# Learning rate scheduler: Cosine annealing with warm restarts for better convergence
EPOCHS = 30
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-6)

# Early stopping based on AUC
early_stopping = EarlyStoppingAUC(patience=8, min_delta=0.002)

history = {
    'train_loss': [], 'val_loss': [], 
    'train_acc': [], 'val_acc': [],
    'train_f1': [], 'val_f1': [],
    'val_auc': []
}

print('Starting training...')
for epoch in range(EPOCHS):
    train_loss, train_acc, train_f1 = train_one_epoch(model, train_loader, optimizer, criterion, scheduler)
    val_loss, val_acc, val_f1, val_auc = evaluate(model, val_loader, criterion)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)
    history['val_auc'].append(val_auc)

    print(f'Epoch {epoch+1:2d}/{EPOCHS} | '
          f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | '
          f'Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | '
          f'Val AUC: {val_auc:.4f} | Val F1: {val_f1:.4f}')

    early_stopping(val_auc, model)
    if early_stopping.early_stop:
        print(f'Early stopping triggered at epoch {epoch+1}')
        early_stopping.load_best_model(model)
        break

print(f'Training complete! Best Val AUC: {early_stopping.best_auc:.4f}')

## 7. Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss plot
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o', markersize=3)
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s', markersize=3)
axes[0, 0].set_title('Loss vs Epoch', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy plot
axes[0, 1].plot(history['train_acc'], label='Train Acc', marker='o', markersize=3)
axes[0, 1].plot(history['val_acc'], label='Val Acc', marker='s', markersize=3)
axes[0, 1].set_title('Accuracy vs Epoch', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# F1 Score plot
axes[1, 0].plot(history['train_f1'], label='Train F1', marker='o', markersize=3)
axes[1, 0].plot(history['val_f1'], label='Val F1', marker='s', markersize=3)
axes[1, 0].set_title('F1 Score vs Epoch', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# AUC plot
axes[1, 1].plot(history['val_auc'], label='Val AUC', marker='o', color='green', markersize=3)
axes[1, 1].axhline(y=0.95, color='r', linestyle='--', label='Target (0.95)')
axes[1, 1].set_title('Validation AUC vs Epoch', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('AUC')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0.5, 1.0])

plt.tight_layout()
plt.show()

# Print training summary
print('\n' + '='*60)
print('TRAINING SUMMARY')
print('='*60)
print(f'Best Validation AUC: {max(history["val_auc"]):.4f}')
print(f'Final Train Accuracy: {history["train_acc"][-1]:.4f}')
print(f'Final Val Accuracy: {history["val_acc"][-1]:.4f}')
print(f'Final Train F1: {history["train_f1"][-1]:.4f}')
print(f'Final Val F1: {history["val_f1"][-1]:.4f}')

## 8. Evaluation on Test Set

In [None]:
print('Evaluating on test set with detailed metrics...')
model.eval()
all_labels = []
all_preds = []
all_probs = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Test Evaluation'):
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)

        all_labels.extend(labels.numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())

# Calculate comprehensive metrics
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)

acc = accuracy_score(all_labels, all_preds)
prec = precision_score(all_labels, all_preds, zero_division=0)
rec = recall_score(all_labels, all_preds, zero_division=0)
f1 = f1_score(all_labels, all_preds, zero_division=0)

# ROC-AUC
fpr, tpr, _ = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)

# Specificity and Sensitivity (for clinical relevance)
tn = ((all_preds == 0) & (all_labels == 0)).sum()
fp = ((all_preds == 1) & (all_labels == 0)).sum()
fn = ((all_preds == 0) & (all_labels == 1)).sum()
tp = ((all_preds == 1) & (all_labels == 1)).sum()

sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

print('\n' + '='*60)
print('TEST SET EVALUATION RESULTS')
print('='*60)
print(f'Accuracy:       {acc:.4f} (Overall correctness)')
print(f'Precision:      {prec:.4f} (Of predicted positive, how many correct)')
print(f'Recall/Sensitivity: {rec:.4f} (Of actual positive, how many detected)')
print(f'Specificity:    {specificity:.4f} (Of actual negative, how many detected)')
print(f'F1 Score:       {f1:.4f} (Harmonic mean of precision & recall)')
print(f'ROC-AUC:        {roc_auc:.4f} (Probability of correct ranking)')
print('='*60)

# Detailed classification report
print('\nDETAILED CLASSIFICATION REPORT:')
print(classification_report(all_labels, all_preds, target_names=['Benign', 'Malignant'], zero_division=0))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
fig, ax = plt.subplots(1, 1, figsize=(7, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Benign', 'Malignant'], 
            yticklabels=['Benign', 'Malignant'],
            cbar_kws={'label': 'Count'},
            ax=ax)
ax.set_title('Confusion Matrix - Test Set', fontsize=12, fontweight='bold')
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
plt.tight_layout()
plt.show()

# ROC Curve with confidence intervals
fig, ax = plt.subplots(1, 1, figsize=(7, 6))
ax.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.4f})', linewidth=2, color='#1f77b4')
ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random Classifier')
ax.fill_between(fpr, tpr, alpha=0.2)
ax.set_xlabel('False Positive Rate (1 - Specificity)', fontsize=11)
ax.set_ylabel('True Positive Rate (Sensitivity)', fontsize=11)
ax.set_title('ROC Curve - Test Set', fontsize=12, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f'\n✓ Test evaluation complete!')

## 9. Grad-CAM Explainability

In [None]:
import cv2

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)

    def generate(self, input_tensor, class_idx=None):
        self.model.eval()
        output = self.model(input_tensor)
        
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()

        self.model.zero_grad()
        
        # Create one-hot encoded target
        target = torch.zeros(output.shape).to(device)
        target[0, class_idx] = 1
        loss = (output * target).sum()
        loss.backward()

        gradients = self.gradients[0]
        activations = self.activations[0]
        
        # Compute weights using global average pooling
        weights = gradients.mean(dim=(1, 2))
        
        # Compute weighted sum of activations
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32).to(device)
        for i, w in enumerate(weights):
            cam += w * activations[i]

        # Apply ReLU and normalize
        cam = torch.relu(cam)
        cam = cam - cam.min()
        cam_max = cam.max()
        if cam_max > 0:
            cam = cam / cam_max
        
        cam = cam.detach().cpu().numpy()
        return cam, class_idx

def overlay_heatmap(image_path, cam, output_path='heatmap.png', alpha=0.5):
    """Create heatmap overlay with better visualization"""
    img = cv2.imread(image_path)
    if img is None:
        print(f'Failed to load image: {image_path}')
        return None
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    cam_resized = cv2.resize(cam, (img.shape[1], img.shape[0]))
    
    # Apply colormap
    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    # Blend original and heatmap
    overlay = cv2.addWeighted(img, 1-alpha, heatmap, alpha, 0)
    
    # Save
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    cv2.imwrite(output_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    
    return output_path

print('Initializing GradCAM for model layer4[-1]...')
gradcam = GradCAM(model, model.layer4[-1])

# Generate heatmaps for sample test images
print('\nGenerating Grad-CAM heatmaps for test samples...')
os.makedirs('heatmaps', exist_ok=True)

sample_indices = np.random.choice(len(test_df), min(5, len(test_df)), replace=False)
for i, idx in enumerate(sample_indices):
    sample_path = test_df.iloc[idx]['path']
    
    sample_img = Image.open(sample_path).convert('RGB')
    input_tensor = val_tfms(sample_img).unsqueeze(0).to(device)
    
    cam, pred_class = gradcam.generate(input_tensor)
    
    class_name = class_names[pred_class]
    heatmap_name = f'heatmap_sample_{i+1}_{class_name}.png'
    heatmap_path = os.path.join('heatmaps', heatmap_name)
    
    overlay_heatmap(sample_path, cam, output_path=heatmap_path, alpha=0.5)
    print(f'Saved: {heatmap_path} (Predicted: {class_name})')

print('✓ GradCAM heatmaps generated successfully')

## 10. Inference Helper

In [None]:
class_names = {0: 'Benign', 1: 'Malignant'}

# Test-Time Augmentation (TTA) for more robust predictions
def apply_tta_transforms(image, num_augmentations=5):
    """Apply multiple augmentations and return list of tensors"""
    tta_list = []
    
    for _ in range(num_augmentations):
        # Apply slight augmentations for predictions
        tta_tfm = T.Compose([
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.RandomAffine(degrees=5, translate=(0.05, 0.05)) if _ > 0 else T.Resize((IMG_SIZE, IMG_SIZE)),
            T.RandomHorizontalFlip(p=0.5) if _ > 1 else T.Compose([]),
            T.RandomRotation(degrees=5) if _ > 2 else T.Compose([]),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        tta_list.append(tta_tfm(image))
    
    return tta_list

def predict_image_tta(image_path, model, use_tta=True, num_tta=5):
    """Prediction with optional Test-Time Augmentation"""
    model.eval()
    image = Image.open(image_path).convert('RGB')
    
    with torch.no_grad():
        if use_tta:
            tta_imgs = apply_tta_transforms(image, num_tta)
            tta_probs = []
            for tta_img in tta_imgs:
                tta_img = tta_img.unsqueeze(0).to(device)
                output = model(tta_img)
                probs = torch.softmax(output, dim=1).cpu().numpy()[0]
                tta_probs.append(probs)
            
            # Average predictions across augmentations
            avg_probs = np.mean(tta_probs, axis=0)
        else:
            input_tensor = val_tfms(image).unsqueeze(0).to(device)
            output = model(input_tensor)
            avg_probs = torch.softmax(output, dim=1).cpu().numpy()[0]
    
    pred_idx = int(np.argmax(avg_probs))
    confidence = float(avg_probs[pred_idx])
    
    return {
        'prediction': class_names[pred_idx],
        'confidence': confidence,
        'probabilities': {
            'benign': float(avg_probs[0]),
            'malignant': float(avg_probs[1])
        },
        'tta_used': use_tta
    }

# Example predictions
print('Example Predictions (with Test-Time Augmentation):')
print('='*60)

# Sample some test images
sample_indices = np.random.choice(len(test_df), min(3, len(test_df)), replace=False)

for idx in sample_indices:
    sample_path = test_df.iloc[idx]['path']
    true_label = test_df.iloc[idx]['label']
    
    result = predict_image_tta(sample_path, model, use_tta=True, num_tta=5)
    
    true_name = 'Benign' if true_label == 0 else 'Malignant'
    print(f'Image: {os.path.basename(sample_path)}')
    print(f'True Label: {true_name}')
    print(f'Prediction: {result["prediction"]}')
    print(f'Confidence: {result["confidence"]:.4f}')
    print(f'Probabilities: Benign={result["probabilities"]["benign"]:.4f}, Malignant={result["probabilities"]["malignant"]:.4f}')
    print('-'*60)

## 11. Export Model

In [None]:
os.makedirs('models', exist_ok=True)

# Save model with metadata
model_info = {
    'model_state': model.state_dict(),
    'epoch': len(history['train_loss']),
    'best_auc': max(history['val_auc']) if history['val_auc'] else 0,
    'final_accuracy': history['val_acc'][-1] if history['val_acc'] else 0,
    'class_names': class_names,
    'image_size': IMG_SIZE,
    'normalization': {
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225]
    }
}

torch.save(model_info, 'models/pathovision_resnet50_v2.pt')
print('✓ Model saved to models/pathovision_resnet50_v2.pt')

# Also save the legacy format for compatibility
torch.save(model.state_dict(), 'models/pathovision_resnet50_state_dict.pt')
print('✓ State dict saved to models/pathovision_resnet50_state_dict.pt')

print('\nModel Information:')
print(f'  - Architecture: ResNet50 (ImageNet pretrained)')
print(f'  - Training Epochs: {len(history["train_loss"])}')
print(f'  - Best Validation AUC: {max(history["val_auc"]):.4f}')
print(f'  - Final Test Accuracy: {acc:.4f}')
print(f'  - Test AUC-ROC: {roc_auc:.4f}')

## 12. Simple FastAPI Inference Example

In [None]:
# FastAPI inference server - Save as inference_server.py
fastapi_code = r'''
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import uvicorn
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import io
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="PathoVision Inference Server", version="2.0")

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = 'models/pathovision_resnet50_v2.pt'
IMG_SIZE = 224

# Global model variable
model = None

# Normalization transforms
transforms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

@app.on_event("startup")
async def load_model():
    """Load model on startup"""
    global model
    try:
        logger.info(f"Loading model from {MODEL_PATH}")
        checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
        
        # Check if it's a full checkpoint with metadata or just state_dict
        if isinstance(checkpoint, dict) and 'model_state' in checkpoint:
            model_state = checkpoint['model_state']
            metadata = {k: v for k, v in checkpoint.items() if k != 'model_state'}
            logger.info(f"Model metadata: {metadata}")
        else:
            model_state = checkpoint
            logger.warning("No metadata found in checkpoint")
        
        # Build model (assuming ResNet50)
        from torchvision import models
        model = models.resnet50(weights=None)
        num_ftrs = model.fc.in_features
        
        # Recreate the custom head
        import torch.nn as nn
        model.fc = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(num_ftrs, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(512, 2)
        )
        
        model.load_state_dict(model_state)
        model.to(DEVICE)
        model.eval()
        
        logger.info(f"Model loaded successfully on device: {DEVICE}")
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")
        raise

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """
    Predict histopathology image classification
    
    Args:
        file: Image file (PNG, JPG, etc.)
    
    Returns:
        JSON with prediction, confidence, probabilities, and diagnosis
    """
    try:
        if model is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        # Read image
        image_data = await file.read()
        image = Image.open(io.BytesIO(image_data)).convert('RGB')
        
        # Preprocess
        input_tensor = transforms(image).unsqueeze(0).to(DEVICE)
        
        # Inference with no gradient
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
        
        pred_idx = int(np.argmax(probs))
        confidence = float(probs[pred_idx])
        
        # Determine diagnosis with confidence threshold
        if confidence < 0.60:
            diagnosis = "INCONCLUSIVE - Requires specialist review"
        elif pred_idx == 0:
            diagnosis = "BENIGN - Low malignancy risk"
        else:
            diagnosis = "MALIGNANT - High suspicion, urgent review recommended"
        
        return {
            "status": "success",
            "prediction": "Benign" if pred_idx == 0 else "Malignant",
            "confidence": confidence,
            "probabilities": {
                "benign": float(probs[0]),
                "malignant": float(probs[1])
            },
            "clinical_diagnosis": diagnosis,
            "timestamp": str(np.datetime64('now')),
            "model_version": "2.0"
        }
    
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        return JSONResponse(
            status_code=500,
            content={"status": "error", "message": str(e)}
        )

@app.get("/health")
async def health():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "model_loaded": model is not None,
        "device": str(DEVICE)
    }

if __name__ == '__main__':
    uvicorn.run(app, host='0.0.0.0', port=8000, log_level='info')
'''

# Save the complete API code
with open('inference_server.py', 'w') as f:
    f.write(fastapi_code)

print('✓ FastAPI inference server saved to: inference_server.py')
print('\nTo run the server:')
print('  1. pip install fastapi uvicorn pillow torch torchvision')
print('  2. python inference_server.py')
print('  3. API will be available at http://localhost:8000')
print('  4. Swagger docs: http://localhost:8000/docs')