# Ki-67 Scoring System - Google Colab Training

This notebook implements a comprehensive Ki-67 scoring system using ensemble deep learning models for breast cancer diagnosis.

## Features:
- **Three Model Ensemble**: InceptionV3, ResNet-50, and Vision Transformer
- **Google Drive Integration**: Automatic dataset loading and model saving
- **Robust Training**: Error handling and early stopping
- **Comprehensive Evaluation**: Multiple metrics and visualizations

## Setup Instructions:
1. Upload your dataset ZIP to: `MyDrive/Ki67_Dataset/Ki67_Dataset_for_Colab.zip`
2. Run cells sequentially
3. Models will be saved to your MyDrive root folder

In [None]:
# Install Required Packages
import subprocess
import sys
import os

def install_package(package):
    """Install package with error handling"""
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])
        print(f"✅ {package}")
        return True
    except subprocess.CalledProcessError:
        print(f"⚠️  {package} - may already be installed")
        return False

print("📦 Installing required packages...")
packages = ["torch", "torchvision", "scikit-learn", "matplotlib", "seaborn", "pandas", "numpy", "Pillow", "timm"]

for package in packages:
    install_package(package)

print("\n✅ Package installation completed!")

In [None]:
# Import Libraries and Setup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import json
import pickle
from datetime import datetime
import zipfile
import shutil
import warnings
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.metrics import precision_score, recall_score, f1_score

try:
    import timm
    print("✅ timm imported successfully")
except ImportError:
    print("⚠️  timm not available, will use fallback CNN")
    timm = None

warnings.filterwarnings('ignore')

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    torch.backends.cudnn.benchmark = True

print("✅ All libraries imported successfully!")

In [None]:
# Mount Google Drive and Setup Paths
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Setup save paths
MODELS_SAVE_PATH = "/content/drive/MyDrive"  # Models saved to MyDrive root
RESULTS_SAVE_PATH = "/content/drive/MyDrive/Ki67_Results"
os.makedirs(RESULTS_SAVE_PATH, exist_ok=True)

print(f"📁 Models will be saved to: {MODELS_SAVE_PATH}")
print(f"📁 Results will be saved to: {RESULTS_SAVE_PATH}")

# Check what's in your Ki67_Dataset folder
dataset_folder = "/content/drive/MyDrive/Ki67_Dataset"
if os.path.exists(dataset_folder):
    print(f"\n📋 Contents of Ki67_Dataset folder:")
    for item in os.listdir(dataset_folder):
        print(f"  - {item}")
else:
    print(f"❌ Ki67_Dataset folder not found!")

In [None]:
# Extract Dataset from Google Drive
DATASET_ZIP_PATH = "/content/drive/MyDrive/Ki67_Dataset/Ki67_Dataset_for_Colab.zip"

if os.path.exists(DATASET_ZIP_PATH):
    print(f"✅ Found dataset at: {DATASET_ZIP_PATH}")
    
    # Create extraction directory
    EXTRACT_PATH = "/content/ki67_dataset"
    os.makedirs(EXTRACT_PATH, exist_ok=True)
    
    # Extract the dataset
    print("Extracting dataset...")
    with zipfile.ZipFile(DATASET_ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_PATH)
    
    print("✅ Dataset extracted successfully!")
    
    # List extracted contents
    print("\nExtracted contents:")
    for root, dirs, files in os.walk(EXTRACT_PATH):
        level = root.replace(EXTRACT_PATH, '').count(os.sep)
        if level < 3:  # Limit depth for readability
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            if level < 2:  # Show files only for top 2 levels
                subindent = ' ' * 2 * (level + 1)
                for file in files[:3]:  # Show first 3 files
                    print(f"{subindent}{file}")
                if len(files) > 3:
                    print(f"{subindent}... and {len(files)-3} more files")
    
    DATASET_PATH = EXTRACT_PATH
    
else:
    print(f"❌ Dataset ZIP file not found!")
    print("Available files in Ki67_Dataset:")
    if os.path.exists("/content/drive/MyDrive/Ki67_Dataset"):
        for item in os.listdir("/content/drive/MyDrive/Ki67_Dataset"):
            print(f"  - {item}")
    
    # You can manually set the path here if ZIP file has different name
    DATASET_PATH = "/content/ki67_dataset"  # Will be empty but won't crash

In [None]:
# Define Custom Dataset Class
class Ki67Dataset(Dataset):
    def __init__(self, dataset_path, split='train', transform=None):
        self.dataset_path = Path(dataset_path)
        self.split = split
        self.transform = transform
        
        # Load data
        self.load_from_directory()
    
    def load_from_directory(self):
        """Load dataset from directory structure"""
        print(f"Loading {self.split} data from directory structure...")
        self.images = []
        self.labels = []
        
        # Check multiple possible directory structures
        possible_structures = [
            # Standard BCData structure
            {
                'images': self.dataset_path / "BCData" / "images" / self.split,
                'pos_annotations': self.dataset_path / "BCData" / "annotations" / self.split / "positive",
                'neg_annotations': self.dataset_path / "BCData" / "annotations" / self.split / "negative"
            },
            # Alternative structure
            {
                'images': self.dataset_path / "images" / self.split,
                'pos_annotations': self.dataset_path / "annotations" / self.split / "positive",
                'neg_annotations': self.dataset_path / "annotations" / self.split / "negative"
            },
            # Test256 structure (adapted for splits)
            {
                'images': self.dataset_path / "data" / "test256",
                'json_annotations': True
            }
        ]
        
        data_found = False
        for structure in possible_structures:
            if 'json_annotations' in structure:
                images_dir = structure['images']
                if images_dir.exists():
                    self._load_from_json_structure(images_dir)
                    data_found = True
                    break
            else:
                images_dir = structure['images']
                pos_annotations_dir = structure['pos_annotations']
                neg_annotations_dir = structure['neg_annotations']
                
                if images_dir.exists():
                    self._load_from_h5_structure(images_dir, pos_annotations_dir, neg_annotations_dir)
                    data_found = True
                    break
        
        if not data_found:
            print(f"⚠️  No data found for {self.split} split in any expected structure")
            # Show what's actually available
            print("Available directories:")
            if self.dataset_path.exists():
                for item in self.dataset_path.iterdir():
                    if item.is_dir():
                        print(f"  - {item.name}/")
            
            self.images = []
            self.labels = []
    
    def _load_from_json_structure(self, images_dir):
        """Load from JSON annotation structure and create splits"""
        all_images = []
        all_labels = []
        
        for img_file in images_dir.glob("*.jpg"):
            json_file = img_file.with_suffix('.json')
            if json_file.exists():
                try:
                    with open(json_file, 'r') as f:
                        annotation = json.load(f)
                    
                    # Determine label from JSON
                    label = 0
                    if 'shapes' in annotation and len(annotation['shapes']) > 0:
                        label = 1
                    elif 'label' in annotation:
                        label = 1 if annotation['label'] == 'positive' else 0
                    elif 'ki67_positive' in annotation:
                        label = int(annotation['ki67_positive'])
                    
                    all_images.append(str(img_file))
                    all_labels.append(label)
                except Exception as e:
                    print(f"Warning: Could not load {json_file}: {e}")
        
        # Create splits
        if all_images:
            indices = np.random.RandomState(42).permutation(len(all_images))
            n_total = len(indices)
            n_train = int(0.7 * n_total)
            n_val = int(0.15 * n_total)
            
            if self.split == 'train':
                selected_indices = indices[:n_train]
            elif self.split == 'validation':
                selected_indices = indices[n_train:n_train+n_val]
            else:  # test
                selected_indices = indices[n_train+n_val:]
            
            self.images = [all_images[i] for i in selected_indices]
            self.labels = [all_labels[i] for i in selected_indices]
    
    def _load_from_h5_structure(self, images_dir, pos_annotations_dir, neg_annotations_dir):
        """Load from h5 annotation structure"""
        for img_file in images_dir.glob("*.png"):
            img_name = img_file.stem
            pos_ann = pos_annotations_dir / f"{img_name}.h5"
            neg_ann = neg_annotations_dir / f"{img_name}.h5"

            if pos_ann.exists():
                self.images.append(str(img_file))
                self.labels.append(1)
            elif neg_ann.exists():
                self.images.append(str(img_file))
                self.labels.append(0)
        
        print(f"Loaded {len(self.images)} samples from directory")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, torch.tensor(label, dtype=torch.float32)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return fallback
            if self.transform:
                fallback = self.transform(Image.new('RGB', (640, 640), color='black'))
            else:
                fallback = torch.zeros((3, 224, 224))
            return fallback, torch.tensor(label, dtype=torch.float32)

print("✅ Dataset class defined successfully!")

In [None]:
# Create Data Transforms and Datasets
# Define data transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
print("🔄 Creating datasets...")
train_dataset = Ki67Dataset(DATASET_PATH, split='train', transform=train_transform)
val_dataset = Ki67Dataset(DATASET_PATH, split='validation', transform=val_transform)
test_dataset = Ki67Dataset(DATASET_PATH, split='test', transform=val_transform)

print(f"\n✅ Dataset creation completed!")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Check class distribution
def check_class_distribution(dataset, name):
    if len(dataset) > 0:
        labels = []
        for i in range(min(len(dataset), 100)):
            _, label = dataset[i]
            labels.append(int(label.item()))
        
        pos_count = sum(labels)
        neg_count = len(labels) - pos_count
        print(f"{name}: {pos_count} positive, {neg_count} negative (from {len(labels)} checked)")

if len(train_dataset) > 0:
    check_class_distribution(train_dataset, "Training")
    check_class_distribution(val_dataset, "Validation")
    check_class_distribution(test_dataset, "Test")
else:
    print("⚠️  No data loaded. Please check your dataset structure.")
    print("Expected structures:")
    print("1. BCData/images/{train,validation,test}/ with BCData/annotations/{train,validation,test}/{positive,negative}/")
    print("2. images/{train,validation,test}/ with annotations/{train,validation,test}/{positive,negative}/")
    print("3. data/test256/ with .jpg and .json files")

In [None]:
# Create Data Loaders (only if we have data)
if len(train_dataset) > 0:
    batch_size = 32 if torch.cuda.is_available() else 16
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"✅ Data loaders created with batch size: {batch_size}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    
    # Test loading a batch
    try:
        sample_batch = next(iter(train_loader))
        images, labels = sample_batch
        print(f"Sample batch shape: images={images.shape}, labels={labels.shape}")
        print("✅ Data loading test successful!")
    except Exception as e:
        print(f"❌ Data loading test failed: {e}")
        
else:
    print("❌ Cannot create data loaders - no training data available")
    print("Please check your dataset and re-run the dataset creation cell")

In [None]:
# Model Creation and Setup (only if we have data)
if len(train_dataset) > 0:
    print("🏗️ Creating models...")
    
    # Calculate class weights
    labels = []
    for i in range(len(train_dataset)):
        _, label = train_dataset[i]
        labels.append(int(label.item()))
    
    pos_count = sum(labels)
    neg_count = len(labels) - pos_count
    
    if pos_count > 0 and neg_count > 0:
        pos_weight = len(labels) / (2 * pos_count)
        neg_weight = len(labels) / (2 * neg_count)
        pos_weight_ratio = pos_weight / neg_weight
    else:
        pos_weight_ratio = 1.0
    
    print(f"Class distribution: {neg_count} negative, {pos_count} positive")
    print(f"Positive weight ratio: {pos_weight_ratio:.3f}")
    
    # Create models
    models_dict = {}
    
    try:
        # InceptionV3
        inception_model = models.inception_v3(pretrained=True)
        inception_model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(inception_model.fc.in_features, 1),
            nn.Sigmoid()
        )
        inception_model = inception_model.to(device)
        
        models_dict['inception'] = {
            'model': inception_model,
            'criterion': nn.BCELoss(),
            'optimizer': optim.Adam(inception_model.parameters(), lr=0.001, weight_decay=1e-4),
            'scheduler': ReduceLROnPlateau(optim.Adam(inception_model.parameters(), lr=0.001, weight_decay=1e-4), mode='min', factor=0.1, patience=5),
            'name': 'InceptionV3'
        }
        print("✅ InceptionV3 model created")
        
        # ResNet-50
        resnet_model = models.resnet50(pretrained=False)
        resnet_model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(resnet_model.fc.in_features, 1)
        )
        resnet_model = resnet_model.to(device)
        
        models_dict['resnet'] = {
            'model': resnet_model,
            'criterion': nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight_ratio).to(device)),
            'optimizer': optim.SGD(resnet_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4),
            'scheduler': StepLR(optim.SGD(resnet_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4), step_size=10, gamma=0.1),
            'name': 'ResNet50'
        }
        print("✅ ResNet-50 model created")
        
        # ViT or CNN fallback
        try:
            if timm is not None:
                vit_model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=1)
                vit_model = nn.Sequential(vit_model, nn.Sigmoid()).to(device)
                print("✅ ViT model created")
            else:
                raise ImportError("timm not available")
        except:
            # Fallback CNN
            class SimpleCNN(nn.Module):
                def __init__(self):
                    super(SimpleCNN, self).__init__()
                    self.features = nn.Sequential(
                        nn.Conv2d(3, 32, 3, padding=1),
                        nn.ReLU(),
                        nn.MaxPool2d(2),
                        nn.Conv2d(32, 64, 3, padding=1),
                        nn.ReLU(),
                        nn.MaxPool2d(2),
                        nn.AdaptiveAvgPool2d(7)
                    )
                    self.classifier = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(64 * 7 * 7, 128),
                        nn.ReLU(),
                        nn.Dropout(0.5),
                        nn.Linear(128, 1),
                        nn.Sigmoid()
                    )
                
                def forward(self, x):
                    x = self.features(x)
                    x = self.classifier(x)
                    return x
            
            vit_model = SimpleCNN().to(device)
            print("✅ Simple CNN created as ViT fallback")
        
        models_dict['vit'] = {
            'model': vit_model,
            'criterion': nn.BCELoss(),
            'optimizer': optim.Adam(vit_model.parameters(), lr=1e-3, weight_decay=1e-4),
            'scheduler': ReduceLROnPlateau(optim.Adam(vit_model.parameters(), lr=1e-3, weight_decay=1e-4), mode='min', factor=0.1, patience=5),
            'name': 'ViT'
        }
        
        # Fix optimizer references in schedulers
        for key, model_info in models_dict.items():
            if 'ReduceLR' in str(type(model_info['scheduler'])):
                model_info['scheduler'] = ReduceLROnPlateau(model_info['optimizer'], mode='min', factor=0.1, patience=5)
            elif 'StepLR' in str(type(model_info['scheduler'])):
                model_info['scheduler'] = StepLR(model_info['optimizer'], step_size=10, gamma=0.1)
        
        print(f"\n✅ All models created successfully!")
        
        # Print model info
        def count_parameters(model):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"\n📊 Model Parameters:")
        for key, model_info in models_dict.items():
            count = count_parameters(model_info['model'])
            print(f"  {model_info['name']}: {count:,}")
            
    except Exception as e:
        print(f"❌ Error creating models: {e}")
        models_dict = {}
        
else:
    print("⚠️  Skipping model creation - no training data available")
    models_dict = {}

In [None]:
# Model Saving Functions
def save_model_to_drive(model, model_name, epoch, val_loss, val_acc, save_path):
    """Save model checkpoint to Google Drive"""
    try:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"Ki67_{model_name}_best_model_{timestamp}.pth"
        full_path = os.path.join(save_path, filename)
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
            'timestamp': timestamp,
            'model_name': model_name,
            'performance_summary': f"Epoch {epoch}, Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%"
        }, full_path)
        
        print(f"✅ Model saved to MyDrive: {filename}")
        return full_path
    except Exception as e:
        print(f"❌ Failed to save model {model_name}: {e}")
        return None

def save_training_history(history, model_name, save_path):
    """Save training history"""
    try:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{model_name}_history_{timestamp}.pkl"
        full_path = os.path.join(save_path, filename)
        
        with open(full_path, 'wb') as f:
            pickle.dump(history, f)
        
        print(f"✅ Training history saved: {filename}")
        return full_path
    except Exception as e:
        print(f"❌ Failed to save history: {e}")
        return None

print("✅ Model saving functions defined!")

In [None]:
# Training Function
def train_individual_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                          model_name, device, num_epochs=15, use_aux_loss=False, 
                          early_stopping_patience=7):
    """Train individual model with error handling"""
    print(f"\n🚀 Training {model_name}...")
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    best_val_loss = float('inf')
    best_val_acc = 0.0
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs} - {model_name}")
        print("-" * 40)
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            try:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Fix label format
                if labels.dim() == 1:
                    labels = labels.unsqueeze(1)
                labels = labels.float()
                labels = torch.clamp(labels, 0.0, 1.0)
                
                # Adjust input size for InceptionV3
                if model_name == "InceptionV3" and inputs.size(-1) != 299:
                    inputs = F.interpolate(inputs, size=(299, 299), mode='bilinear', align_corners=False)
                
                optimizer.zero_grad()
                
                # Forward pass
                if use_aux_loss and model.training:
                    outputs, aux_outputs = model(inputs)
                    if outputs.dim() == 1:
                        outputs = outputs.unsqueeze(1)
                    if aux_outputs.dim() == 1:
                        aux_outputs = aux_outputs.unsqueeze(1)
                    
                    outputs = torch.clamp(outputs, 1e-7, 1 - 1e-7)
                    aux_outputs = torch.clamp(aux_outputs, 1e-7, 1 - 1e-7)
                    
                    main_loss = criterion(outputs, labels)
                    aux_loss = criterion(aux_outputs, labels)
                    loss = main_loss + 0.4 * aux_loss
                else:
                    outputs = model(inputs)
                    if outputs.dim() == 1:
                        outputs = outputs.unsqueeze(1)
                    
                    if isinstance(criterion, nn.BCELoss):
                        outputs = torch.clamp(outputs, 1e-7, 1 - 1e-7)
                    
                    loss = criterion(outputs, labels)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    continue
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                train_loss += loss.item()
                
                # Calculate accuracy
                if isinstance(criterion, nn.BCEWithLogitsLoss):
                    predicted = (torch.sigmoid(outputs) > 0.5).float()
                else:
                    predicted = (outputs > 0.5).float()
                
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
                
            except RuntimeError as e:
                print(f"Error in batch {batch_idx}: {e}")
                optimizer.zero_grad()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                continue
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                try:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    
                    if labels.dim() == 1:
                        labels = labels.unsqueeze(1)
                    labels = labels.float()
                    labels = torch.clamp(labels, 0.0, 1.0)
                    
                    if model_name == "InceptionV3" and inputs.size(-1) != 299:
                        inputs = F.interpolate(inputs, size=(299, 299), mode='bilinear', align_corners=False)
                    
                    outputs = model(inputs)
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    
                    if outputs.dim() == 1:
                        outputs = outputs.unsqueeze(1)
                    
                    if isinstance(criterion, nn.BCELoss):
                        outputs = torch.clamp(outputs, 1e-7, 1 - 1e-7)
                    
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    
                    if isinstance(criterion, nn.BCEWithLogitsLoss):
                        predicted = (torch.sigmoid(outputs) > 0.5).float()
                    else:
                        predicted = (outputs > 0.5).float()
                    
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
                    
                except Exception as e:
                    continue
        
        # Calculate averages
        train_loss = train_loss / len(train_loader) if len(train_loader) > 0 else 0
        val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0
        train_acc = 100 * train_correct / train_total if train_total > 0 else 0
        val_acc = 100 * val_correct / val_total if val_total > 0 else 0
        
        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Print epoch results
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            print("✅ New best model found!")
            
            # Save to Google Drive
            model_path = save_model_to_drive(model, model_name, epoch+1, val_loss, val_acc, MODELS_SAVE_PATH)
        else:
            patience_counter += 1
        
        # Step scheduler
        if hasattr(scheduler, 'step'):
            if 'ReduceLR' in str(type(scheduler)):
                scheduler.step(val_loss)
            elif 'Cyclic' not in str(type(scheduler)):
                scheduler.step()
        
        # Early stopping
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f"✅ Best {model_name} model loaded!")
    
    # Save training history
    save_training_history(history, model_name, RESULTS_SAVE_PATH)
    
    return history, best_val_loss

print("✅ Training function defined!")

In [None]:
# Execute Training (only if we have data and models)
if len(train_dataset) > 0 and len(models_dict) > 0:
    print("🚀 Starting training process...")
    
    NUM_EPOCHS = 10  # Reduced for faster testing
    individual_histories = {}
    individual_best_losses = {}
    
    session_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    print(f"🕐 Training session: {session_timestamp}")
    
    # Train each model
    for key, model_info in models_dict.items():
        try:
            print(f"\n{'='*60}")
            print(f"🏗️ TRAINING {model_info['name'].upper()} MODEL")
            print(f"{'='*60}")
            
            use_aux_loss = (model_info['name'] == 'InceptionV3')
            
            history, best_loss = train_individual_model(
                model_info['model'], train_loader, val_loader,
                model_info['criterion'], model_info['optimizer'], model_info['scheduler'],
                model_info['name'], device, NUM_EPOCHS, use_aux_loss=use_aux_loss
            )
            
            individual_histories[model_info['name']] = history
            individual_best_losses[model_info['name']] = best_loss
            
            print(f"✅ {model_info['name']} training completed")
            
            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"❌ {model_info['name']} training failed: {e}")
            individual_histories[model_info['name']] = {
                'train_loss': [1.0], 'val_loss': [1.0],
                'train_acc': [50.0], 'val_acc': [50.0]
            }
            individual_best_losses[model_info['name']] = 1.0
    
    print(f"\n{'='*60}")
    print("✅ TRAINING PROCESS COMPLETED!")
    print(f"{'='*60}")
    
    # Display summary
    print(f"\n📊 Training Summary:")
    for model_name, best_loss in individual_best_losses.items():
        final_val_acc = max(individual_histories[model_name]['val_acc'])
        print(f"  {model_name}: Best Loss={best_loss:.4f}, Best Acc={final_val_acc:.2f}%")
    
    # Calculate ensemble weights
    total_acc = sum(max(hist['val_acc']) for hist in individual_histories.values())
    if total_acc > 0:
        ensemble_weights = []
        for model_name in individual_histories.keys():
            val_acc = max(individual_histories[model_name]['val_acc'])
            weight = val_acc / total_acc
            ensemble_weights.append(weight)
        
        print(f"\n⚖️ Calculated Ensemble Weights:")
        for i, (model_name, weight) in enumerate(zip(individual_histories.keys(), ensemble_weights)):
            print(f"  {model_name}: {weight:.4f}")
    else:
        ensemble_weights = [1/3, 1/3, 1/3]
        print(f"\n⚖️ Using equal ensemble weights (fallback)")
    
    # Save ensemble weights
    try:
        ensemble_path = os.path.join(MODELS_SAVE_PATH, f"Ki67_ensemble_weights_{session_timestamp}.json")
        with open(ensemble_path, 'w') as f:
            json.dump({
                'weights': ensemble_weights,
                'model_order': list(individual_histories.keys()),
                'session_timestamp': session_timestamp,
                'best_losses': individual_best_losses,
                'description': 'Ensemble weights for Ki67 classification'
            }, f, indent=2)
        print(f"✅ Ensemble weights saved to MyDrive: Ki67_ensemble_weights_{session_timestamp}.json")
    except Exception as e:
        print(f"⚠️  Could not save ensemble weights: {e}")
    
    print("\n🎉 Training completed! Check your MyDrive for saved models.")
    
else:
    print("❌ Cannot start training - missing data or models")
    if len(train_dataset) == 0:
        print("  - No training data available")
    if len(models_dict) == 0:
        print("  - Models not created successfully")

In [None]:
# Evaluation Function
def evaluate_models_and_save(models_dict, test_loader, device, ensemble_weights):
    """Evaluate models and save results"""
    if len(test_loader.dataset) == 0:
        print("❌ No test data available for evaluation")
        return {}
    
    print("🔍 Evaluating models on test set...")
    
    # Set models to evaluation mode
    for model_info in models_dict.values():
        model_info['model'].eval()
    
    predictions = {}
    model_names = list(models_dict.keys())
    for key in model_names:
        predictions[models_dict[key]['name']] = []
    predictions['Ensemble'] = []
    
    y_true = []
    
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            try:
                inputs, labels = inputs.to(device), labels.to(device)
                
                if labels.dim() == 1:
                    labels = labels.unsqueeze(1)
                labels = labels.float()
                
                model_outputs = {}
                
                # Get predictions from each model
                for key, model_info in models_dict.items():
                    try:
                        model_inputs = inputs
                        # Adjust for InceptionV3
                        if model_info['name'] == "InceptionV3" and inputs.size(-1) != 299:
                            model_inputs = F.interpolate(inputs, size=(299, 299), mode='bilinear', align_corners=False)
                        
                        outputs = model_info['model'](model_inputs)
                        if isinstance(outputs, tuple):
                            outputs = outputs[0]
                        if outputs.dim() == 1:
                            outputs = outputs.unsqueeze(1)
                        
                        # Apply sigmoid if needed
                        if isinstance(model_info['criterion'], nn.BCEWithLogitsLoss):
                            outputs = torch.sigmoid(outputs)
                        
                        model_outputs[model_info['name']] = outputs
                        predictions[model_info['name']].extend(outputs.cpu().numpy())
                        
                    except Exception as e:
                        predictions[model_info['name']].extend([[0.5]] * len(labels))
                        model_outputs[model_info['name']] = torch.ones_like(labels) * 0.5
                
                # Ensemble prediction
                try:
                    model_names_ordered = ['InceptionV3', 'ResNet50', 'ViT']
                    ensemble_pred = torch.zeros_like(labels)
                    for i, name in enumerate(model_names_ordered):
                        if name in model_outputs:
                            ensemble_pred += ensemble_weights[i] * model_outputs[name]
                    
                    predictions['Ensemble'].extend(ensemble_pred.cpu().numpy())
                except:
                    # Fallback to average
                    avg_pred = torch.mean(torch.stack(list(model_outputs.values())), dim=0)
                    predictions['Ensemble'].extend(avg_pred.cpu().numpy())
                
                y_true.extend(labels.cpu().numpy())
                
            except Exception as e:
                continue
    
    # Calculate metrics
    y_true = np.array(y_true).reshape(-1)
    
    print(f"\n📊 Evaluation Results:")
    print("="*50)
    
    results = {}
    
    for model_name, preds in predictions.items():
        if len(preds) > 0:
            scores = np.array(preds).reshape(-1)
            pred_binary = (scores > 0.5).astype(int)
            
            # Calculate metrics
            accuracy = (pred_binary == y_true).mean() * 100
            
            try:
                if len(np.unique(y_true)) > 1:
                    auc = roc_auc_score(y_true, scores) * 100
                else:
                    auc = 50.0
            except:
                auc = 50.0
            
            try:
                precision = precision_score(y_true, pred_binary, zero_division=0) * 100
                recall = recall_score(y_true, pred_binary, zero_division=0) * 100
                f1 = f1_score(y_true, pred_binary, zero_division=0) * 100
            except:
                precision = recall = f1 = 0.0
            
            results[model_name] = {
                'accuracy': accuracy,
                'auc': auc,
                'precision': precision,
                'recall': recall,
                'f1_score': f1
            }
            
            print(f"{model_name:12}: Acc={accuracy:6.2f}%, AUC={auc:6.2f}%, F1={f1:6.2f}%")
    
    print("="*50)
    
    # Find best model
    if results:
        best_model = max(results.keys(), key=lambda k: results[k]['accuracy'])
        print(f"🏆 Best model: {best_model} (Accuracy: {results[best_model]['accuracy']:.2f}%)")
    
    # Save results
    try:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = os.path.join(RESULTS_SAVE_PATH, f"Ki67_Results_Summary_{timestamp}.json")
        
        with open(results_file, 'w') as f:
            json.dump({
                'timestamp': timestamp,
                'results': results,
                'ensemble_weights': ensemble_weights,
                'test_set_size': len(y_true)
            }, f, indent=2)
        
        print(f"✅ Results saved: Ki67_Results_Summary_{timestamp}.json")
        
    except Exception as e:
        print(f"⚠️  Could not save results: {e}")
    
    return results

print("✅ Evaluation function defined!")

In [None]:
# Run Evaluation (only if we have trained models and test data)
if len(train_dataset) > 0 and len(models_dict) > 0 and len(test_dataset) > 0 and 'ensemble_weights' in locals():
    try:
        results = evaluate_models_and_save(models_dict, test_loader, device, ensemble_weights)
        
        print(f"\n🎯 Evaluation completed!")
        print(f"📁 All results saved to Google Drive")
        
        # Final summary
        print(f"\n📊 Final Summary:")
        print(f"  Models saved to: MyDrive/")
        print(f"  Results saved to: MyDrive/Ki67_Results/")
        if results:
            best_model = max(results.keys(), key=lambda k: results[k]['accuracy'])
            print(f"  Best model: {best_model} ({results[best_model]['accuracy']:.2f}% accuracy)")
            
    except Exception as e:
        print(f"❌ Evaluation failed: {e}")
        
else:
    print("⚠️  Skipping evaluation - missing requirements:")
    if len(train_dataset) == 0:
        print("  - No training data")
    if len(models_dict) == 0:
        print("  - No models created")
    if len(test_dataset) == 0:
        print("  - No test data")
    if 'ensemble_weights' not in locals():
        print("  - No ensemble weights from training")

## Summary and Next Steps

### Files Saved to Your Google Drive:

**MyDrive/** (Model files):
- `Ki67_InceptionV3_best_model_YYYYMMDD_HHMMSS.pth`
- `Ki67_ResNet50_best_model_YYYYMMDD_HHMMSS.pth` 
- `Ki67_ViT_best_model_YYYYMMDD_HHMMSS.pth`
- `Ki67_ensemble_weights_YYYYMMDD_HHMMSS.json`

**MyDrive/Ki67_Results/** (Analysis files):
- Training histories (`.pkl` files)
- Results summaries (`.json` files)
- Detailed predictions and metrics

### Loading Saved Models:

```python
import torch

# Load a saved model
checkpoint = torch.load('/content/drive/MyDrive/Ki67_InceptionV3_best_model_YYYYMMDD_HHMMSS.pth')
print(f"Model performance: {checkpoint['performance_summary']}")

# Load ensemble weights
import json
with open('/content/drive/MyDrive/Ki67_ensemble_weights_YYYYMMDD_HHMMSS.json', 'r') as f:
    ensemble_config = json.load(f)
    weights = ensemble_config['weights']
    model_order = ensemble_config['model_order']
```

### Troubleshooting:
- If no data loads: Check your ZIP file structure
- If training fails: Verify sufficient GPU memory
- If models don't save: Ensure Google Drive has space

The ensemble approach provides robust Ki-67 classification by combining three different architectures!