In [None]:
"""
Vision Transformer (ViT) Implementation from Scratch - Optimized for T4 GPU

This file contains a complete implementation of Vision Transformer (ViT) for image classification 
using the Food-101 dataset. The implementation is built from scratch using PyTorch and optimized for T4 GPU.

Sections:
1. Imports and Setup
2. Helper Functions
3. Dataset Handling
4. ViT Components
5. Complete ViT Model
6. Training Functions
7. Main Training Script
"""

# =============================================================================
# 1. IMPORTS AND SETUP
# =============================================================================

import torch
import torchvision
from pathlib import Path
import matplotlib.pyplot as plt
from torch import nn
from torchvision import transforms
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
import os
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from typing import List

# Set device - prioritize CUDA for T4
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Check GPU memory and capabilities for multi-GPU T4 optimization
if device == "cuda":
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"GPU {i} Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
        print(f"GPU {i} CUDA Capability: {torch.cuda.get_device_capability(i)}")
    
    # Enable optimizations for T4
    torch.backends.cudnn.benchmark = True  # Optimize for consistent input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 for faster training
    torch.backends.cudnn.allow_tf32 = True
    
    # Multi-GPU setup
    if num_gpus > 1:
        print(f"\n🚀 Multi-GPU Training Enabled with {num_gpus} GPUs!")
        print("This will significantly speed up training and allow larger batch sizes.")
    else:
        print("\n⚡ Single GPU Training")
else:
    num_gpus = 0

# =============================================================================
# 2. HELPER FUNCTIONS
# =============================================================================

def plot_predictions(train_data, train_labels, test_data, test_labels, predictions=None):
    """
    Plots linear training data and test data and compares predictions.
    """
    plt.figure(figsize=(10, 7))
    plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
    plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
    
    if predictions is not None:
        plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
    
    plt.legend(prop={"size": 14})


def accuracy_fn(y_true, y_pred):
    """Calculates accuracy between truth labels and predictions.

    Args:
        y_true (torch.Tensor): Truth labels for predictions.
        y_pred (torch.Tensor): Predictions to be compared to predictions.

    Returns:
        [torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
    """
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc


def print_train_time(start, end, device=None):
    """Prints difference between start and end time.

    Args:
        start (float): Start time of computation (preferred in timeit format). 
        end (float): End time of computation.
        device ([type], optional): Device that compute is running on. Defaults to None.

    Returns:
        float: time between start and end in seconds (higher is longer).
    """
    total_time = end - start
    print(f"\nTrain time on {device}: {total_time:.3f} seconds")
    return total_time


def plot_loss_curves(results):
    """Plots training curves of a results dictionary.

    Args:
        results (dict): dictionary containing list of values, e.g.
            {"train_loss": [...],
             "train_acc": [...],
             "test_loss": [...],
             "test_acc": [...]}
    """
    loss = results["train_loss"]
    test_loss = results["test_loss"]
    
    accuracy = results["train_acc"]
    test_accuracy = results["test_acc"]
    
    epochs = range(len(results["train_loss"]))
    
    plt.figure(figsize=(15, 7))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label="train_loss")
    plt.plot(epochs, test_loss, label="test_loss")
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.legend()
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label="train_accuracy")
    plt.plot(epochs, test_accuracy, label="test_accuracy")
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.legend()


def pred_and_plot_image(
    model: torch.nn.Module,
    image_path: str,
    class_names: List[str] = None,
    transform=None,
    device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
    """Makes a prediction on a target image with a trained model and plots the image.

    Args:
        model (torch.nn.Module): trained PyTorch image classification model.
        image_path (str): filepath to target image.
        class_names (List[str], optional): different class names for target image. Defaults to None.
        transform (_type_, optional): transform of target image. Defaults to None.
        device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
    
    Returns:
        Matplotlib plot of target image and model prediction as title.

    Example usage:
        pred_and_plot_image(model=model,
                            image="some_image.jpeg",
                            class_names=["class_1", "class_2", "class_3"],
                            transform=torchvision.transforms.ToTensor(),
                            device=device)
    """
    
    target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
    target_image = target_image / 255.0
    
    if transform:
        target_image = transform(target_image)
    
    model.to(device)
    model.eval()
    
    with torch.inference_mode():
        target_image = target_image.unsqueeze(dim=0)
        target_image_pred = model(target_image.to(device))
    
    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
    target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
    
    plt.imshow(target_image.squeeze().permute(1, 2, 0)) 
    
    if class_names:
        title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
    else:
        title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
    
    plt.title(title)
    plt.axis(False)


def set_seeds(seed: int = 42):
    """Sets random sets for torch operations.

    Args:
        seed (int, optional): Random seed to set. Defaults to 42.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU
    np.random.seed(seed)
    random.seed(seed)


def get_gpu_memory_usage():
    """Returns current GPU memory usage in GB."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3
    return 0

# =============================================================================
# 3. DATASET HANDLING
# =============================================================================

class Food101Dataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


def load_split(data_path, split_file):
    with open(data_path / "meta/meta" / split_file, 'r') as f:
        split_data = json.load(f)
    return split_data


def prepare_data(data_path):
    # Load splits
    train_split = load_split(data_path, "train.json")
    test_split = load_split(data_path, "test.json")

    # Prepare data
    train_images = []
    train_labels = []
    test_images = []
    test_labels = []
    
    # Process train split
    for class_name, image_paths in train_split.items():
        for relative_img_path in image_paths:  # e.g., "poutine/1005364"
            # The relative_img_path from the JSON (e.g., "poutine/1005364")
            # already contains the class folder.
            # We just need to append ".jpg" and join with "images" directory.
            train_images.append(data_path / "images" / f"{relative_img_path}.jpg")
            train_labels.append(list(train_split.keys()).index(class_name))

    # Process test split
    for class_name, image_paths in test_split.items():
        for relative_img_path in image_paths:  # e.g., "poutine/1005364"
            # Same logic as for the training split.
            test_images.append(data_path / "images" / f"{relative_img_path}.jpg")
            test_labels.append(list(test_split.keys()).index(class_name))

    return train_images, train_labels, test_images, test_labels

# =============================================================================
# 4. VIT COMPONENTS
# =============================================================================

class PatchEmbeddings(nn.Module):
    """
    Converts input images into patch embeddings.
    """
    def __init__(self,
                 in_channels: int = 3,
                 embeddings_dimensions: int = 512,  # Reduced for T4
                 patch_size: int = 16):
        super().__init__()
        
        self.patch_size = patch_size
        self.patched_embeddings = nn.Conv2d(in_channels=in_channels,
                                          out_channels=embeddings_dimensions,
                                          stride=patch_size,
                                          padding=0,
                                          kernel_size=patch_size)

        self.flatten_embeddings = nn.Flatten(start_dim=2, end_dim=3)

    def forward(self, x):
        image_resolution = x.shape[-1]
        assert image_resolution % self.patch_size == 0, f"Input image size must be divisible by patch size, image shape: {image_resolution}, patch size: {self.patch_size}"

        x_patched = self.patched_embeddings(x)
        x_flatten = self.flatten_embeddings(x_patched)
        return x_flatten.permute(0, 2, 1)


class MultiHeadSelfAttentionBlock(nn.Module):
    """
    Multi-Head Self-Attention mechanism for Vision Transformer.
    """
    def __init__(self,
                 num_heads: int = 8,  # Reduced for T4
                 embeddings_dimension: int = 512,  # Reduced for T4
                 attn_dropout: float = 0.1):
        super().__init__()

        self.layer_norm = nn.LayerNorm(embeddings_dimension)
        self.multihead_attn_layer = nn.MultiheadAttention(
            embed_dim=embeddings_dimension,
            num_heads=num_heads,
            dropout=attn_dropout,
            batch_first=True
        )

    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn_layer(query=x, key=x, value=x, need_weights=False)
        return attn_output


class MLPBlock(nn.Module):
    """
    Multi-Layer Perceptron block with GELU activation.
    """
    def __init__(self,
                 embeddings_dimension: int = 512,  # Reduced for T4
                 mlp_size: int = 2048,  # Reduced for T4
                 dropout: float = 0.1):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embeddings_dimension)

        self.mlp = nn.Sequential(
            nn.Linear(in_features=embeddings_dimension,
                     out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size,
                     out_features=embeddings_dimension),
            nn.Dropout(p=dropout)
        )

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x


class TransformerEncoderBlock(nn.Module):
    """
    Complete Transformer Encoder block with self-attention and MLP.
    """
    def __init__(self,
                 num_heads: int = 8,  # Reduced for T4
                 embeddings_dimension: int = 512,  # Reduced for T4
                 dropout: float = 0.1,
                 mlp_size: int = 2048,  # Reduced for T4
                 attn_dropout: float = 0.1):
        super().__init__()

        self.msa_layer = MultiHeadSelfAttentionBlock(
            num_heads=num_heads,
            embeddings_dimension=embeddings_dimension,
            attn_dropout=attn_dropout
        )

        self.mlp_block = MLPBlock(
            dropout=dropout,
            embeddings_dimension=embeddings_dimension,
            mlp_size=mlp_size
        )

    def forward(self, x):
        x = self.msa_layer(x) + x  # Residual connection
        x = self.mlp_block(x) + x  # Residual connection
        return x

# =============================================================================
# 5. COMPLETE VIT MODEL
# =============================================================================

class ViT(nn.Module):
    """
    Complete Vision Transformer (ViT) model implementation optimized for T4 GPU.
    """
    def __init__(self,
                 num_heads: int = 8,  # Reduced for T4
                 embeddings_dimension: int = 512,  # Reduced for T4
                 dropout: float = 0.1,
                 mlp_size: int = 2048,  # Reduced for T4
                 attn_dropout: float = 0.1,
                 num_of_encoder_layers: int = 8,  # Reduced for T4
                 patch_size: int = 16,
                 image_width: int = 224,
                 img_height: int = 224,
                 no_channels: int = 3,
                 classes: int = 101,
                 positional_embedding_dropout: float = 0.1):
        
        assert (img_height * image_width) % (patch_size * patch_size) == 0, \
            f"Image dimensions ({img_height}x{image_width}) must be divisible by patch_size squared ({patch_size*patch_size})"

        super().__init__()
        self.number_of_patches = (image_width * img_height) // (patch_size * patch_size)

        self.patch_embeddings = PatchEmbeddings(
            in_channels=no_channels,
            embeddings_dimensions=embeddings_dimension,
            patch_size=patch_size
        )

        self.positional_embeddings = nn.Parameter(
            torch.randn(1, self.number_of_patches + 1, embeddings_dimension),
            requires_grad=True
        )

        self.cls_token = nn.Parameter(
            torch.randn(1, 1, embeddings_dimension),
            requires_grad=True
        )

        self.encoder_block = nn.Sequential(*[
            TransformerEncoderBlock(
                num_heads=num_heads,
                embeddings_dimension=embeddings_dimension,
                dropout=dropout,
                mlp_size=mlp_size,
                attn_dropout=attn_dropout
            ) for _ in range(num_of_encoder_layers)
        ])

        self.classifier = nn.Sequential(
            nn.LayerNorm(embeddings_dimension),
            nn.Linear(in_features=embeddings_dimension, out_features=classes)
        )

        self.dropout_after_positional_embeddings = nn.Dropout(p=positional_embedding_dropout)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embeddings(x)
        prepend_token = self.cls_token.expand(batch_size, -1, -1)

        x = torch.cat((prepend_token, x), dim=1)
        x = self.positional_embeddings + x
        x = self.dropout_after_positional_embeddings(x)
        x = self.encoder_block(x)
        x = self.classifier(x[:, 0])

        return x

# =============================================================================
# 6. TRAINING FUNCTIONS
# =============================================================================

def get_data_transforms():
    """
    Returns data transformations for training - optimized for T4.
    """
    return transforms.Compose([
        transforms.Resize(size=(224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


def train_step(model, dataloader, loss_fn, optimizer, device, scaler=None, num_gpus=1):
    """
    Performs one training step (epoch) with mixed precision and multi-GPU support for T4.
    """
    model.train()
    train_loss, train_acc = 0, 0
    
    train_pbar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch, (X, y) in enumerate(train_pbar):
        # Send data to target device
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)

        # Mixed precision training for T4
        if scaler is not None and device == "cuda":
            with torch.amp.autocast(device_type='cuda'):
                # Forward pass
                y_pred = model(X)
                loss = loss_fn(y_pred, y)
            
            # Backward pass with gradient scaling
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # Regular training
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss += loss.item()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        current_acc = accuracy_fn(y_true=y, y_pred=y_pred_class)
        train_acc += current_acc

        # Update progress bar with memory usage
        if device == "cuda":
            memory_usage = sum(torch.cuda.memory_allocated(i) for i in range(num_gpus)) / 1024**3
            train_pbar.set_postfix(
                loss=loss.item(), 
                acc=f"{current_acc:.2f}%",
                mem=f"{memory_usage:.1f}GB"
            )
        else:
            train_pbar.set_postfix(loss=loss.item(), acc=f"{current_acc:.2f}%")

    # Calculate average metrics for this epoch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    
    return train_loss, train_acc


def test_step(model, dataloader, loss_fn, device, num_gpus=1):
    """
    Performs one testing step with memory optimization and multi-GPU support.
    """
    model.eval()
    test_loss, test_acc = 0, 0
    
    with torch.inference_mode():
        test_pbar = tqdm(dataloader, desc="Testing", leave=False)
        
        for batch, (X, y) in enumerate(test_pbar):
            # Send data to target device
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)

            # Mixed precision inference
            if device == "cuda":
                with torch.amp.autocast(device_type='cuda'):
                    test_pred_logits = model(X)
                    loss = loss_fn(test_pred_logits, y)
            else:
                test_pred_logits = model(X)
                loss = loss_fn(test_pred_logits, y)

            test_loss += loss.item()

            # Calculate and accumulate accuracy
            test_pred_labels = test_pred_logits.argmax(dim=1)
            current_acc = accuracy_fn(y_true=y, y_pred=test_pred_labels)
            test_acc += current_acc

            # Update progress bar
            if device == "cuda":
                memory_usage = sum(torch.cuda.memory_allocated(i) for i in range(num_gpus)) / 1024**3
                test_pbar.set_postfix(
                    loss=loss.item(), 
                    acc=f"{current_acc:.2f}%",
                    mem=f"{memory_usage:.1f}GB"
                )
            else:
                test_pbar.set_postfix(loss=loss.item(), acc=f"{current_acc:.2f}%")

    # Calculate final test metrics
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    
    return test_loss, test_acc

# =============================================================================
# 7. MAIN TRAINING SCRIPT - OPTIMIZED FOR MULTI-GPU T4 SETUP
# =============================================================================

def main():
    """
    Main training function optimized for multi-GPU T4 setup.
    """
    print(f"Using device: {device}")
    
    # Get number of GPUs
    if device == "cuda":
        num_gpus = torch.cuda.device_count()
    else:
        num_gpus = 0
    
    # Set up data paths - Fixed for Kaggle environment
    import os
    
    # Try different possible paths for Kaggle environment
    possible_paths = [
        Path("/kaggle/input/food41"),  # Kaggle input path
        Path("../input/food41"),      # Alternative Kaggle path
        Path("./food41"),              # Local path
        Path("food41"),                # Current directory
        Path("./food-101"),            # Original naming
        Path("food-101")               # Original naming alternative
    ]
    
    data_path = None
    for path in possible_paths:
        if path.exists():
            data_path = path
            print(f"Found dataset at: {data_path}")
            break
    
    if data_path is None:
        # List available directories to help debug
        print("Available directories in current location:")
        for item in Path(".").iterdir():
            if item.is_dir():
                print(f"  - {item}")
        print("\nAvailable directories in /kaggle/input (if exists):")
        kaggle_input = Path("/kaggle/input")
        if kaggle_input.exists():
            for item in kaggle_input.iterdir():
                if item.is_dir():
                    print(f"  - {item}")
        
        raise FileNotFoundError("Could not find food dataset. Please check the dataset path.")

    # Define training parameters optimized for multi-GPU T4 setup
    NUM_EPOCHS = 25  # Can maintain or even increase with better GPU utilization
    
    # Optimize batch size for multi-GPU setup
    if num_gpus >= 2:
        BATCH_SIZE = 32  # Increased for 2 T4 GPUs (16 per GPU)
        LEARNING_RATE = 2e-4  # Slightly increased for larger effective batch size
        print(f"🚀 Multi-GPU Setup: Using batch size {BATCH_SIZE} ({BATCH_SIZE//num_gpus} per GPU)")
    else:
        BATCH_SIZE = 16  # Single GPU batch size
        LEARNING_RATE = 1e-4
        print(f"⚡ Single GPU Setup: Using batch size {BATCH_SIZE}")
    
    # Define checkpoint directory - Fixed for Kaggle
    run_name = f"vit_food101_multi_t4_optimized" if num_gpus > 1 else f"vit_food101_t4_optimized"
    
    # For Kaggle, save checkpoints in working directory
    if "/kaggle" in str(Path.cwd()):
        checkpoint_base_dir = Path("/kaggle/working/checkpoints")
    else:
        checkpoint_base_dir = Path("./checkpoints")
    
    checkpoint_dir = checkpoint_base_dir / run_name
    
    # Get data transforms
    data_transform = get_data_transforms()
    
    # Prepare data
    train_images, train_labels, test_images, test_labels = prepare_data(data_path)
    
    print(f"Total training images available: {len(train_images)}")
    print(f"Using the full training dataset ({len(train_images)} images) for {NUM_EPOCHS} epochs")
    print(f"Batch size optimized for {num_gpus} GPU(s): {BATCH_SIZE}")
    
    # Create datasets
    full_train_dataset = Food101Dataset(train_images, train_labels, transform=data_transform)
    test_dataset = Food101Dataset(test_images, test_labels, transform=data_transform)
    
    # Create dataloaders with optimized settings for multi-GPU T4
    # Increase num_workers for better data loading with multiple GPUs
    num_workers = min(8, os.cpu_count() // max(1, num_gpus)) if num_gpus > 1 else 4
    
    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True  # Keep workers alive between epochs
    )
    
    train_dataloader = DataLoader(
        dataset=full_train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True  # Keep workers alive between epochs
    )
    
    # Create model optimized for T4
    vit = ViT(
        classes=len(set(train_labels)),
        embeddings_dimension=512,  # Maintained for T4
        num_heads=8,  # Maintained for T4
        num_of_encoder_layers=8,  # Maintained for T4
        mlp_size=2048  # Maintained for T4
    )
    
    # Multi-GPU setup with DataParallel
    if num_gpus > 1:
        print(f"\n🔥 Setting up DataParallel for {num_gpus} GPUs...")
        vit = nn.DataParallel(vit)
        print("✅ DataParallel setup complete!")
    
    vit = vit.to(device)
    
    # Count parameters
    if num_gpus > 1:
        # For DataParallel, access the original model via .module
        total_params = sum(p.numel() for p in vit.module.parameters())
        trainable_params = sum(p.numel() for p in vit.module.parameters() if p.requires_grad)
    else:
        total_params = sum(p.numel() for p in vit.parameters())
        trainable_params = sum(p.numel() for p in vit.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    if num_gpus > 1:
        print(f"Parameters per GPU: ~{trainable_params//num_gpus:,}")
    
    # Setup optimizer with multi-GPU optimized settings
    optimizer = torch.optim.AdamW(
        params=vit.parameters(),
        lr=LEARNING_RATE,
        weight_decay=0.05,  # Slightly increased for regularization
        betas=(0.9, 0.999)
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=NUM_EPOCHS,
        eta_min=1e-6
    )
    
    loss_fn = nn.CrossEntropyLoss()
    
    # Mixed precision scaler for T4
    scaler = torch.amp.GradScaler(device='cuda') if device == "cuda" else None
    
    # Create a writer instance - Fixed for Kaggle
    if "/kaggle" in str(Path.cwd()):
        log_dir = f"/kaggle/working/runs/{run_name}"
    else:
        log_dir = f"runs/{run_name}"
    
    writer = SummaryWriter(log_dir=log_dir)
    
    # Training loop
    results = {"train_loss": [], "train_acc": []}
    
    print("\n" + "="*60)
    if num_gpus > 1:
        print(f"STARTING MULTI-GPU TRAINING ON {num_gpus} T4 GPUs (Mixed Precision)")
    else:
        print("STARTING TRAINING ON T4 GPU (Mixed Precision)")
    print("="*60)
    
    for epoch in range(NUM_EPOCHS):
        print(f"\n=== Epoch {epoch + 1}/{NUM_EPOCHS} ===")
        
        # Clear cache before each epoch for all GPUs
        if device == "cuda":
            for i in range(num_gpus):
                with torch.cuda.device(i):
                    torch.cuda.empty_cache()
        
        # Train for one epoch
        train_loss, train_acc = train_step(vit, train_dataloader, loss_fn, optimizer, device, scaler, num_gpus)
        
        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Store results
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        
        # Log to TensorBoard
        writer.add_scalar(tag="Loss/train", scalar_value=train_loss, global_step=epoch)
        writer.add_scalar(tag="Accuracy/train", scalar_value=train_acc, global_step=epoch)
        writer.add_scalar(tag="Learning Rate", scalar_value=current_lr, global_step=epoch)
        
        # Log GPU memory usage for all GPUs
        if device == "cuda":
            total_memory = 0
            for i in range(num_gpus):
                memory_used = torch.cuda.max_memory_allocated(i) / 1024**3
                total_memory += memory_used
                writer.add_scalar(tag=f"GPU_{i}_Memory_(GB)", scalar_value=memory_used, global_step=epoch)
                torch.cuda.reset_peak_memory_stats(i)
            
            writer.add_scalar(tag="Total_GPU_Memory_(GB)", scalar_value=total_memory, global_step=epoch)
        
        print(f"Epoch {epoch+1} | Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f}% | LR: {current_lr:.6f}")
        if num_gpus > 1:
            print(f"         | Multi-GPU Memory: {total_memory:.1f}GB total ({total_memory/num_gpus:.1f}GB avg)")
        
        # Save checkpoint every 5 epochs and at the end
        if (epoch + 1) % 5 == 0 or epoch == NUM_EPOCHS - 1:
            if checkpoint_dir:
                checkpoint_dir.mkdir(parents=True, exist_ok=True)
                
                # Handle DataParallel model state dict
                model_state_dict = vit.module.state_dict() if num_gpus > 1 else vit.state_dict()
                
                checkpoint = {
                    'epoch': epoch + 1,
                    'model_state_dict': model_state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'train_loss': train_loss,
                    'train_acc': train_acc,
                    'num_gpus': num_gpus
                }
                torch.save(checkpoint, checkpoint_dir / f"epoch_{epoch+1}_checkpoint.pth")
                torch.save(model_state_dict, checkpoint_dir / "latest_model.pth")
    
    print("\n" + "="*60)
    print("TRAINING COMPLETED - STARTING FINAL TESTING")
    print("="*60)
    
    # Final testing phase
    test_loss, test_acc = test_step(vit, test_dataloader, loss_fn, device, num_gpus)
    
    # Add final test results to results dict
    results["test_loss"] = [test_loss]
    results["test_acc"] = [test_acc]
    
    # Log final test results
    writer.add_scalar(tag="Loss/final_test", scalar_value=test_loss, global_step=NUM_EPOCHS)
    writer.add_scalar(tag="Accuracy/final_test", scalar_value=test_acc, global_step=NUM_EPOCHS)
    
    print(f"\nFINAL RESULTS:")
    print(f"Final Test Loss: {test_loss:.4f}")
    print(f"Final Test Accuracy: {test_acc:.2f}%")
    
    # Save final model
    if checkpoint_dir:
        # Handle DataParallel model state dict
        model_state_dict = vit.module.state_dict() if num_gpus > 1 else vit.state_dict()
        
        final_checkpoint = {
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'test_loss': test_loss,
            'test_acc': test_acc,
            'num_gpus': num_gpus,
            'model_config': {
                'embeddings_dimension': 512,
                'num_heads': 8,
                'num_of_encoder_layers': 8,
                'mlp_size': 2048,
                'classes': len(set(train_labels))
            }
        }
        torch.save(final_checkpoint, checkpoint_dir / "final_checkpoint.pth")
        torch.save(model_state_dict, checkpoint_dir / "final_model.pth")
        print(f"Final model saved to {checkpoint_dir / 'final_model.pth'}")
    
    # Close the writer
    writer.close()
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    
    # Plot training loss
    plt.subplot(1, 2, 1)
    plt.plot(range(1, NUM_EPOCHS + 1), results["train_loss"], label="Train Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    title_suffix = f"(Multi-GPU: {num_gpus} T4s)" if num_gpus > 1 else "(Single T4)"
    plt.title(f"Training Loss {title_suffix}")
    plt.legend()
    plt.grid(True)
    
    # Plot training accuracy (and final test accuracy)
    plt.subplot(1, 2, 2)
    plt.plot(range(1, NUM_EPOCHS + 1), results["train_acc"], label="Train Accuracy")
    if results.get("test_acc"):
        plt.axhline(y=results["test_acc"][0], color='r', linestyle='--', 
                    label=f'Final Test Acc: {results["test_acc"][0]:.2f}%')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title(f"Training vs Test Accuracy {title_suffix}")
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    if checkpoint_dir:
        plt.savefig(checkpoint_dir / "training_curves.png", dpi=300, bbox_inches='tight')
    else:
        plt.savefig("training_curves.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nMulti-GPU T4 Optimized Training Summary:")
    print(f"Number of GPUs used: {num_gpus}")
    print(f"Total epochs: {NUM_EPOCHS}")
    print(f"Training dataset size: {len(full_train_dataset)}")
    print(f"Batch size: {BATCH_SIZE} (total) / {BATCH_SIZE//max(1,num_gpus)} (per GPU)")
    print(f"Model parameters: {trainable_params:,}")
    print(f"Mixed precision: {'Enabled' if scaler else 'Disabled'}")
    print(f"DataParallel: {'Enabled' if num_gpus > 1 else 'Disabled'}")
    print(f"Total training samples processed: {NUM_EPOCHS * len(full_train_dataset)}")
    if results.get("test_acc"):
        print(f"Final test accuracy: {results['test_acc'][0]:.2f}%")


# =============================================================================
# 8. KAGGLE/NOTEBOOK EXECUTION FUNCTIONS
# =============================================================================

def run_training():
    """
    Function to run training in Kaggle/Notebook environment.
    Call this function instead of main() in notebooks.
    """
    # Set random seeds for reproducibility
    set_seeds(42)
    
    # Run the main training function
    main()


def example_inference():
    """
    Example function showing how to use the trained multi-GPU model for inference.
    Optimized for T4 GPU inference.
    """
    # Load the trained model - handles both single and multi-GPU trained models
    # model = ViT(
    #     classes=101,
    #     embeddings_dimension=512,
    #     num_heads=8,
    #     num_of_encoder_layers=8,
    #     mlp_size=2048
    # )
    # 
    # # Load checkpoint
    # checkpoint = torch.load("path/to/your/checkpoint.pth", map_location=device)
    # 
    # # Handle both single-GPU and multi-GPU trained models
    # if 'num_gpus' in checkpoint and checkpoint['num_gpus'] > 1:
    #     # Model was trained with DataParallel
    #     model.load_state_dict(checkpoint['model_state_dict'])
    # else:
    #     # Model was trained on single GPU
    #     model.load_state_dict(checkpoint['model_state_dict'])
    # 
    # # For inference, we don't need DataParallel (use single GPU)
    # model.to(device)
    # model.eval()
    
    # Load class names (you'll need to create this list based on your dataset)
    # class_names = ["class_1", "class_2", ...]  # Replace with actual class names
    
    # Make a prediction on a sample image with mixed precision
    # with torch.amp.autocast(device_type='cuda'):
    #     pred_and_plot_image(model=model,
    #                         image_path="path/to/your/image.jpg",
    #                         class_names=class_names,
    #                         transform=get_data_transforms(),
    #                         device=device)
    pass


def load_multi_gpu_model(checkpoint_path, num_classes=101):
    """
    Utility function to properly load a model trained with multi-GPU setup.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        num_classes (int): Number of classes in the model
    
    Returns:
        model: Loaded model ready for inference
    """
    # Create model
    model = ViT(
        classes=num_classes,
        embeddings_dimension=512,
        num_heads=8,
        num_of_encoder_layers=8,
        mlp_size=2048
    )
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Move to device and set to eval mode
    model.to(device)
    model.eval()
    
    print(f"✅ Model loaded from {checkpoint_path}")
    if 'num_gpus' in checkpoint:
        print(f"📊 Model was trained using {checkpoint['num_gpus']} GPU(s)")
    if 'test_acc' in checkpoint:
        print(f"🎯 Final test accuracy: {checkpoint['test_acc']:.2f}%")
    
    return model


# =============================================================================
# 9. SCRIPT EXECUTION
# =============================================================================

if __name__ == "__main__":
    # This will only run when executing as a script, not in notebook
    run_training()

# For Kaggle/Notebook users, uncomment the line below:
# run_training() 

Using device: cuda
Number of GPUs available: 2
GPU 0: Tesla T4
GPU 0 Memory: 14.7 GB
GPU 0 CUDA Capability: (7, 5)
GPU 1: Tesla T4
GPU 1 Memory: 14.7 GB
GPU 1 CUDA Capability: (7, 5)

🚀 Multi-GPU Training Enabled with 2 GPUs!
This will significantly speed up training and allow larger batch sizes.
Using device: cuda
Found dataset at: /kaggle/input/food41
🚀 Multi-GPU Setup: Using batch size 32 (16 per GPU)
Total training images available: 75750
Using the full training dataset (75750 images) for 25 epochs
Batch size optimized for 2 GPU(s): 32

🔥 Setting up DataParallel for 2 GPUs...
✅ DataParallel setup complete!
Total parameters: 25,767,013
Trainable parameters: 25,767,013
Parameters per GPU: ~12,883,506

STARTING MULTI-GPU TRAINING ON 2 T4 GPUs (Mixed Precision)

=== Epoch 1/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 1 | Train loss: 4.2028 | Train acc: 6.67% | LR: 0.000199
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 2/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 2 | Train loss: 3.8185 | Train acc: 12.52% | LR: 0.000197
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 3/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 3 | Train loss: 3.5897 | Train acc: 16.50% | LR: 0.000193
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 4/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 4 | Train loss: 3.3833 | Train acc: 20.22% | LR: 0.000188
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 5/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 5 | Train loss: 3.1889 | Train acc: 23.67% | LR: 0.000181
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 6/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 6 | Train loss: 3.0037 | Train acc: 27.25% | LR: 0.000173
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 7/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 7 | Train loss: 2.8267 | Train acc: 30.64% | LR: 0.000164
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 8/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 8 | Train loss: 2.6623 | Train acc: 34.19% | LR: 0.000154
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 9/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 9 | Train loss: 2.5064 | Train acc: 37.43% | LR: 0.000143
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 10/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 10 | Train loss: 2.3496 | Train acc: 40.53% | LR: 0.000131
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 11/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 11 | Train loss: 2.1911 | Train acc: 44.16% | LR: 0.000119
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 12/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 12 | Train loss: 2.0336 | Train acc: 47.35% | LR: 0.000107
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 13/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]

Epoch 13 | Train loss: 1.8707 | Train acc: 51.17% | LR: 0.000094
         | Multi-GPU Memory: 1.8GB total (0.9GB avg)

=== Epoch 14/25 ===


Training:   0%|          | 0/2368 [00:00<?, ?it/s]