# ViT Replication Project

This notebook allows you to run the Vision Transformer (ViT) replication project on Google Colab.

## Setup
First, we need to ensure we have the necessary dependencies installed.

In [None]:
# Install dependencies
!pip install timm tqdm

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

## Model Definition
We define the Vision Transformer architecture components: PatchEmbedding, MultiHeadAttention, MLP, TransformerBlock, and the main VisionTransformer class.

In [None]:
class PatchEmbedding(nn.Module):
    """Convert image into patch embeddings"""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # create patch embeddings by conv2d
        self.proj = nn.Conv2d(
            in_channels, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
        
    def forward(self, x):
        # x: [Batch_size, Channel, Height, Width] -> [Batch_size, embed_dim, H', W']
        x = self.proj(x) 
        Batch_size, Channel, Height, Width = x.shape
        # [Batch_size, embed_dim, H', W'] -> [Batch_size, num_patches, embed_dim]
        x = x.flatten(2).transpose(1, 2)  
        return x

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention"""
    
    # embed_dim must be divisible by num_heads
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: [Batch_size, 197, 768]
        B, N, C = x.shape
        
        # Generate Q, K, V
        # [Batch_size, 197, 768] -> [Batch_size, 197, 3, 12, 64] -> [3, Batch_size, 12, 197, 64]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        
        # [3, Batch_size, 12, 197, 64] -> [Batch_size, 12, 197, 64]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        # [Batch_size, 12, 197, 64] -> [Batch_size, 12, 197, 197]
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        # transfer attention to values
        # [Batch_size, 12, 197, 197] -> [Batch_size, 197, 768]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x

class MLP(nn.Module):
    """
    Feed-forward network
    embed_dim -> hidden_dim(expand mlp_ratio times) -> embed_dim
    """
    
    def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer Block Only Encoder"""
    
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)
        
    def forward(self, x):
        # x shape: [Batch_size, num_patches+1, embed_dim]
        # residual connection
        x = x + self.attn(self.norm1(x))
        # residual connection
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer Model"""
    
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.0,
        emb_dropout=0.0
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Patch embedding [Batch_size, num_patches, embed_dim]
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        # Learnable class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # Positional embedding 
        # num_patches + 1  including the class token
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        
        # Dropout for embeddings
        self.pos_dropout = nn.Dropout(emb_dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights"""
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.normal_(m.bias, std=1e-6)
            elif isinstance(m, nn.LayerNorm):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None: nn.init.zeros_(m.bias)

        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
        
    def forward(self, x):
        Batch_size = x.shape[0]
        
        # Patch embedding
        # [Batch_size, Channel, Height, Width] -> [Batch_size, num_patches, embed_dim]
        x = self.patch_embed(x)  

        # Add class token
        cls_tokens = self.cls_token.expand(Batch_size, -1, -1)  # [Batch_size, 1, embed_dim]
        
        # [Batch_size, num_patches, embed_dim] -> [Batch_size, num_patches+1, embed_dim]
        x = torch.cat([cls_tokens, x], dim=1)  
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Apply transformer blocks
        # [Batch_size, num_patches+1, embed_dim] -> [Batch_size, num_patches+1, embed_dim]
        for block in self.blocks:
            x = block(x)
        
        # Final layer norm
        x = self.norm(x)
        
        # use class tokens for classification
        # [Batch_size, num_patches+1, embed_dim] -> [Batch_size, embed_dim]
        cls_token_final = x[:, 0]  

        # [Batch_size, embed_dim] -> [Batch_size, num_classes]
        class_vectors = self.head(cls_token_final)  
        
        return class_vectors

## Helper Functions
Functions for loading pretrained weights and data.

In [None]:
def load_pretrained_weights(model, timm_model_name, num_classes):
    """Load pretrained weights from timm library"""
    print(f"Loading pretrained weights from timm: {timm_model_name}")
    
    # Load pretrained model from timm
    pretrained_model = timm.create_model(timm_model_name, pretrained=True, num_classes=1000)
    
    # Get state dicts
    our_state_dict = model.state_dict()
    pretrained_state_dict = pretrained_model.state_dict()
    
    # Debug: Print key names to understand the structure
    print(f"\nDebug: Our model has {len(our_state_dict)} parameters")
    print(f"Debug: Pretrained model has {len(pretrained_state_dict)} parameters")
    print(f"Debug: Sample our keys (first 10):")
    for i, key in enumerate(list(our_state_dict.keys())[:10]):
        print(f"  {i+1}. {key}")
    print(f"Debug: Sample pretrained keys (first 10):")
    for i, key in enumerate(list(pretrained_state_dict.keys())[:10]):
        print(f"  {i+1}. {key}")
    
    # Create a mapping from our keys to pretrained keys
    # timm ViT models typically use the same structure, so keys should match
    key_mapping = {}
    for our_key in our_state_dict.keys():
        # Try direct match first (most common case)
        if our_key in pretrained_state_dict:
            key_mapping[our_key] = our_key
        else:
            # If direct match fails, try to find the corresponding key
            # This handles cases where timm might use slightly different naming
            found_match = False
            
            # Extract the key parts
            our_parts = our_key.split('.')
            
            # Try to find matching key in pretrained model
            for pretrained_key in pretrained_state_dict.keys():
                pretrained_parts = pretrained_key.split('.')
                
                # Match if the last parts are the same and structure is similar
                if len(our_parts) == len(pretrained_parts):
                    # Check if the last 2-3 parts match (handles blocks.0.attn vs blocks.0.attn)
                    if our_parts[-1] == pretrained_parts[-1]:
                        # For blocks, also check the layer number and component type
                        if 'blocks' in our_key and 'blocks' in pretrained_key:
                            # Extract block number and component
                            our_block_idx = None
                            pretrained_block_idx = None
                            for i, part in enumerate(our_parts):
                                if part == 'blocks' and i+1 < len(our_parts):
                                    try:
                                        our_block_idx = int(our_parts[i+1])
                                        break
                                    except:
                                        pass
                            for i, part in enumerate(pretrained_parts):
                                if part == 'blocks' and i+1 < len(pretrained_parts):
                                    try:
                                        pretrained_block_idx = int(pretrained_parts[i+1])
                                        break
                                    except:
                                        pass
                            
                            # Check if block indices match and component types match
                            if our_block_idx == pretrained_block_idx:
                                # Check component type (attn, mlp, norm)
                                our_comp = '.'.join(our_parts[our_parts.index('blocks')+2:])
                                pretrained_comp = '.'.join(pretrained_parts[pretrained_parts.index('blocks')+2:])
                                if our_comp == pretrained_comp:
                                    key_mapping[our_key] = pretrained_key
                                    found_match = True
                                    break
                        elif our_parts[-2:] == pretrained_parts[-2:]:
                            # For non-block keys, match last 2 parts
                            key_mapping[our_key] = pretrained_key
                            found_match = True
                            break
    
    # Try to load matching weights
    loaded_keys = []
    missing_keys = []
    shape_mismatch_keys = []
    
    for our_key in our_state_dict.keys():
        pretrained_key = key_mapping.get(our_key, our_key)
        
        if pretrained_key in pretrained_state_dict:
            if our_state_dict[our_key].shape == pretrained_state_dict[pretrained_key].shape:
                our_state_dict[our_key] = pretrained_state_dict[pretrained_key]
                loaded_keys.append(our_key)
            else:
                shape_mismatch_keys.append(f"{our_key} (our: {our_state_dict[our_key].shape} vs pretrained: {pretrained_state_dict[pretrained_key].shape})")
        else:
            missing_keys.append(our_key)
    
    # Handle classification head separately
    if 'head.weight' in pretrained_state_dict and 'head.bias' in pretrained_state_dict:
        pretrained_head_weight = pretrained_state_dict['head.weight']
        pretrained_head_bias = pretrained_state_dict['head.bias']
        
        if num_classes == 1000:
            if 'head.weight' not in loaded_keys:
                our_state_dict['head.weight'] = pretrained_head_weight
                our_state_dict['head.bias'] = pretrained_head_bias
                loaded_keys.extend(['head.weight', 'head.bias'])
        else:
            print(f"Note: Re-initializing head for {num_classes} classes (pretrained had {pretrained_head_weight.shape[0]})")
    
    # Load the state dict
    model.load_state_dict(our_state_dict, strict=False)
    
    print(f"\nSuccessfully loaded {len(loaded_keys)} pretrained layers")
    if missing_keys:
        print(f"Warning: {len(missing_keys)} keys not found in pretrained model (first 5): {missing_keys[:5]}")
    if shape_mismatch_keys:
        print(f"Warning: {len(shape_mismatch_keys)} layers have shape mismatches (first 3):")
        for mismatch in shape_mismatch_keys[:3]:
            print(f"  {mismatch}")
    
    # Critical check: verify key layers were loaded
    critical_keys = ['patch_embed.proj.weight', 'pos_embed', 'cls_token']
    print(f"\nCritical layer check:")
    for key in critical_keys:
        if key in loaded_keys:
            print(f"  ✓ Loaded: {key}")
        else:
            print(f"  ✗ WARNING: NOT loaded: {key}")
    
    return model

def get_dataloader(batch_size=32, num_workers=2, img_size=224, train=False):
    """Get DataLoader for CIFAR-10"""
    
    # Transform: Resize to 224x224 for ViT, ImageNet normalization
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    root = './data'
    dataset = datasets.CIFAR10(root=root, train=train, download=True, transform=transform)
    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return loader, classes

def evaluate_model(model, data_loader, device, classes):
    """Evaluate model"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc='Evaluating'):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    acc = 100. * correct / total
    print(f'Accuracy: {acc:.2f}%')
    
    # Additional diagnostic info
    if acc < 15.0:  # Very low accuracy
        print(f'\n⚠️  WARNING: Very low accuracy ({acc:.2f}%)!')
        print('This is likely because:')
        print('1. Classification head is randomly initialized (expected if num_classes != 1000)')
        print('2. Model needs fine-tuning on CIFAR10 before evaluation')
        print('3. Random guessing would give ~10% accuracy on CIFAR10')
        print('\nTo fix: You need to fine-tune the model on CIFAR10 training set first!')
    
    return acc

def verify_weight_loading(model, timm_model_name):
    """Verify that critical weights were loaded correctly"""
    print(f"\n{'='*50}")
    print("Weight Loading Verification")
    print(f"{'='*50}")
    
    # Load timm model for comparison
    pretrained_model = timm.create_model(timm_model_name, pretrained=True, num_classes=1000)
    pretrained_state_dict = pretrained_model.state_dict()
    our_state_dict = model.state_dict()
    
    # Check critical layers
    critical_layers = {
        'patch_embed.proj.weight': 'Patch embedding projection',
        'pos_embed': 'Positional embeddings',
        'cls_token': 'Class token',
        'norm.weight': 'Final layer norm weight',
        'blocks.0.norm1.weight': 'First transformer block norm'
    }
    
    print("\nCritical layer verification:")
    all_loaded = True
    for key, description in critical_layers.items():
        if key in our_state_dict and key in pretrained_state_dict:
            our_weight = our_state_dict[key]
            pretrained_weight = pretrained_state_dict[key]
            
            if torch.allclose(our_weight, pretrained_weight, atol=1e-5):
                print(f"  ✓ {description} ({key}): CORRECTLY LOADED")
            else:
                print(f"  ✗ {description} ({key}): MISMATCH!")
                print(f"    Our: {our_weight.shape}, Pretrained: {pretrained_weight.shape}")
                all_loaded = False
        else:
            print(f"  ✗ {description} ({key}): NOT FOUND")
            all_loaded = False
    
    if all_loaded:
        print("\n✓ All critical layers loaded correctly!")
    else:
        print("\n✗ Some critical layers may not be loaded correctly!")
    
    return all_loaded

def get_dataloader_with_augmentation(batch_size=32, num_workers=2, img_size=224, train=True):
    """Get DataLoader with data augmentation for training CIFAR-10"""
    
    if train:
        # Training: use data augmentation
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(img_size, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        # Validation/Test: no augmentation
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    root = './data'
    dataset = datasets.CIFAR10(root=root, train=train, download=True, transform=transform)
    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return loader, classes

def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch} [Train]')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='[Val]')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def fine_tune_model(
    model,
    epochs=10,
    batch_size=64,
    learning_rate=1e-4,
    weight_decay=0.01,
    freeze_backbone=False,
    device='cuda',
    save_path='best_model.pth'
):
    """
    Fine-tune the model on CIFAR10
    
    Args:
        model: The VisionTransformer model
        epochs: Number of training epochs
        batch_size: Batch size for training
        learning_rate: Learning rate
        weight_decay: Weight decay for optimizer
        freeze_backbone: If True, only train the classification head
        device: Device to train on
        save_path: Path to save the best model
    """
    print(f"\n{'='*60}")
    print(f"Fine-tuning on CIFAR10")
    print(f"{'='*60}")
    print(f"Epochs: {epochs}")
    print(f"Batch size: {batch_size}")
    print(f"Learning rate: {learning_rate}")
    print(f"Freeze backbone: {freeze_backbone}")
    print(f"{'='*60}\n")
    
    # Freeze backbone if requested
    if freeze_backbone:
        print("Freezing backbone, only training classification head...")
        for name, param in model.named_parameters():
            if 'head' not in name:
                param.requires_grad = False
        # Count trainable parameters
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100.*trainable_params/total_params:.2f}%)")
    else:
        print("Training entire model (end-to-end fine-tuning)...")
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Trainable parameters: {trainable_params:,}")
    
    # Get data loaders
    print("\nLoading datasets...")
    train_loader, classes = get_dataloader_with_augmentation(
        batch_size=batch_size, num_workers=2, train=True
    )
    val_loader, _ = get_dataloader_with_augmentation(
        batch_size=batch_size, num_workers=2, train=False
    )
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    
    # Use different learning rates for backbone and head if not freezing
    if freeze_backbone:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
    else:
        # Use smaller LR for backbone, larger for head
        backbone_params = []
        head_params = []
        for name, param in model.named_parameters():
            if param.requires_grad:
                if 'head' in name:
                    head_params.append(param)
                else:
                    backbone_params.append(param)
        
        optimizer = torch.optim.AdamW([
            {'params': backbone_params, 'lr': learning_rate * 0.1},  # Smaller LR for pretrained
            {'params': head_params, 'lr': learning_rate}  # Normal LR for new head
        ], weight_decay=weight_decay)
        print(f"Using different LRs: backbone={learning_rate*0.1}, head={learning_rate}")
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs, eta_min=learning_rate * 0.01
    )
    
    # Training loop
    best_val_acc = 0.0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    
    print(f"\nStarting training...\n")
    for epoch in range(1, epochs + 1):
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Validate
        val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Print epoch summary
        print(f"\nEpoch {epoch}/{epochs} Summary:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  Learning Rate: {current_lr:.6f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'train_acc': train_acc,
            }, save_path)
            print(f"  ✓ Saved best model (Val Acc: {val_acc:.2f}%)")
        
        print("-" * 60)
    
    print(f"\n{'='*60}")
    print(f"Training completed!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Model saved to: {save_path}")
    print(f"{'='*60}\n")
    
    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc
    }

## Execution
Run the evaluation on CIFAR-10 and/or CIFAR-100.

## 问题诊断

如果准确率很低（如7-10%），可能的原因：

1. **分类头随机初始化**：当类别数从1000变为10时，分类头被重新初始化。**必须进行微调（fine-tuning）才能获得好的性能**。

2. **权重加载不完整**：检查上面的调试输出，确认关键层（patch_embed, pos_embed, cls_token）是否正确加载。

3. **模型未设置为评估模式**：确保在评估时调用 `model.eval()`。

**解决方案**：
- 如果只是测试预训练模型，应该使用 `num_classes=1000` 并在ImageNet上测试
- 如果要在CIFAR10上使用，**必须进行微调训练**，不能直接评估


## Fine-tuning on CIFAR10

Fine-tune the pretrained ViT model on CIFAR10. You can choose to:
- **Freeze backbone**: Only train the classification head (faster, less memory)
- **End-to-end**: Train the entire model (slower, better performance)


In [None]:
# Fine-tuning Configuration
model_name = 'vit_base_patch16_224'
num_classes = 10  # CIFAR10 has 10 classes

# Training hyperparameters
epochs = 10              # Number of training epochs
batch_size = 64           # Batch size (adjust based on GPU memory)
learning_rate = 1e-4      # Learning rate
weight_decay = 0.01       # Weight decay for regularization
freeze_backbone = False   # Set to True to only train classification head

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
print(f"\nCreating model {model_name} for {num_classes} classes...")
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=num_classes,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0
)

# Load pretrained weights
model = load_pretrained_weights(model, model_name, num_classes)
model = model.to(device)

# Fine-tune the model
history = fine_tune_model(
    model=model,
    epochs=epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    freeze_backbone=freeze_backbone,
    device=device,
    save_path='vit_cifar10_best.pth'
)

print("Fine-tuning completed!")
print(f"Best validation accuracy: {history['best_val_acc']:.2f}%")


## Evaluate Fine-tuned Model

Load the best model and evaluate on test set.


In [None]:
# Load the best model and evaluate
checkpoint_path = 'vit_cifar10_best.pth'
num_classes = 10  # CIFAR10 has 10 classes
batch_size = 64   # Batch size for evaluation

# Create model (same architecture)
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=num_classes,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0
)
model = model.to(device)

# Load checkpoint
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']}")
print(f"Validation accuracy when saved: {checkpoint['val_acc']:.2f}%")

# Evaluate on test set
print(f"\n{'='*60}")
print("Evaluating on test set...")
print(f"{'='*60}")
test_loader, classes = get_dataloader(batch_size=batch_size, num_workers=2, train=False)
test_acc = evaluate_model(model, test_loader, device, classes)
print(f"\nFinal test accuracy: {test_acc:.2f}%")


In [None]:
# Configuration for evaluating pretrained model (without fine-tuning)
batch_size = 64           # Adjust based on Colab GPU memory (e.g. 32, 64, 128)
model_name = 'vit_base_patch16_224'
num_classes = 10          # CIFAR10 has 10 classes

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
print(f"\n{'='*20} Evaluating on CIFAR10 {'='*20}")
print(f"Creating model {model_name} for {num_classes} classes...")
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=num_classes,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0
)

# Load weights
model = load_pretrained_weights(model, model_name, num_classes)
model = model.to(device)

# Load data
print(f"Loading CIFAR10 test set...")
# Ensure we're running with appropriate num_workers for Colab
loader, classes = get_dataloader(batch_size=batch_size, num_workers=2, train=False)

# Evaluate
evaluate_model(model, loader, device, classes)