# Mayo Clinic STRIP AI - Patch-Based Vision Transformer Training

This notebook continues training the patch-based ViT model on Google Colab with free GPU access.

**Current Progress:** 73.63% validation accuracy (Epoch 2/30)

**Setup Instructions:**
1. Runtime → Change runtime type → GPU (T4)
2. Upload your checkpoint file when prompted
3. Run all cells

## 1. Setup Environment

In [None]:
# Install dependencies
!pip install -q transformers albumentations kaggle

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Setup Kaggle API and Download Dataset

In [None]:
# Upload your kaggle.json file
from google.colab import files
print("Please upload your kaggle.json file:")
uploaded = files.upload()

# Setup Kaggle credentials
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download the Mayo Clinic STRIP AI dataset
!kaggle competitions download -c mayo-clinic-strip-ai
!unzip -q mayo-clinic-strip-ai.zip -d raw_data
print("Dataset downloaded!")

## 3. Prepare Data (Patient-Level Split)

In [None]:
import os
import pandas as pd
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split

def reorganize_data(raw_data_dir, processed_dir):
    """Reorganize data into patient-level directory structure."""
    os.makedirs(processed_dir, exist_ok=True)
    
    # Read training CSV
    train_df = pd.read_csv(os.path.join(raw_data_dir, 'train.csv'))
    
    # Group by patient and label
    for label in train_df['label'].unique():
        label_dir = os.path.join(processed_dir, label)
        os.makedirs(label_dir, exist_ok=True)
    
    # Copy images to patient directories
    for _, row in train_df.iterrows():
        patient_id = row['patient_id']
        label = row['label']
        image_id = row['image_id']
        
        patient_dir = os.path.join(processed_dir, label, f"patient_{patient_id}")
        os.makedirs(patient_dir, exist_ok=True)
        
        src = os.path.join(raw_data_dir, 'train', f"{image_id}.jpg")
        dst = os.path.join(patient_dir, f"{image_id}.jpg")
        
        if os.path.exists(src) and not os.path.exists(dst):
            shutil.copy2(src, dst)
    
    return train_df

def create_patient_splits(df, processed_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """Create patient-level train/val/test splits."""
    patients = df['patient_id'].unique()
    
    train_patients, temp_patients = train_test_split(patients, test_size=(val_ratio + test_ratio), random_state=42)
    val_patients, test_patients = train_test_split(temp_patients, test_size=test_ratio/(val_ratio + test_ratio), random_state=42)
    
    # Create split directories
    for split in ['train', 'val', 'test']:
        for label in df['label'].unique():
            os.makedirs(os.path.join(processed_dir, split, label), exist_ok=True)
    
    # Move patient directories to splits
    for label in df['label'].unique():
        label_dir = os.path.join(processed_dir, label)
        for patient_dir in os.listdir(label_dir):
            if not patient_dir.startswith('patient_'):
                continue
            
            patient_id = int(patient_dir.split('_')[1])
            src = os.path.join(label_dir, patient_dir)
            
            if patient_id in train_patients:
                dst = os.path.join(processed_dir, 'train', label, patient_dir)
            elif patient_id in val_patients:
                dst = os.path.join(processed_dir, 'val', label, patient_dir)
            else:
                dst = os.path.join(processed_dir, 'test', label, patient_dir)
            
            if os.path.exists(src):
                shutil.move(src, dst)
    
    # Clean up empty label directories
    for label in df['label'].unique():
        label_dir = os.path.join(processed_dir, label)
        if os.path.exists(label_dir) and not os.listdir(label_dir):
            os.rmdir(label_dir)

# Run data preparation
print("Reorganizing data...")
df = reorganize_data('raw_data', 'processed_data')
print("Creating patient-level splits...")
create_patient_splits(df, 'processed_data')
print("Data preparation complete!")

## 4. Setup Model Code

In [None]:
# Create directory structure
!mkdir -p src/data config experiments/patch_based/checkpoints

# Patch dataset code
patch_dataset_code = '''import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

class PatchDataset(Dataset):
    def __init__(self, data_dir, split="train", patch_size=224, num_patches_per_image=16, transform=None, mode="random"):
        self.data_dir = data_dir
        self.split = split
        self.patch_size = patch_size
        self.num_patches_per_image = num_patches_per_image
        self.transform = transform
        self.mode = mode
        
        self.classes = ["CE", "LAA"]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        self.samples = self._load_samples()
    
    def _load_samples(self):
        samples = []
        split_dir = os.path.join(self.data_dir, self.split)
        
        for class_name in self.classes:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.exists(class_dir):
                continue
            
            for patient_dir in os.listdir(class_dir):
                patient_path = os.path.join(class_dir, patient_dir)
                if not os.path.isdir(patient_path):
                    continue
                
                for img_file in os.listdir(patient_path):
                    if img_file.endswith((".jpg", ".jpeg", ".png")):
                        img_path = os.path.join(patient_path, img_file)
                        samples.append((img_path, self.class_to_idx[class_name]))
        
        return samples
    
    def __len__(self):
        return len(self.samples) * self.num_patches_per_image
    
    def __getitem__(self, idx):
        image_idx = idx // self.num_patches_per_image
        patch_idx = idx % self.num_patches_per_image
        
        img_path, label = self.samples[image_idx]
        image = Image.open(img_path).convert("RGB")
        image_np = np.array(image)
        
        h, w = image_np.shape[:2]
        
        if self.mode == "random":
            max_y = max(0, h - self.patch_size)
            max_x = max(0, w - self.patch_size)
            y = np.random.randint(0, max_y + 1) if max_y > 0 else 0
            x = np.random.randint(0, max_x + 1) if max_x > 0 else 0
            patch = image_np[y:y+self.patch_size, x:x+self.patch_size]
        else:
            grid_h = int(np.sqrt(self.num_patches_per_image))
            grid_w = self.num_patches_per_image // grid_h
            row = patch_idx // grid_w
            col = patch_idx % grid_w
            y = int(row * h / grid_h)
            x = int(col * w / grid_w)
            patch = image_np[y:min(y+self.patch_size, h), x:min(x+self.patch_size, w)]
        
        if patch.shape[0] != self.patch_size or patch.shape[1] != self.patch_size:
            patch = np.array(Image.fromarray(patch).resize((self.patch_size, self.patch_size)))
        
        if self.transform:
            patch = self.transform(image=patch)["image"]
        
        return patch, label

def get_patch_transforms(config, train=True):
    if train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
'''

with open('src/data/patch_dataset.py', 'w') as f:
    f.write(patch_dataset_code)

print("Model code created!")

## 5. Create Config File

In [None]:
import yaml

config = {
    'data': {
        'data_dir': 'processed_data',
        'patch_size': 224,
        'num_patches_per_image': 16,
        'patch_mode': 'random',
        'batch_size': 32,
        'num_workers': 2
    },
    'model': {
        'model_name': 'google/vit-base-patch16-224-in21k',
        'pretrained': True,
        'num_classes': 2
    },
    'loss': {
        'type': 'weighted_cross_entropy',
        'class_weights': [0.69, 1.82]
    },
    'training': {
        'num_epochs': 30,
        'learning_rate': 5e-5,
        'weight_decay': 0.01,
        'min_learning_rate': 1e-6,
        'early_stopping_patience': 10
    }
}

with open('config/mayo_patch_config.yaml', 'w') as f:
    yaml.dump(config, f)

print("Config file created!")

## 6. Upload Checkpoint (Optional - to resume from local training)

In [None]:
from google.colab import files

print("Upload your best_model.pth checkpoint file (or skip to train from scratch):")
uploaded = files.upload()

if uploaded:
    # Move checkpoint to correct location
    import shutil
    checkpoint_name = list(uploaded.keys())[0]
    shutil.move(checkpoint_name, 'experiments/patch_based/checkpoints/best_model.pth')
    print(f"Checkpoint uploaded: {checkpoint_name}")
    RESUME = True
else:
    print("No checkpoint uploaded, training from scratch")
    RESUME = False

## 7. Training Function

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from tqdm.notebook import tqdm
import sys
sys.path.append('src')
from data.patch_dataset import PatchDataset, get_patch_transforms

def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc=f'Training Epoch {epoch}')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': f'{running_loss/len(pbar):.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    return running_loss / len(dataloader), correct / total

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': f'{running_loss/len(pbar):.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    return running_loss / len(dataloader), correct / total

print("Training functions defined!")

## 8. Run Training

In [None]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load config
with open('config/mayo_patch_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Create datasets
train_transform = get_patch_transforms(config, train=True)
val_transform = get_patch_transforms(config, train=False)

train_dataset = PatchDataset(
    data_dir=config['data']['data_dir'],
    split='train',
    patch_size=config['data']['patch_size'],
    num_patches_per_image=config['data']['num_patches_per_image'],
    transform=train_transform,
    mode=config['data']['patch_mode']
)

val_dataset = PatchDataset(
    data_dir=config['data']['data_dir'],
    split='val',
    patch_size=config['data']['patch_size'],
    num_patches_per_image=config['data']['num_patches_per_image'],
    transform=val_transform,
    mode=config['data']['patch_mode']
)

print(f'Train set: {len(train_dataset.samples)} images')
print(f'Val set: {len(val_dataset.samples)} images')

train_loader = DataLoader(train_dataset, batch_size=config['data']['batch_size'], shuffle=True, num_workers=config['data']['num_workers'])
val_loader = DataLoader(val_dataset, batch_size=config['data']['batch_size'], shuffle=False, num_workers=config['data']['num_workers'])

# Create model
print(f"Loading model: {config['model']['model_name']}")
model = ViTForImageClassification.from_pretrained(
    config['model']['model_name'],
    num_labels=2,
    ignore_mismatched_sizes=True
).to(device)

print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')

# Loss and optimizer
class_weights = torch.tensor(config['loss']['class_weights'], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=config['training']['learning_rate'], weight_decay=config['training']['weight_decay'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['training']['num_epochs'], eta_min=config['training'].get('min_learning_rate', 1e-6))

# Resume from checkpoint if available
start_epoch = 1
best_val_acc = 0.0
patience_counter = 0

if RESUME and os.path.exists('experiments/patch_based/checkpoints/best_model.pth'):
    print('Loading checkpoint...')
    checkpoint = torch.load('experiments/patch_based/checkpoints/best_model.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_acc = checkpoint['val_acc']
    print(f'Resuming from epoch {start_epoch} (best val acc: {best_val_acc:.4f})')

# Training loop
for epoch in range(start_epoch, config['training']['num_epochs'] + 1):
    print(f'\nEpoch {epoch}/{config["training"]["num_epochs"]}')
    print('-' * 50)
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    print(f'Learning Rate: {current_lr:.6f}')
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': config
        }, 'experiments/patch_based/checkpoints/best_model.pth')
        print(f'✓ Saved best model (val_acc: {val_acc:.4f})')
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= config['training']['early_stopping_patience']:
        print(f'\nEarly stopping triggered after {epoch} epochs')
        break

print(f'\nTraining complete!')
print(f'Best validation accuracy: {best_val_acc:.4f}')

## 9. Download Trained Model

In [None]:
from google.colab import files

print("Downloading trained model checkpoint...")
files.download('experiments/patch_based/checkpoints/best_model.pth')
print("Download complete!")