# Vision Transformer Training for Car Classification

This notebook trains a Vision Transformer (ViT) model for multi-label car classification:
- **Brand** classification (e.g., BMW, Audi, Toyota)
- **Model** classification (e.g., M3, A4, Camry)
- **Year** classification (e.g., 2000, 2010, 2020)

The model uses the Stanford Cars196 dataset with a three-headed classifier architecture.

## 1. Import Required Libraries

In [None]:
import sys
from pathlib import Path
import os

# Setup project paths
project_root = Path(os.getcwd()).parent
sys.path.append(str(project_root))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
from datetime import datetime

# Import custom modules
from Model.Model import VisionTransformer, create_vit_base
from Utilities.Cars196 import create_dataloaders

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n{'='*60}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Available GPUs: {torch.cuda.device_count()}")
else:
    print("No GPU available - using CPU")
print(f"{'='*60}")

## 2. Setup Directories and Configuration

In [None]:
# Create necessary directories
project_root = Path(os.getcwd()).parent
cache_dir = project_root / 'cache'
checkpoint_dir = project_root / 'checkpoints'
results_dir = project_root / 'results'

# Create directories
cache_dir.mkdir(exist_ok=True)
checkpoint_dir.mkdir(exist_ok=True)
results_dir.mkdir(exist_ok=True)

print(f"Directory Setup:")
print(f"  Project Root: {project_root}")
print(f"  Cache:        {cache_dir}")
print(f"  Checkpoints:  {checkpoint_dir}")
print(f"  Results:      {results_dir}")

# Training configuration
config = {
    'batch_size': 32,
    'img_size': 224,
    'num_workers': 4,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'num_epochs': 50,
    'patience': 10,  # Early stopping patience
    'save_frequency': 5,  # Save checkpoint every N epochs
    
    # Model configuration
    'embed_dim': 768,
    'depth': 12,
    'num_heads': 12,
    'mlp_ratio': 4.0,
    'dropout': 0.1,
}

print("\nTraining Configuration:")
print(json.dumps(config, indent=2))

## 3. Load Dataset

In [None]:
# Create dataloaders with caching
print("Loading Cars196 dataset...")
train_loader, test_loader = create_dataloaders(
    root_dir=None,  # Auto-download if not present
    batch_size=config['batch_size'],
    img_size=config['img_size'],
    num_workers=config['num_workers'],
    auto_download=True,
    cache_dir=str(cache_dir)
)

# Get dataset information
train_dataset = train_loader.dataset
test_dataset = test_loader.dataset
num_classes = train_dataset.get_num_classes()

print(f"\nDataset Statistics:")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of brands: {num_classes['brand']}")
print(f"Number of models: {num_classes['model']}")
print(f"Number of years: {num_classes['year']}")
print(f"Total classes: {num_classes['total']}")

# Visualize a sample batch
sample_images, sample_labels = next(iter(train_loader))
print(f"\nBatch shape: {sample_images.shape}")
print(f"Sample brand indices: {sample_labels['brand'][:5].tolist()}")
print(f"Sample brand names: {sample_labels['brand_name'][:5]}")
print(f"Sample model names: {sample_labels['model_name'][:5]}")
print(f"Sample years: {sample_labels['year_value'][:5]}")

## 4. Initialize Model

In [None]:
# Create Vision Transformer model
model = create_vit_base(
    num_classes_brand=num_classes['brand'],
    num_classes_model=num_classes['model'],
    num_classes_year=num_classes['year'],
    img_size=config['img_size'],
    embed_dim=config['embed_dim'],
    depth=config['depth'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    dropout=config['dropout']
)

# Move model to device
model = model.to(device)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Architecture:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024**2:.2f} MB (float32)")

# Test forward pass
with torch.no_grad():
    test_input = torch.randn(2, 3, config['img_size'], config['img_size']).to(device)
    brand_out, model_out, year_out = model(test_input)
    print(f"\nOutput shapes:")
    print(f"Brand: {brand_out.shape}")
    print(f"Model: {model_out.shape}")
    print(f"Year: {year_out.shape}")

## 5. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train the model for one epoch.
    
    Args:
        model: Vision Transformer model
        dataloader: Training dataloader
        criterion: Loss function
        optimizer: Optimizer
        device: Device to train on
        epoch: Current epoch number
        
    Returns:
        Dictionary containing average losses and accuracies
    """
    model.train()
    
    total_loss = 0
    brand_loss_sum = 0
    model_loss_sum = 0
    year_loss_sum = 0
    
    brand_correct = 0
    model_correct = 0
    year_correct = 0
    total_samples = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch}')
    
    for batch_idx, (images, labels) in enumerate(progress_bar):
        images = images.to(device)
        brand_labels = labels['brand'].to(device)
        model_labels = labels['model'].to(device)
        year_labels = labels['year'].to(device)
        
        # Forward pass
        brand_out, model_out, year_out = model(images)
        
        # Compute losses
        brand_loss = criterion(brand_out, brand_labels)
        model_loss = criterion(model_out, model_labels)
        year_loss = criterion(year_out, year_labels)
        
        # Combined loss (weighted equally)
        loss = brand_loss + model_loss + year_loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate accuracies
        brand_pred = brand_out.argmax(dim=1)
        model_pred = model_out.argmax(dim=1)
        year_pred = year_out.argmax(dim=1)
        
        brand_correct += (brand_pred == brand_labels).sum().item()
        model_correct += (model_pred == model_labels).sum().item()
        year_correct += (year_pred == year_labels).sum().item()
        
        total_loss += loss.item()
        brand_loss_sum += brand_loss.item()
        model_loss_sum += model_loss.item()
        year_loss_sum += year_loss.item()
        total_samples += images.size(0)
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'brand_acc': f'{100 * brand_correct / total_samples:.2f}%',
            'model_acc': f'{100 * model_correct / total_samples:.2f}%',
            'year_acc': f'{100 * year_correct / total_samples:.2f}%'
        })
    
    num_batches = len(dataloader)
    
    return {
        'total_loss': total_loss / num_batches,
        'brand_loss': brand_loss_sum / num_batches,
        'model_loss': model_loss_sum / num_batches,
        'year_loss': year_loss_sum / num_batches,
        'brand_accuracy': 100 * brand_correct / total_samples,
        'model_accuracy': 100 * model_correct / total_samples,
        'year_accuracy': 100 * year_correct / total_samples,
    }

## 6. Testing/Evaluation Functions

In [None]:
def evaluate(model, dataloader, criterion, device, split='Test'):
    """
    Evaluate the model on a dataset.
    
    Args:
        model: Vision Transformer model
        dataloader: Evaluation dataloader
        criterion: Loss function
        device: Device to evaluate on
        split: Name of the split (for logging)
        
    Returns:
        Dictionary containing average losses and accuracies
    """
    model.eval()
    
    total_loss = 0
    brand_loss_sum = 0
    model_loss_sum = 0
    year_loss_sum = 0
    
    brand_correct = 0
    model_correct = 0
    year_correct = 0
    total_samples = 0
    
    # For computing top-5 accuracy
    brand_top5_correct = 0
    model_top5_correct = 0
    
    all_brand_preds = []
    all_model_preds = []
    all_year_preds = []
    all_brand_labels = []
    all_model_labels = []
    all_year_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f'{split} Evaluation')
        
        for images, labels in progress_bar:
            images = images.to(device)
            brand_labels = labels['brand'].to(device)
            model_labels = labels['model'].to(device)
            year_labels = labels['year'].to(device)
            
            # Forward pass
            brand_out, model_out, year_out = model(images)
            
            # Compute losses
            brand_loss = criterion(brand_out, brand_labels)
            model_loss = criterion(model_out, model_labels)
            year_loss = criterion(year_out, year_labels)
            loss = brand_loss + model_loss + year_loss
            
            # Calculate accuracies
            brand_pred = brand_out.argmax(dim=1)
            model_pred = model_out.argmax(dim=1)
            year_pred = year_out.argmax(dim=1)
            
            brand_correct += (brand_pred == brand_labels).sum().item()
            model_correct += (model_pred == model_labels).sum().item()
            year_correct += (year_pred == year_labels).sum().item()
            
            # Top-5 accuracy
            _, brand_top5 = brand_out.topk(5, dim=1)
            _, model_top5 = model_out.topk(5, dim=1)
            
            brand_top5_correct += sum([1 for i, label in enumerate(brand_labels) 
                                       if label in brand_top5[i]])
            model_top5_correct += sum([1 for i, label in enumerate(model_labels) 
                                       if label in model_top5[i]])
            
            total_loss += loss.item()
            brand_loss_sum += brand_loss.item()
            model_loss_sum += model_loss.item()
            year_loss_sum += year_loss.item()
            total_samples += images.size(0)
            
            # Store predictions for confusion matrix
            all_brand_preds.extend(brand_pred.cpu().numpy())
            all_model_preds.extend(model_pred.cpu().numpy())
            all_year_preds.extend(year_pred.cpu().numpy())
            all_brand_labels.extend(brand_labels.cpu().numpy())
            all_model_labels.extend(model_labels.cpu().numpy())
            all_year_labels.extend(year_labels.cpu().numpy())
    
    num_batches = len(dataloader)
    
    results = {
        'total_loss': total_loss / num_batches,
        'brand_loss': brand_loss_sum / num_batches,
        'model_loss': model_loss_sum / num_batches,
        'year_loss': year_loss_sum / num_batches,
        'brand_accuracy': 100 * brand_correct / total_samples,
        'model_accuracy': 100 * model_correct / total_samples,
        'year_accuracy': 100 * year_correct / total_samples,
        'brand_top5_accuracy': 100 * brand_top5_correct / total_samples,
        'model_top5_accuracy': 100 * model_top5_correct / total_samples,
        'predictions': {
            'brand': all_brand_preds,
            'model': all_model_preds,
            'year': all_year_preds,
        },
        'labels': {
            'brand': all_brand_labels,
            'model': all_model_labels,
            'year': all_year_labels,
        }
    }
    
    return results

## 7. Training Loop with Checkpointing

In [None]:
def train_model(model, train_loader, test_loader, config, device, 
                checkpoint_dir, results_dir):
    """
    Complete training loop with validation and checkpointing.
    
    Args:
        model: Vision Transformer model
        train_loader: Training dataloader
        test_loader: Test dataloader
        config: Training configuration dictionary
        device: Device to train on
        checkpoint_dir: Directory to save checkpoints
        results_dir: Directory to save results
        
    Returns:
        Dictionary containing training history
    """
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # Training history
    history = {
        'train_loss': [],
        'train_brand_acc': [],
        'train_model_acc': [],
        'train_year_acc': [],
        'test_loss': [],
        'test_brand_acc': [],
        'test_model_acc': [],
        'test_year_acc': [],
        'learning_rates': []
    }
    
    best_test_loss = float('inf')
    patience_counter = 0
    
    print(f"\n{'='*60}")
    print("Starting Training")
    print(f"{'='*60}\n")
    
    for epoch in range(1, config['num_epochs'] + 1):
        print(f"\nEpoch {epoch}/{config['num_epochs']}")
        print("-" * 60)
        
        # Train
        train_metrics = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        
        # Evaluate
        test_metrics = evaluate(
            model, test_loader, criterion, device, split='Test'
        )
        
        # Update learning rate
        current_lr_before = optimizer.param_groups[0]['lr']
        scheduler.step(test_metrics['total_loss'])
        current_lr = optimizer.param_groups[0]['lr']
        
        # Print if learning rate changed
        if current_lr != current_lr_before:
            print(f"\nLearning rate reduced: {current_lr_before:.6f} -> {current_lr:.6f}")
        
        # Store history
        history['train_loss'].append(train_metrics['total_loss'])
        history['train_brand_acc'].append(train_metrics['brand_accuracy'])
        history['train_model_acc'].append(train_metrics['model_accuracy'])
        history['train_year_acc'].append(train_metrics['year_accuracy'])
        history['test_loss'].append(test_metrics['total_loss'])
        history['test_brand_acc'].append(test_metrics['brand_accuracy'])
        history['test_model_acc'].append(test_metrics['model_accuracy'])
        history['test_year_acc'].append(test_metrics['year_accuracy'])
        history['learning_rates'].append(current_lr)
        
        # Print epoch summary
        print(f"\nEpoch {epoch} Summary:")
        print(f"  Train Loss: {train_metrics['total_loss']:.4f}")
        print(f"  Train Brand Acc: {train_metrics['brand_accuracy']:.2f}%")
        print(f"  Train Model Acc: {train_metrics['model_accuracy']:.2f}%")
        print(f"  Train Year Acc: {train_metrics['year_accuracy']:.2f}%")
        print(f"\n  Test Loss: {test_metrics['total_loss']:.4f}")
        print(f"  Test Brand Acc: {test_metrics['brand_accuracy']:.2f}%")
        print(f"  Test Model Acc: {test_metrics['model_accuracy']:.2f}%")
        print(f"  Test Year Acc: {test_metrics['year_accuracy']:.2f}%")
        print(f"  Test Brand Top-5 Acc: {test_metrics['brand_top5_accuracy']:.2f}%")
        print(f"  Test Model Top-5 Acc: {test_metrics['model_top5_accuracy']:.2f}%")
        print(f"\n  Learning Rate: {current_lr:.6f}")
        
        # Save checkpoint
        if epoch % config['save_frequency'] == 0:
            checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_metrics': train_metrics,
                'test_metrics': test_metrics,
                'history': history,
                'config': config
            }, checkpoint_path)
            print(f"\n  Checkpoint saved: {checkpoint_path.name}")
        
        # Save best model
        if test_metrics['total_loss'] < best_test_loss:
            best_test_loss = test_metrics['total_loss']
            patience_counter = 0
            
            best_model_path = checkpoint_dir / 'best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'test_metrics': test_metrics,
                'config': config
            }, best_model_path)
            print(f"  ✓ Best model saved (loss: {best_test_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  No improvement. Patience: {patience_counter}/{config['patience']}")
        
        # Early stopping
        if patience_counter >= config['patience']:
            print(f"\nEarly stopping triggered after {epoch} epochs")
            break
    
    # Save final model
    final_model_path = checkpoint_dir / 'final_model.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'history': history,
        'config': config
    }, final_model_path)
    
    # Save training history
    history_path = results_dir / 'training_history.json'
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    
    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"{'='*60}")
    print(f"Best test loss: {best_test_loss:.4f}")
    print(f"Final model saved: {final_model_path}")
    print(f"Training history saved: {history_path}")
    
    return history


## 8. Start Training

In [None]:
# Train the model
history = train_model(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    config=config,
    device=device,
    checkpoint_dir=checkpoint_dir,
    results_dir=results_dir
)

## 9. Visualize Training Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

epochs = range(1, len(history['train_loss']) + 1)

# Loss plot
axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs, history['test_loss'], 'r-', label='Test Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Test Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Brand accuracy
axes[0, 1].plot(epochs, history['train_brand_acc'], 'b-', label='Train', linewidth=2)
axes[0, 1].plot(epochs, history['test_brand_acc'], 'r-', label='Test', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Brand Classification Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Model accuracy
axes[1, 0].plot(epochs, history['train_model_acc'], 'b-', label='Train', linewidth=2)
axes[1, 0].plot(epochs, history['test_model_acc'], 'r-', label='Test', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy (%)')
axes[1, 0].set_title('Model Classification Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Year accuracy
axes[1, 1].plot(epochs, history['train_year_acc'], 'b-', label='Train', linewidth=2)
axes[1, 1].plot(epochs, history['test_year_acc'], 'r-', label='Test', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy (%)')
axes[1, 1].set_title('Year Classification Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(results_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTraining curves saved to: {results_dir / 'training_curves.png'}")

## 10. Final Evaluation on Test Set

In [None]:
# Load best model
best_model_path = checkpoint_dir / 'best_model.pth'
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']}")

# Evaluate on test set
criterion = nn.CrossEntropyLoss()
final_results = evaluate(model, test_loader, criterion, device, split='Final Test')

print("\n" + "="*60)
print("FINAL TEST SET RESULTS")
print("="*60)
print(f"\nOverall Test Loss: {final_results['total_loss']:.4f}")
print(f"\nBrand Classification:")
print(f"  Top-1 Accuracy: {final_results['brand_accuracy']:.2f}%")
print(f"  Top-5 Accuracy: {final_results['brand_top5_accuracy']:.2f}%")
print(f"  Loss: {final_results['brand_loss']:.4f}")

print(f"\nModel Classification:")
print(f"  Top-1 Accuracy: {final_results['model_accuracy']:.2f}%")
print(f"  Top-5 Accuracy: {final_results['model_top5_accuracy']:.2f}%")
print(f"  Loss: {final_results['model_loss']:.4f}")

print(f"\nYear Classification:")
print(f"  Accuracy: {final_results['year_accuracy']:.2f}%")
print(f"  Loss: {final_results['year_loss']:.4f}")

# Save final results
final_results_summary = {
    'test_loss': final_results['total_loss'],
    'brand_accuracy': final_results['brand_accuracy'],
    'brand_top5_accuracy': final_results['brand_top5_accuracy'],
    'model_accuracy': final_results['model_accuracy'],
    'model_top5_accuracy': final_results['model_top5_accuracy'],
    'year_accuracy': final_results['year_accuracy'],
    'epoch': checkpoint['epoch']
}

with open(results_dir / 'final_results.json', 'w') as f:
    json.dump(final_results_summary, f, indent=2)

print(f"\nFinal results saved to: {results_dir / 'final_results.json'}")

## 11. Performance Visualization

In [None]:
# Create performance summary visualization
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

tasks = ['Brand', 'Model', 'Year']
top1_accuracies = [
    final_results['brand_accuracy'],
    final_results['model_accuracy'],
    final_results['year_accuracy']
]

x = np.arange(len(tasks))
width = 0.35

bars1 = ax.bar(x - width/2, top1_accuracies, width, label='Top-1 Accuracy', 
               color='steelblue', alpha=0.8)

# Add Top-5 for brand and model
top5_accuracies = [
    final_results['brand_top5_accuracy'],
    final_results['model_top5_accuracy'],
    0  # Year doesn't have top-5
]
bars2 = ax.bar(x + width/2, top5_accuracies, width, label='Top-5 Accuracy', 
               color='coral', alpha=0.8)

ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Final Test Set Performance by Task', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(tasks, fontsize=12)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim(0, 100)

# Add value labels on bars
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.1f}%',
                   ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig(results_dir / 'final_performance.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nPerformance visualization saved to: {results_dir / 'final_performance.png'}")

## 12. Model Inference Example

In [None]:
# Example inference on a batch of test images
def predict_car(model, image_tensor, dataset, device):
    """
    Predict car attributes from an image.
    
    Args:
        model: Trained model
        image_tensor: Preprocessed image tensor
        dataset: Dataset object for label mapping
        device: Device to run inference on
        
    Returns:
        Dictionary with predictions
    """
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)
        brand_out, model_out, year_out = model(image_tensor)
        
        # Get predictions
        brand_pred = brand_out.argmax(dim=1).item()
        model_pred = model_out.argmax(dim=1).item()
        year_pred = year_out.argmax(dim=1).item()
        
        # Get confidence scores
        brand_conf = torch.softmax(brand_out, dim=1).max().item()
        model_conf = torch.softmax(model_out, dim=1).max().item()
        year_conf = torch.softmax(year_out, dim=1).max().item()
        
        # Map to labels
        brand_name = dataset.unique_brands[brand_pred]
        model_name = dataset.unique_models[model_pred]
        year_value = dataset.unique_years[year_pred]
        
        return {
            'brand': brand_name,
            'brand_confidence': brand_conf,
            'model': model_name,
            'model_confidence': model_conf,
            'year': year_value,
            'year_confidence': year_conf
        }

# Test on a few samples
print("Sample Predictions:\n")
print("="*80)

sample_images, sample_labels = next(iter(test_loader))
num_samples = min(5, len(sample_images))

for i in range(num_samples):
    predictions = predict_car(model, sample_images[i], test_dataset, device)
    
    actual_brand = sample_labels['brand_name'][i]
    actual_model = sample_labels['model_name'][i]
    actual_year = sample_labels['year_value'][i]
    
    print(f"\nSample {i+1}:")
    print(f"  Predicted: {predictions['brand']} {predictions['model']} ({predictions['year']})")
    print(f"  Actual:    {actual_brand} {actual_model} ({actual_year})")
    print(f"  Confidence: Brand={predictions['brand_confidence']:.2%}, "
          f"Model={predictions['model_confidence']:.2%}, "
          f"Year={predictions['year_confidence']:.2%}")
    
    # Check if correct
    brand_correct = predictions['brand'] == actual_brand
    model_correct = predictions['model'] == actual_model
    year_correct = predictions['year'] == actual_year
    
    if brand_correct and model_correct and year_correct:
        print("  ✓ All predictions CORRECT!")
    else:
        status = []
        if not brand_correct:
            status.append("Brand")
        if not model_correct:
            status.append("Model")
        if not year_correct:
            status.append("Year")
        print(f"  ✗ Incorrect: {', '.join(status)}")
    print("-"*80)