In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import requests
import zipfile
from tqdm import tqdm
import random
import time

# Set random seeds for reproducibility
def set_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

set_seeds()

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data preparation functions
def download_data(source, destination):
    """Download and extract dataset if it doesn't exist."""
    dest_path = Path(destination)
    if not dest_path.exists():
        dest_path.mkdir(parents=True, exist_ok=True)
    file_name = source.split('/')[-1]
    file_path = dest_path / file_name
    if not file_path.exists():
        with open(file_path, "wb") as f:
            request = requests.get(source)
            print(f"Downloading {file_name}...")
            f.write(request.content)
    else:
        print(f"{file_name} already exists")

    extract_path = dest_path / "tiny-imagenet-200"
    if not extract_path.exists():
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            print(f"Extracting {file_name}...")
            zip_ref.extractall(dest_path)
    else:
        print(f"{extract_path} already exists")

    return dest_path / "tiny-imagenet-200"

class CustomImageDataset(Dataset):
    """Custom dataset for TinyImageNet validation set."""
    def __init__(self, img_dir, annotation_file, class_to_idx, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.annotations = self._load_annotations(annotation_file)
        self.class_to_idx = class_to_idx
        
    def _load_annotations(self, annotation_file):
        annotations = {}
        with open(annotation_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                img_name = parts[0]
                class_name = parts[1]
                annotations[img_name] = class_name
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_name = list(self.annotations.keys())[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        label = self.annotations[img_name]
        label_idx = self.class_to_idx[label]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_idx

# Patch Embedding Layer
class PatchEmbedding(nn.Module):
    """Convert input images to patch embeddings."""
    def __init__(self, img_size=128, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Use convolution for patch embedding
        self.proj = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
    def forward(self, x):
        # Input: [batch_size, in_channels, img_size, img_size]
        # Output: [batch_size, num_patches, embed_dim]
        
        # Apply convolution to create patches
        x = self.proj(x)  # [batch_size, embed_dim, grid_size, grid_size]
        
        # Flatten spatial dimensions
        x = x.flatten(2)  # [batch_size, embed_dim, num_patches]
        
        # Transpose to get [batch_size, num_patches, embed_dim]
        x = x.transpose(1, 2)
        
        return x

# Multi-head Self-Attention
class MultiHeadAttention(nn.Module):
    """Multi-head self-attention module with optimizations."""
    def __init__(self, dim, num_heads, dropout=0.0, qkv_bias=True):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Combined projection for Q, K, V for efficiency
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)
        self.attn_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        # x shape: [batch_size, num_patches+1, embed_dim]
        batch_size, num_tokens, dim = x.shape
        
        # Calculate query, key, value vectors
        qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch_size, num_heads, num_tokens, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Calculate attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [batch_size, num_heads, num_tokens, num_tokens]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(batch_size, num_tokens, dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

# MLP Block
class MLP(nn.Module):
    """MLP layer with GELU activation."""
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    """Transformer block with pre-norm architecture."""
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, attn_dropout=0.0, qkv_bias=True):
        super().__init__()
        
        # Pre-norm architecture with LayerNorm
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = MultiHeadAttention(dim, num_heads, attn_dropout, qkv_bias=qkv_bias)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout)
        
        # Dropout for residual connections
        self.drop = nn.Dropout(dropout)
        
    def forward(self, x):
        # Self-attention block with residual connection and pre-norm
        x = x + self.drop(self.attn(self.norm1(x)))
        
        # MLP block with residual connection and pre-norm
        x = x + self.drop(self.mlp(self.norm2(x)))
        
        return x

# Vision Transformer
class VisionTransformer(nn.Module):
    """Optimized Vision Transformer implementation."""
    def __init__(
        self,
        img_size=128,
        patch_size=16,
        in_channels=3,
        num_classes=200,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        dropout=0.0,
        attn_dropout=0.0,
        embed_dropout=0.0
    ):
        super().__init__()
        
        # Patch embedding layer
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches
        
        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Position embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        
        # Dropout after position embeddings
        self.pos_drop = nn.Dropout(embed_dropout)
        
        # Transformer blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attn_dropout=attn_dropout,
                qkv_bias=qkv_bias
            )
            for _ in range(depth)
        ])
        
        # Final normalization layer
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        
        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self.apply(self._init_weights)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward_features(self, x):
        # Get patch embeddings
        x = self.patch_embed(x)  # [B, num_patches, embed_dim]
        
        # Add class token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # [B, 1, embed_dim]
        x = torch.cat((cls_token, x), dim=1)  # [B, num_patches+1, embed_dim]
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Apply transformer blocks
        x = self.blocks(x)
        
        # Apply final layer normalization
        x = self.norm(x)
        
        # Use only the class token for classification
        return x[:, 0]
    
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

# Training and evaluation functions
def train_one_epoch(model, dataloader, criterion, optimizer, scheduler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update statistics
        running_loss += loss.item()
        
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1), 
            'acc': 100. * correct / total
        })
        
    # Update learning rate
    if scheduler:
        scheduler.step()
        
    return running_loss / len(dataloader), 100. * correct / total

def evaluate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(dataloader), 100. * correct / total

# Main function to run the training
def main():
    # Data parameters
    IMG_SIZE = 128
    BATCH_SIZE = 64
    NUM_WORKERS = 4
    
    # Model parameters
    PATCH_SIZE = 16
    EMBED_DIM = 768
    DEPTH = 12
    NUM_HEADS = 12
    MLP_RATIO = 4.0
    DROPOUT = 0.1
    ATTN_DROPOUT = 0.0
    EMBED_DROPOUT = 0.1
    
    # Training parameters
    EPOCHS = 30
    LEARNING_RATE = 2e-4
    WEIGHT_DECAY = 0.05
    
    # Data augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Validation transforms - only resize and normalize
    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Check if TinyImageNet is already downloaded, otherwise download it
    if os.path.exists("tiny-imagenet-200"):
        print("Using existing TinyImageNet dataset")
        image_path = Path("tiny-imagenet-200")
    else:
        print("Downloading TinyImageNet dataset")
        image_path = download_data(
            source="https://cs231n.stanford.edu/tiny-imagenet-200.zip",
            destination="."
        )
    
    train_dir = image_path / "train"
    val_dir = image_path / "val"
    val_img_dir = val_dir / "images"
    val_annotations_file = val_dir / "val_annotations.txt"
    
    # Create datasets
    print("Creating datasets...")
    train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
    class_to_idx = train_dataset.class_to_idx
    
    val_dataset = CustomImageDataset(
        img_dir=val_img_dir,
        annotation_file=val_annotations_file,
        class_to_idx=class_to_idx,
        transform=val_transform
    )
    
    # For faster training, use a subset of the data
    # Comment these lines if you want to use the full dataset
    subset_size = 90000
    train_indices = list(range(min(len(train_dataset), subset_size)))
    val_indices = list(range(min(len(val_dataset), 2000)))
    
    train_subset = Subset(train_dataset, train_indices)
    val_subset = Subset(val_dataset, val_indices)
    
    print(f"Train dataset size: {len(train_subset)}")
    print(f"Validation dataset size: {len(val_subset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_subset, 
        batch_size=BATCH_SIZE, 
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_subset, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    # Create model
    print("Initializing model...")
    model = VisionTransformer(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=3,
        num_classes=len(class_to_idx),
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        mlp_ratio=MLP_RATIO,
        dropout=DROPOUT,
        attn_dropout=ATTN_DROPOUT,
        embed_dropout=EMBED_DROPOUT
    )
    
    # Move model to device
    model = model.to(device)
    print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    
    # Train model
    print("Starting training...")
    best_acc = 0.0
    
    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch+1}/{EPOCHS}")
        print("-" * 20)
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, scheduler
        )
        
        # Evaluate
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if train_acc > best_acc:
            best_acc = train_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_acc': train_acc,
                'val_acc': val_acc,
            }, "best_model.pt")
            print(f"Saved model with train accuracy: {train_acc:.2f}%")
    
    print(f"Training completed. Best training accuracy: {best_acc:.2f}%")

if __name__ == "__main__":
    main()

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import requests
import zipfile
from tqdm import tqdm
import random
import time

# Set random seeds
def set_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

set_seeds()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data preparation functions
def download_data(source, destination):
    dest_path = Path(destination)
    if not dest_path.exists():
        dest_path.mkdir(parents=True, exist_ok=True)
    file_name = source.split('/')[-1]
    file_path = dest_path / file_name
    if not file_path.exists():
        with open(file_path, "wb") as f:
            request = requests.get(source)
            print(f"Downloading {file_name}...")
            f.write(request.content)
    else:
        print(f"{file_name} already exists")

    extract_path = dest_path / "tiny-imagenet-200"
    if not extract_path.exists():
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            print(f"Extracting {file_name}...")
            zip_ref.extractall(dest_path)
    else:
        print(f"{extract_path} already exists")

    return dest_path / "tiny-imagenet-200"

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, annotation_file, class_to_idx, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.annotations = self._load_annotations(annotation_file)
        self.class_to_idx = class_to_idx
        
    def _load_annotations(self, annotation_file):
        annotations = {}
        with open(annotation_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                img_name = parts[0]
                class_name = parts[1]
                annotations[img_name] = class_name
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_name = list(self.annotations.keys())[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        label = self.annotations[img_name]
        label_idx = self.class_to_idx[label]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_idx

# OPTIMIZATION 1: Improved PatchEmbedding with layer normalization
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Use convolution for patch embedding
        self.proj = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
        # ADDED: Normalization after projection for stable training
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        x = self.norm(x)  # Apply normalization
        return x

# OPTIMIZATION 2: Modified MultiHeadAttention with more stable softmax scaling
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1, qkv_bias=True):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        # MODIFIED: Adjusted scaling factor for better gradient flow
        self.scale = (self.head_dim ** -0.5) * 0.8
        
        # Combined projection for Q, K, V
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)
        self.attn_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, num_tokens, dim = x.shape
        
        qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Calculate attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # ADDED: Apply attention mask to avoid gradient issues with padding
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(batch_size, num_tokens, dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

# OPTIMIZATION 3: Enhanced MLP with mixture of GELU and ReLU activations
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        # MODIFIED: Using GELU for better feature learning
        self.act = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x

# OPTIMIZATION 4: Improved TransformerBlock with StochDepth (drop path)
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1, 
                 drop_path=0.0, qkv_bias=True):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = MultiHeadAttention(dim, num_heads, attn_dropout, qkv_bias=qkv_bias)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout)
        
        # ADDED: Stochastic depth (drop path) for enhanced regularization
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        # Using drop path instead of dropout for residual connections
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

# ADDED: DropPath (Stochastic Depth) implementation
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
            
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

# OPTIMIZATION 5: Enhanced Vision Transformer with deeper architecture
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=128,
        patch_size=16,
        in_channels=3,
        num_classes=200,
        embed_dim=512,  # REDUCED from 768 to limit model size
        depth=8,        # REDUCED from 12 for efficiency
        num_heads=8,    # REDUCED from 12
        mlp_ratio=3.0,  # REDUCED from 4.0
        qkv_bias=True,
        dropout=0.1,
        attn_dropout=0.1,
        embed_dropout=0.1,
        drop_path_rate=0.1  # ADDED: drop path rate
    ):
        super().__init__()
        
        # Patch embedding layer
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches
        
        # Class token and position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # ADDED: Use learnable absolute positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(embed_dropout)
        
        # MODIFIED: Stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        
        # Transformer blocks with progressively increasing drop path rate
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attn_dropout=attn_dropout,
                drop_path=dpr[i],
                qkv_bias=qkv_bias
            )
            for i in range(depth)
        ])
        
        # Final normalization and classification head
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        
        # MODIFIED: Two-stage head for better classification
        self.pre_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.BatchNorm1d(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.2)
        )
        self.head = nn.Linear(embed_dim // 2, num_classes)
        
        # Initialize weights
        self.apply(self._init_weights)
        
        # OPTIMIZATION 6: Better initialization
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward_features(self, x):
        x = self.patch_embed(x)
        
        # Add class token for classification
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Apply final layer normalization
        x = self.norm(x)
        
        # Use class token for classification
        return x[:, 0]
    
    def forward(self, x):
        x = self.forward_features(x)
        
        # Apply two-stage head
        x = self.pre_head(x)
        x = self.head(x)
        return x

# OPTIMIZATION 7: Enhanced training function with gradient clipping
def train_one_epoch(model, dataloader, criterion, optimizer, scheduler=None, clip_grad_norm=1.0):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        
        # Apply gradient clipping to prevent exploding gradients
        if clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
            
        optimizer.step()
        
        # Update statistics
        running_loss += loss.item()
        
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1), 
            'acc': 100. * correct / total
        })
        
    # Update learning rate
    if scheduler:
        scheduler.step()
        
    return running_loss / len(dataloader), 100. * correct / total

def evaluate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(dataloader), 100. * correct / total

# Main function to run the training
def main():
    # MODIFIED: Data parameters
    IMG_SIZE = 128
    BATCH_SIZE = 128  # INCREASED for faster training
    NUM_WORKERS = 4
    
    # MODIFIED: Model parameters for better performance/efficiency balance
    PATCH_SIZE = 16
    EMBED_DIM = 512  # Reduced from 768
    DEPTH = 8        # Reduced from 12
    NUM_HEADS = 8    # Reduced from 12
    MLP_RATIO = 3.0  # Reduced from 4.0
    DROPOUT = 0.1
    ATTN_DROPOUT = 0.1
    EMBED_DROPOUT = 0.1
    
    # MODIFIED: Training parameters
    EPOCHS = 30
    LEARNING_RATE = 5e-4  # Increased for faster convergence
    WEIGHT_DECAY = 0.1    # Increased for better regularization
    
    # OPTIMIZATION 8: Enhanced data augmentation
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.6, 1.0)),  # More aggressive crop
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  # Add rotation for robustness
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2)  # Add random erasing for robustness
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Dataset preparation
    if os.path.exists("tiny-imagenet-200"):
        print("Using existing TinyImageNet dataset")
        image_path = Path("tiny-imagenet-200")
    else:
        print("Downloading TinyImageNet dataset")
        image_path = download_data(
            source="https://cs231n.stanford.edu/tiny-imagenet-200.zip",
            destination="."
        )
    
    train_dir = image_path / "train"
    val_dir = image_path / "val"
    val_img_dir = val_dir / "images"
    val_annotations_file = val_dir / "val_annotations.txt"
    
    # Create datasets
    print("Creating datasets...")
    train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
    class_to_idx = train_dataset.class_to_idx
    
    val_dataset = CustomImageDataset(
        img_dir=val_img_dir,
        annotation_file=val_annotations_file,
        class_to_idx=class_to_idx,
        transform=val_transform
    )
    
    # MODIFIED: Using more data for better training
    subset_size = 90000  # INCREASED from 10000 to use more training data
    train_indices = list(range(min(len(train_dataset), subset_size)))
    val_indices = list(range(min(len(val_dataset), 5000)))  # INCREASED validation set
    
    train_subset = Subset(train_dataset, train_indices)
    val_subset = Subset(val_dataset, val_indices)
    
    print(f"Train dataset size: {len(train_subset)}")
    print(f"Validation dataset size: {len(val_subset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_subset, 
        batch_size=BATCH_SIZE, 
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_subset, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    # Create model
    print("Initializing model...")
    model = VisionTransformer(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=3,
        num_classes=len(class_to_idx),
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        mlp_ratio=MLP_RATIO,
        dropout=DROPOUT,
        attn_dropout=ATTN_DROPOUT,
        embed_dropout=EMBED_DROPOUT
    )
    
    # Move model to device
    model = model.to(device)
    print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
    
    # OPTIMIZATION 9: Label smoothing for better generalization
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # OPTIMIZATION 10: Parameter groups with different weight decay
    no_decay = ['bias', 'LayerNorm.weight', 'BatchNorm1d.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() 
                       if not any(nd in n for nd in no_decay)],
            'weight_decay': WEIGHT_DECAY
        },
        {
            'params': [p for n, p in model.named_parameters() 
                       if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)
    
    # OPTIMIZATION 11: OneCycleLR scheduler for faster convergence
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=LEARNING_RATE,
        total_steps=EPOCHS * len(train_loader),
        pct_start=0.2,  # Warm up for 20% of training
        div_factor=25,
        final_div_factor=1000
    )
    
    # Train model with early stopping
    print("Starting training...")
    best_acc = 0.0
    early_stop_patience = 5
    early_stop_counter = 0
    
    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch+1}/{EPOCHS}")
        print("-" * 20)
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, scheduler, clip_grad_norm=1.0
        )
        
        # Evaluate
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model based on validation accuracy
        if val_acc > best_acc:
            best_acc = val_acc
            early_stop_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_acc': train_acc,
                'val_acc': val_acc,
            }, "best_model.pt")
            print(f"Saved model with val accuracy: {val_acc:.2f}%")
        else:
            early_stop_counter += 1
            
        # Early stopping
        if early_stop_counter >= early_stop_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    print(f"Training completed. Best validation accuracy: {best_acc:.2f}%")
    
    # Load best model for final evaluation
    checkpoint = torch.load("/kaggle/input/best-model/best_model_v2.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Final evaluation
    _, final_val_acc = evaluate(model, val_loader, criterion)
    print(f"Final model validation accuracy: {final_val_acc:.2f}%")

if __name__ == "__main__":
    main()

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import requests
import zipfile
from tqdm import tqdm
import random
import time

# Import PyTorch/XLA libraries for TPU support
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

# Model classes and helper functions (all unchanged from your code)
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
            
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        x = self.norm(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1, qkv_bias=True):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = (self.head_dim ** -0.5) * 0.8
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)
        self.attn_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, num_tokens, dim = x.shape
        
        qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(batch_size, num_tokens, dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1, 
                 drop_path=0.0, qkv_bias=True):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = MultiHeadAttention(dim, num_heads, attn_dropout, qkv_bias=qkv_bias)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout)
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=128,
        patch_size=16,
        in_channels=3,
        num_classes=200,
        embed_dim=512,
        depth=8,
        num_heads=8,
        mlp_ratio=3.0,
        qkv_bias=True,
        dropout=0.1,
        attn_dropout=0.1,
        embed_dropout=0.1,
        drop_path_rate=0.1
    ):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(embed_dropout)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attn_dropout=attn_dropout,
                drop_path=dpr[i],
                qkv_bias=qkv_bias
            )
            for i in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        
        self.pre_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.BatchNorm1d(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.2)
        )
        self.head = nn.Linear(embed_dim // 2, num_classes)
        
        self._init_weights()
        
    def _init_weights(self):
        # Initialize position embedding and class token
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize all linear layers
        self.apply(self._init_layer_weights)
    
    def _init_layer_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward_features(self, x):
        x = self.patch_embed(x)
        
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        return x[:, 0]
    
    def forward(self, x):
        x = self.forward_features(x)
        x = self.pre_head(x)
        x = self.head(x)
        return x

# Data preparation functions
def download_data(source, destination):
    dest_path = Path(destination)
    if not dest_path.exists():
        dest_path.mkdir(parents=True, exist_ok=True)
    file_name = source.split('/')[-1]
    file_path = dest_path / file_name
    if not file_path.exists():
        with open(file_path, "wb") as f:
            request = requests.get(source)
            print(f"Downloading {file_name}...")
            f.write(request.content)
    else:
        print(f"{file_name} already exists")

    extract_path = dest_path / "tiny-imagenet-200"
    if not extract_path.exists():
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            print(f"Extracting {file_name}...")
            zip_ref.extractall(dest_path)
    else:
        print(f"{extract_path} already exists")

    return dest_path / "tiny-imagenet-200"

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, annotation_file, class_to_idx, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.annotations = self._load_annotations(annotation_file)
        self.class_to_idx = class_to_idx
        
    def _load_annotations(self, annotation_file):
        annotations = {}
        with open(annotation_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                img_name = parts[0]
                class_name = parts[1]
                annotations[img_name] = class_name
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_name = list(self.annotations.keys())[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        label = self.annotations[img_name]
        label_idx = self.class_to_idx[label]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_idx

# TPU training function (moved entirely inside _mp_fn)
def _mp_fn(index):
    # IMPORTANT: Only initialize the device after xmp.spawn is called
    device = xm.xla_device()
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    # Print TPU device info
    xm.master_print(f"Using device: {device}")
    xm.master_print(f"TPU process {index} initialized")
    
    # Parameters
    IMG_SIZE = 128
    PATCH_SIZE = 16
    BATCH_SIZE = 64  # Adjusted for TPU compatibility
    NUM_WORKERS = 2  # Reduced to avoid warnings
    
    EMBED_DIM = 512
    DEPTH = 8
    NUM_HEADS = 8
    MLP_RATIO = 3.0
    DROPOUT = 0.1
    ATTN_DROPOUT = 0.1
    EMBED_DROPOUT = 0.1
    
    # Continue training from checkpoint
    START_EPOCH = 28  # Already completed 29 epochs
    EPOCHS = 100       # Train for more epochs
    LEARNING_RATE = 1e-4  # Reduced learning rate for fine-tuning
    WEIGHT_DECAY = 0.1
    
    # Only execute dataset preparation on the master process
    if xm.is_master_ordinal():
        # Data augmentation
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(IMG_SIZE, scale=(0.6, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.2)
        ])
        
        val_transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Load dataset
        if os.path.exists("tiny-imagenet-200"):
            xm.master_print("Using existing TinyImageNet dataset")
            image_path = Path("tiny-imagenet-200")
        else:
            xm.master_print("Downloading TinyImageNet dataset")
            image_path = download_data(
                source="https://cs231n.stanford.edu/tiny-imagenet-200.zip",
                destination="."
            )
        
        train_dir = image_path / "train"
        val_dir = image_path / "val"
        val_img_dir = val_dir / "images"
        val_annotations_file = val_dir / "val_annotations.txt"
        
        # Create datasets
        xm.master_print("Creating datasets...")
        train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
        class_to_idx = train_dataset.class_to_idx
        
        val_dataset = CustomImageDataset(
            img_dir=val_img_dir,
            annotation_file=val_annotations_file,
            class_to_idx=class_to_idx,
            transform=val_transform
        )
        
        # Use full dataset
        xm.master_print(f"Train dataset size: {len(train_dataset)}")
        xm.master_print(f"Validation dataset size: {len(val_dataset)}")

    # Synchronize all processes to ensure master has fully prepared the data
    xm.rendezvous('dataset_prepared')
    
    # Make datasets available to all processes
    if xm.is_master_ordinal():
        world_size = xm.xrt_world_size()
        xm.master_print(f"TPU world size: {world_size}")
    
    # Create datasets and data loaders (must be done in all processes)
    train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
    class_to_idx = train_dataset.class_to_idx
    
    val_dataset = CustomImageDataset(
        img_dir=val_img_dir,
        annotation_file=val_annotations_file,
        class_to_idx=class_to_idx,
        transform=val_transform
    )

    # Create TPU-optimized samplers
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        num_workers=NUM_WORKERS,
        drop_last=True
    )
    
    # For validation, we don't need to distribute it as much
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        sampler=valid_sampler,
        num_workers=NUM_WORKERS
    )
    
    # Create parallel loaders (essential for TPU training)
    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    val_device_loader = pl.MpDeviceLoader(val_loader, device)
    
    # Initialize model
    xm.master_print("Initializing model...")
    model = VisionTransformer(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=3,
        num_classes=len(class_to_idx),
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        mlp_ratio=MLP_RATIO,
        dropout=DROPOUT,
        attn_dropout=ATTN_DROPOUT,
        embed_dropout=EMBED_DROPOUT
    )
    
    # Move model to TPU
    model = model.to(device)
    
    if xm.is_master_ordinal():
        param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        xm.master_print(f"Model has {param_count:,} parameters")
    
    # Load checkpoint
    if os.path.exists('/kaggle/input/best-model/best_model_v2.pt'):
        xm.master_print("Loading checkpoint with 18% accuracy...")
        # Use local loading first to avoid TPU issues
        checkpoint = torch.load('/kaggle/input/best-model/best_model_v2.pt', map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        xm.master_print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with {checkpoint['val_acc']:.2f}% accuracy")
        best_acc = checkpoint['val_acc']
    else:
        xm.master_print("No checkpoint found, starting from scratch")
        best_acc = 0.0
    
    # Loss function with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # Parameter groups for weight decay
    no_decay = ['bias', 'LayerNorm.weight', 'BatchNorm1d.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() 
                      if not any(nd in n for nd in no_decay)],
            'weight_decay': WEIGHT_DECAY
        },
        {
            'params': [p for n, p in model.named_parameters() 
                      if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    
    # Create optimizer and scheduler
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(EPOCHS - START_EPOCH))
    
    # Training and evaluation functions for TPU
    def train_one_epoch(model, dataloader, criterion, optimizer, scheduler):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (images, labels) in enumerate(dataloader):
            # Forward pass and loss calculation
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            
            # TPU-specific: we need to call optimizer_step through XLA
            xm.optimizer_step(optimizer)
            
            # Update metrics
            running_loss += loss.detach().item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Log progress (only on master process)
            if batch_idx % 2 == 0 and xm.is_master_ordinal():
                xm.master_print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
        
        # Sync metrics across all TPU cores
        running_loss = xm.mesh_reduce('train_loss', running_loss, lambda x: sum(x) / len(x))
        correct = xm.mesh_reduce('train_correct', correct, sum)
        total = xm.mesh_reduce('train_total', total, sum)
        
        # Update scheduler after epoch (if needed)
        if scheduler:
            scheduler.step()
            
        return running_loss / len(dataloader), 100. * correct / total
    
    def evaluate(model, dataloader, criterion):
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in dataloader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        # Sync metrics across all TPU cores
        running_loss = xm.mesh_reduce('val_loss', running_loss, lambda x: sum(x) / len(x))
        correct = xm.mesh_reduce('val_correct', correct, sum)
        total = xm.mesh_reduce('val_total', total, sum)
        
        return running_loss / len(dataloader), 100. * correct / total
    
    # Resume training
    patience = 7  # Increased patience for fine-tuning
    early_stop_counter = 0
    
    for epoch in range(START_EPOCH, EPOCHS):
        if xm.is_master_ordinal():
            xm.master_print(f"\nEpoch {epoch+1}/{EPOCHS}")
            xm.master_print("-" * 20)
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_device_loader, criterion, optimizer, scheduler
        )
        
        # Evaluate
        val_loss, val_acc = evaluate(model, val_device_loader, criterion)
        
        if xm.is_master_ordinal():
            xm.master_print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            xm.master_print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_acc and xm.is_master_ordinal():
            best_acc = val_acc
            early_stop_counter = 0
            
            # Save checkpoint (use xm.save to handle TPU serialization)
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_acc': train_acc,
                'val_acc': val_acc,
            }
            xm.save(checkpoint, "best_tpu_model.pt")
            xm.master_print(f"Saved best model with val accuracy: {val_acc:.2f}%")
        elif val_acc <= best_acc:
            early_stop_counter += 1
            if xm.is_master_ordinal():
                xm.master_print(f"Val accuracy didn't improve. Counter: {early_stop_counter}/{patience}")
            
        # Early stopping
        if early_stop_counter >= patience:
            if xm.is_master_ordinal():
                xm.master_print(f"Early stopping triggered after {epoch+1} epochs")
            break
        
        # Ensure all TPU processes are synchronized before continuing
        xm.rendezvous(f'epoch_{epoch}_complete')
    
    if xm.is_master_ordinal():
        xm.master_print(f"Training completed. Best validation accuracy: {best_acc:.2f}%")

# Entry point
def main():
    print("Starting TPU training...")
    # Auto-detect the number of TPU cores available
    try:
        num_cores = xm.xrt_world_size()
        print(f"Detected {num_cores} TPU cores")
    except:
        # If can't detect, default to 1 (single TPU core)
        num_cores = 1
        print("Could not detect TPU configuration, using single core")
    
    # Spawn TPU processes
    xmp.spawn(_mp_fn, nprocs=num_cores)
    print("TPU training completed")

if __name__ == "__main__":
    main()

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import requests
import zipfile
from tqdm import tqdm
import random
import time

# Import PyTorch/XLA libraries for TPU support
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

# Model classes and helper functions (all unchanged from your code)
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
            
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        x = self.norm(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1, qkv_bias=True):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = (self.head_dim ** -0.5) * 0.8
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout)
        self.attn_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, num_tokens, dim = x.shape
        
        qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(batch_size, num_tokens, dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1, 
                 drop_path=0.0, qkv_bias=True):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = MultiHeadAttention(dim, num_heads, attn_dropout, qkv_bias=qkv_bias)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout)
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=128,
        patch_size=16,
        in_channels=3,
        num_classes=200,
        embed_dim=512,
        depth=8,
        num_heads=8,
        mlp_ratio=3.0,
        qkv_bias=True,
        dropout=0.1,
        attn_dropout=0.1,
        embed_dropout=0.1,
        drop_path_rate=0.1
    ):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(embed_dropout)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attn_dropout=attn_dropout,
                drop_path=dpr[i],
                qkv_bias=qkv_bias
            )
            for i in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        
        self.pre_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.BatchNorm1d(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.2)
        )
        self.head = nn.Linear(embed_dim // 2, num_classes)
        
        self._init_weights()
        
    def _init_weights(self):
        # Initialize position embedding and class token
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize all linear layers
        self.apply(self._init_layer_weights)
    
    def _init_layer_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward_features(self, x):
        x = self.patch_embed(x)
        
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        return x[:, 0]
    
    def forward(self, x):
        x = self.forward_features(x)
        x = self.pre_head(x)
        x = self.head(x)
        return x

# Data preparation functions
def download_data(source, destination):
    dest_path = Path(destination)
    if not dest_path.exists():
        dest_path.mkdir(parents=True, exist_ok=True)
    file_name = source.split('/')[-1]
    file_path = dest_path / file_name
    if not file_path.exists():
        with open(file_path, "wb") as f:
            request = requests.get(source)
            print(f"Downloading {file_name}...")
            f.write(request.content)
    else:
        print(f"{file_name} already exists")

    extract_path = dest_path / "tiny-imagenet-200"
    if not extract_path.exists():
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            print(f"Extracting {file_name}...")
            zip_ref.extractall(dest_path)
    else:
        print(f"{extract_path} already exists")

    return dest_path / "tiny-imagenet-200"

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, annotation_file, class_to_idx, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.annotations = self._load_annotations(annotation_file)
        self.class_to_idx = class_to_idx
        
    def _load_annotations(self, annotation_file):
        annotations = {}
        with open(annotation_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                img_name = parts[0]
                class_name = parts[1]
                annotations[img_name] = class_name
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_name = list(self.annotations.keys())[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        label = self.annotations[img_name]
        label_idx = self.class_to_idx[label]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_idx

# TPU training function
def _mp_fn(index):
    # Initialize the device
    device = xm.xla_device()
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    # Print TPU device info (only on master process)
    if xm.is_master_ordinal():
        xm.master_print(f"Using device: {device}")
        xm.master_print(f"TPU process {index} initialized")
    
    # Parameters
    IMG_SIZE = 128
    PATCH_SIZE = 16
    BATCH_SIZE = 64
    NUM_WORKERS = 2
    
    EMBED_DIM = 512
    DEPTH = 8
    NUM_HEADS = 8
    MLP_RATIO = 3.0
    DROPOUT = 0.1
    ATTN_DROPOUT = 0.1
    EMBED_DROPOUT = 0.1
    
    # Training parameters
    MAX_EPOCHS = 100
    LEARNING_RATE = 8e-5  # Further reduced for fine-tuning
    WEIGHT_DECAY = 0.1
    
    # Only execute dataset preparation on the master process
    if xm.is_master_ordinal():
        # Data augmentation
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(IMG_SIZE, scale=(0.6, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.2)
        ])
        
        val_transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Load dataset
        if os.path.exists("tiny-imagenet-200"):
            xm.master_print("Using existing TinyImageNet dataset")
            image_path = Path("tiny-imagenet-200")
        else:
            xm.master_print("Downloading TinyImageNet dataset")
            image_path = download_data(
                source="https://cs231n.stanford.edu/tiny-imagenet-200.zip",
                destination="."
            )
        
        train_dir = image_path / "train"
        val_dir = image_path / "val"
        val_img_dir = val_dir / "images"
        val_annotations_file = val_dir / "val_annotations.txt"
        
        # Create datasets
        xm.master_print("Creating datasets...")

    # Synchronize all processes
    xm.rendezvous('dataset_prepared')
    
    # Create datasets and data loaders (must be done in all processes)
    train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
    class_to_idx = train_dataset.class_to_idx
    
    val_dataset = CustomImageDataset(
        img_dir=val_img_dir,
        annotation_file=val_annotations_file,
        class_to_idx=class_to_idx,
        transform=val_transform
    )
    
    # Create TPU-optimized samplers
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        num_workers=NUM_WORKERS,
        drop_last=True
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        sampler=valid_sampler,
        num_workers=NUM_WORKERS
    )
    
    # Create parallel loaders
    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    val_device_loader = pl.MpDeviceLoader(val_loader, device)
    
    # Initialize model
    if xm.is_master_ordinal():
        xm.master_print("Initializing model...")
    
    model = VisionTransformer(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=3,
        num_classes=len(class_to_idx),
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        mlp_ratio=MLP_RATIO,
        dropout=DROPOUT,
        attn_dropout=ATTN_DROPOUT,
        embed_dropout=EMBED_DROPOUT
    )
    
    # Move model to TPU
    model = model.to(device)
    
    if xm.is_master_ordinal():
        param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        xm.master_print(f"Model has {param_count:,} parameters")
    
    # Load checkpoint - Use the best_tpu_model.pt file
    if os.path.exists('/kaggle/input/tpu-model/best_tpu_model_V2.pt'):
        if xm.is_master_ordinal():
            xm.master_print("Loading checkpoint from best_tpu_model.pt...")
            
        # Load on CPU first then transfer to TPU
        checkpoint = torch.load('/kaggle/input/tpu-model/best_tpu_model_V2.pt', map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Get the starting epoch and best accuracy
        start_epoch = checkpoint['epoch'] + 1
        best_acc = checkpoint['val_acc']
        
        if xm.is_master_ordinal():
            xm.master_print(f"Resuming from epoch {start_epoch} with accuracy {best_acc:.2f}%")
    else:
        if xm.is_master_ordinal():
            xm.master_print("No checkpoint found, starting from scratch")
        start_epoch = 0
        best_acc = 0.0
    
    # Loss function with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # Parameter groups for weight decay
    no_decay = ['bias', 'LayerNorm.weight', 'BatchNorm1d.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() 
                      if not any(nd in n for nd in no_decay)],
            'weight_decay': WEIGHT_DECAY
        },
        {
            'params': [p for n, p in model.named_parameters() 
                      if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    
    # Create optimizer and scheduler
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)
    
    # If resuming, load optimizer state if available
    if os.path.exists('best_tpu_model.pt') and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if xm.is_master_ordinal():
            xm.master_print("Loaded optimizer state from checkpoint")
    
    # Create learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=(MAX_EPOCHS - start_epoch)
    )
    
    # Training function - modified to not show batch progress
    def train_one_epoch(model, dataloader, criterion, optimizer, scheduler):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Use tqdm only on master process and only for visualization
        if xm.is_master_ordinal():
            pbar = tqdm(total=len(dataloader), desc="Training", leave=False)
        
        for images, labels in dataloader:
            # Forward pass and loss calculation
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            
            # TPU-specific optimizer step
            xm.optimizer_step(optimizer)
            
            # Update metrics
            running_loss += loss.detach().item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar on master process
            if xm.is_master_ordinal():
                pbar.update(1)
        
        if xm.is_master_ordinal():
            pbar.close()
        
        # Sync metrics across all TPU cores
        running_loss = xm.mesh_reduce('train_loss', running_loss, lambda x: sum(x) / len(x))
        correct = xm.mesh_reduce('train_correct', correct, sum)
        total = xm.mesh_reduce('train_total', total, sum)
        
        # Update scheduler after epoch
        if scheduler:
            scheduler.step()
            
        return running_loss / len(dataloader), 100. * correct / total
    
    def evaluate(model, dataloader, criterion):
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Use tqdm only on master process and only for visualization
        if xm.is_master_ordinal():
            pbar = tqdm(total=len(dataloader), desc="Validating", leave=False)
        
        with torch.no_grad():
            for images, labels in dataloader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                # Update progress bar on master process
                if xm.is_master_ordinal():
                    pbar.update(1)
        
        if xm.is_master_ordinal():
            pbar.close()
        
        # Sync metrics across all TPU cores
        running_loss = xm.mesh_reduce('val_loss', running_loss, lambda x: sum(x) / len(x))
        correct = xm.mesh_reduce('val_correct', correct, sum)
        total = xm.mesh_reduce('val_total', total, sum)
        
        return running_loss / len(dataloader), 100. * correct / total
    
    # Resume training
    patience = 10  # Increased patience for fine-tuning
    early_stop_counter = 0
    
    # Print training summary header
    if xm.is_master_ordinal():
        xm.master_print("\n" + "="*50)
        xm.master_print(f"Resuming training from epoch {start_epoch+1} to {MAX_EPOCHS}")
        xm.master_print(f"Best accuracy so far: {best_acc:.2f}%")
        xm.master_print("="*50 + "\n")
    
    for epoch in range(start_epoch, MAX_EPOCHS):
        epoch_start_time = time.time()
        
        if xm.is_master_ordinal():
            xm.master_print(f"Epoch {epoch+1}/{MAX_EPOCHS}")
        
        # Set epoch for distributed sampler
        train_sampler.set_epoch(epoch)
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_device_loader, criterion, optimizer, scheduler
        )
        
        # Evaluate
        val_loss, val_acc = evaluate(model, val_device_loader, criterion)
        
        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time
        
        if xm.is_master_ordinal():
            current_lr = optimizer.param_groups[0]['lr']
            xm.master_print(f"Epoch {epoch+1}/{MAX_EPOCHS} completed in {epoch_time:.2f}s")
            xm.master_print(f"LR: {current_lr:.6f} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
            xm.master_print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_acc and xm.is_master_ordinal():
            best_acc = val_acc
            early_stop_counter = 0
            
            # Save checkpoint
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_acc': train_acc,
                'val_acc': val_acc,
            }
            xm.save(checkpoint, "best_tpu_model.pt")
            
            if xm.is_master_ordinal():
                xm.master_print(f"✓ New best model saved with val accuracy: {val_acc:.2f}%")
        elif val_acc <= best_acc:
            early_stop_counter += 1
            if xm.is_master_ordinal():
                xm.master_print(f"! No improvement. Early stopping counter: {early_stop_counter}/{patience}")
            
        # Early stopping
        if early_stop_counter >= patience:
            if xm.is_master_ordinal():
                xm.master_print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
            
        # Add a separator between epochs
        if xm.is_master_ordinal():
            xm.master_print("-"*50)
        
        # Ensure all TPU processes are synchronized before continuing
        xm.rendezvous(f'epoch_{epoch}_complete')
    
    if xm.is_master_ordinal():
        xm.master_print(f"\nTraining completed. Best validation accuracy: {best_acc:.2f}%")

# Entry point
def main():
    print("Starting TPU training...")
    # Auto-detect the number of TPU cores available
    try:
        num_cores = xm.xrt_world_size()
        print(f"Detected {num_cores} TPU cores")
    except:
        # If can't detect, default to 1 (single TPU core)
        num_cores = 1
        print("Could not detect TPU configuration, using single core")
    
    # Spawn TPU processes
    xmp.spawn(_mp_fn, nprocs=num_cores)
    print("TPU training completed")

if __name__ == "__main__":
    main()