# From CNNs to ViTs: A Comparative Study in Medical Image Classification

Image classification is a fundamental task in computer vision that involves assigning a label or category to an input image. This project explores two prominent deep learning architectures for medical image classification:

- **Convolutional Neural Networks (CNNs)**: Traditional deep learning models that use convolutional operations to extract hierarchical features from images
- **Vision Transformers (ViTs)**: Modern architecture that adapts the Transformer model (originally designed for NLP) to process images as sequences of patches

### Key Concepts

- **Image Classification**: The process of automatically categorizing images into predefined classes based on their visual content
- **Medical Imaging**: Application of classification to medical scans (CT scans, X-rays, MRIs) for diagnostic assistance
- **Comparative Study**: This project compares the effectiveness of CNNs and ViTs on the Medical MNIST dataset to understand their relative strengths and weaknesses

### References

- **Convolutional Neural Networks**: LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. *Proceedings of the IEEE*, 86(11), 2278-2324.
- **Vision Transformers**: Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. *arXiv preprint arXiv:2010.11929*.
- **Hugging Face**: Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue, C., Moi, A., ... & Rush, A. M. (2020). Transformers: State-of-the-art natural language processing. *Proceedings of the 2020 conference on empirical methods in natural language processing: system demonstrations*, 38-45.



## Set up running environment

In [1]:
%%bash
# ============================================================================
# Environment Setup: Creating Virtual Environment and Installing Dependencies
# ============================================================================
# This cell sets up a clean Python environment for the project by:
# 1. Creating an isolated virtual environment to avoid dependency conflicts
# 2. Activating the virtual environment for this session
# 3. Installing all required packages from requirements.txt

# Step 1: Create a new virtual environment named '.venv'
# This isolates project dependencies from the system Python installation
python3 -m venv .venv

# Step 2: Activate the virtual environment
# This ensures that all subsequent Python commands use packages from .venv
source .venv/bin/activate

# Step 3: Upgrade pip to the latest version for better package management
pip install --upgrade pip

# Step 4: Install all project dependencies from requirements.txt
# This installs all necessary libraries for CNN and ViT model training
pip install -r requirements.txt

echo "✓ Virtual environment created and activated"
echo "✓ All dependencies installed successfully"

✓ Virtual environment created and activated
✓ All dependencies installed successfully


## Dataset: Medical MNIST

The **Medical MNIST** dataset is a collection of medical imaging data that will be downloaded from Kaggle. This dataset serves as an excellent benchmark for comparing the performance of Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs) in medical image classification tasks.

### Dataset Overview
- **Source**: Kaggle
- **Total Classes**: 6
- **Image Type**: Medical imaging scans (CT scans, X-rays, MRI)
- **Purpose**: Multi-class classification of different medical imaging modalities

### Class Definitions

The dataset contains **6 distinct classes**, each representing a different type of medical imaging:

1. **AbdomenCT** - Abdominal Computed Tomography (CT) scans
   - Images of the abdominal region captured using CT imaging technology
   - Used for diagnosing conditions in organs like liver, kidneys, and intestines

2. **BreastMRI** - Breast Magnetic Resonance Imaging (MRI) scans
   - High-resolution MRI images of breast tissue
   - Commonly used for breast cancer detection and diagnosis

3. **ChestCT** - Chest Computed Tomography (CT) scans
   - CT images of the thoracic region (chest area)
   - Used for detecting lung diseases, tumors, and other chest abnormalities

4. **ChestXray** - Chest X-ray images
   - Traditional X-ray images of the chest
   - One of the most common medical imaging techniques for lung and heart assessment

5. **Hand** - Hand X-ray images
   - X-ray images of the hand and wrist
   - Used for diagnosing fractures, arthritis, and other bone-related conditions

6. **HeadCT** - Head Computed Tomography (CT) scans
   - CT scans of the head and brain region
   - Critical for diagnosing brain injuries, tumors, and neurological conditions

### Dataset Characteristics
- Each class contains medical images that require careful analysis and classification
- The dataset presents a challenging classification problem due to the visual similarities between some medical imaging modalities
- This diversity makes it an ideal testbed for evaluating the effectiveness of different deep learning architectures in medical imaging applications


In [2]:
import kagglehub
import os
import shutil
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define the data directory in the current working directory
data_dir = Path("data")

# Check if data directory exists, create if not
if not data_dir.exists():
    data_dir.mkdir(parents=True, exist_ok=True)
    print(f"✓ Created data directory: {data_dir.absolute()}")
else:
    print(f"✓ Data directory already exists: {data_dir.absolute()}")

✓ Data directory already exists: /Users/sztaki/Documents/computer_vision/data


In [4]:
# Check if dataset is already in the data directory
dataset_in_data = data_dir / "medical-mnist-train-test-val"
train_dir = dataset_in_data / "train"
if dataset_in_data.exists() and train_dir.exists() and any(train_dir.iterdir()):
    print(f"✓ Dataset already exists in {data_dir.absolute()}")
    print(f"  Path: {dataset_in_data.absolute()}")
else:
    # Download latest version to cache
    print("Downloading dataset from Kaggle...")
    cache_path = kagglehub.dataset_download("gennadiimanzhos/medical-mnist-train-test-val")
    print(f"✓ Downloaded to cache: {cache_path}")

    # Copy dataset from cache to data directory
    print(f"Copying dataset to {data_dir.absolute()}...")
    if dataset_in_data.exists():
        shutil.rmtree(dataset_in_data)
    shutil.copytree(cache_path, dataset_in_data)
    print(f"✓ Dataset copied to: {dataset_in_data.absolute()}")

✓ Dataset already exists in /Users/sztaki/Documents/computer_vision/data
  Path: /Users/sztaki/Documents/computer_vision/data/medical-mnist-train-test-val


In [5]:
# Set the final path for use in the notebook
dataset_path = dataset_in_data
print(f"\n✓ Dataset ready at: {dataset_path.absolute()}")


✓ Dataset ready at: /Users/sztaki/Documents/computer_vision/data/medical-mnist-train-test-val


## Data Loading and Preprocessing

This section covers the data loading process for the Medical MNIST dataset. The dataset comes pre-split into three distinct sets: **training**, **validation**, and **test** sets. Each set contains images organized by class in separate directories.

### Dataset Structure

The Medical MNIST dataset is organized as follows:
- **Train Set**: Used for model training and learning patterns
- **Validation Set**: Used for hyperparameter tuning and model selection during training
- **Test Set**: Used for final evaluation of model performance

Each split contains 6 class directories:
- `AbdomenCT/`
- `BreastMRI/`
- `ChestCT/`
- `CXR/` (Chest X-ray)
- `Hand/`
- `HeadCT/`

### Data Loading Process

The data loading pipeline will:

1. **Load Training Data**: Load images from the `train/` directory with their corresponding class labels
2. **Load Validation Data**: Load images from the `val/` directory for validation during training
3. **Load Test Data**: Load images from the `test/` directory for final model evaluation
4. **Train-Test Split**: The dataset already provides separate train, validation, and test splits, so we will use these predefined splits directly

### Implementation Details

- Images will be loaded and preprocessed (resizing, normalization, etc.)
- Data loaders will be created with appropriate batch sizes for efficient training
- Data augmentation may be applied to the training set to improve model generalization
- Class labels will be encoded appropriately for the classification task

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path
import os

In [7]:
# Set device - check for GPU/accelerator availability
# Priority: CUDA (NVIDIA GPU) > MPS (Apple Silicon GPU) > CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"✓ CUDA (NVIDIA GPU) detected and will be used")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"✓ MPS (Apple Silicon GPU) detected and will be used")
    print(f"  This will use the GPU cores on your Apple Silicon chip")
else:
    device = torch.device("cpu")
    print(f"⚠ No GPU acceleration available, using CPU")
    print(f"  Note: Training will be slower on CPU")

print(f"\nUsing device: {device}")

✓ MPS (Apple Silicon GPU) detected and will be used
  This will use the GPU cores on your Apple Silicon chip

Using device: mps


In [8]:
# Define paths
train_dir = dataset_path / "train"
val_dir = dataset_path / "val"
test_dir = dataset_path / "test"

# ResNet expects ImageNet-style normalization
# Mean and std for ImageNet (standard for ResNet)
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

# Training transforms (no augmentation - dataset is large enough)
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet input size
    transforms.ToTensor(),
    normalize
])

# No augmentation for validation and test sets
val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

In [9]:
# Load datasets
train_dataset = datasets.ImageFolder(root=str(train_dir), transform=train_transforms)
val_dataset = datasets.ImageFolder(root=str(val_dir), transform=val_test_transforms)
test_dataset = datasets.ImageFolder(root=str(test_dir), transform=val_test_transforms)

In [10]:
# Get class names
class_names = train_dataset.classes
num_classes = len(class_names)
print(f"\nNumber of classes: {num_classes}")
print(f"Class names: {class_names}")

# Create data loaders
batch_size = 32
num_workers = 4 if os.cpu_count() > 4 else 2

# pin_memory only works with CUDA, not MPS
use_pin_memory = torch.cuda.is_available()

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=use_pin_memory
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=use_pin_memory
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=use_pin_memory
)


Number of classes: 6
Class names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']


In [11]:
# Print dataset statistics
print(f"\nDataset Statistics:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples: {len(test_dataset)}")
print(f"  Total samples: {len(train_dataset) + len(val_dataset) + len(test_dataset)}")
print(f"\nDataLoader Configuration:")
print(f"  Batch size: {batch_size}")
print(f"  Number of workers: {num_workers}")
print(f"  Number of batches (train): {len(train_loader)}")
print(f"  Number of batches (val): {len(val_loader)}")
print(f"  Number of batches (test): {len(test_loader)}")

# Verify data loading by getting a sample batch
sample_batch = next(iter(train_loader))
images, labels = sample_batch
print(f"\nSample batch shape:")
print(f"  Images shape: {images.shape}")  # [batch_size, channels, height, width]
print(f"  Labels shape: {labels.shape}")  # [batch_size]
print(f"  Image dtype: {images.dtype}")
print(f"  Label dtype: {labels.dtype}")

print("\n✓ Data loaders created successfully and ready for NN training!")


Dataset Statistics:
  Training samples: 47163
  Validation samples: 5895
  Test samples: 5896
  Total samples: 58954

DataLoader Configuration:
  Batch size: 32
  Number of workers: 4
  Number of batches (train): 1474
  Number of batches (val): 185
  Number of batches (test): 185

Sample batch shape:
  Images shape: torch.Size([32, 3, 224, 224])
  Labels shape: torch.Size([32])
  Image dtype: torch.float32
  Label dtype: torch.int64

✓ Data loaders created successfully and ready for NN training!


## Custom Small Neural Network

After training with the pre-trained ResNet-50 architecture, we will now train a **custom lightweight neural network** from scratch. This provides an interesting comparison between:

- **Transfer Learning (ResNet-50)**: Using a large, pre-trained model with millions of parameters
- **Custom Small Network**: Training a compact, task-specific architecture from scratch

### Architecture Design

The custom neural network will be a simple but effective **Convolutional Neural Network (CNN)** designed specifically for the Medical MNIST classification task:

**Network Structure:**
- **Input Layer**: 224×224×3 RGB images
- **Convolutional Blocks**: Multiple conv layers with increasing depth
  - Convolutional layers with ReLU activation
  - Max pooling for dimensionality reduction
  - Batch normalization for training stability
- **Fully Connected Layers**: Dense layers for classification
- **Output Layer**: 6 classes (one for each medical imaging modality)

**Key Characteristics:**
- **Lightweight**: Significantly fewer parameters than ResNet-50
- **Task-Specific**: Designed from scratch for medical image classification
- **Fast Training**: Smaller model trains faster and requires less memory
- **No Pre-training**: Trained from random initialization (no transfer learning)

### Training Approach

The custom network will be trained using the same dataset and evaluation methodology as ResNet-50, allowing for a fair comparison of:
- **Model Size**: Number of parameters
- **Training Time**: Time to convergence
- **Performance**: Accuracy on validation and test sets
- **Efficiency**: Model size vs. performance trade-off

This comparison will help understand whether the complexity of a large pre-trained model is necessary, or if a simpler custom architecture can achieve comparable results for this specific medical imaging classification task.


In [12]:
import time
from tqdm import tqdm
import copy

# ============================================================================
# Custom CNN Model Definition
# ============================================================================

In [13]:
class CustomCNN(nn.Module):
    def __init__(self, num_classes=6):
        super(CustomCNN, self).__init__()

        # First convolutional block
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 224x224 -> 112x112
        )

        # Second convolutional block
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 112x112 -> 56x56
        )

        # Third convolutional block
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 56x56 -> 28x28
        )

        # Additional pooling to reduce size further
        self.pool = nn.AdaptiveAvgPool2d((7, 7))  # 28x28 -> 7x7

        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = self.fc(x)
        return x

# ============================================================================
# Model Initialization
# ============================================================================
print("Initializing Custom CNN model...")

# Create model instance
custom_model = CustomCNN(num_classes=num_classes)

# Move model to device
custom_model = custom_model.to(device)
print(f"✓ Model initialized and moved to {device}")
print(f"  Total parameters: {sum(p.numel() for p in custom_model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in custom_model.parameters() if p.requires_grad):,}")

Initializing Custom CNN model...
✓ Model initialized and moved to mps
  Total parameters: 3,438,342
  Trainable parameters: 3,438,342


In [14]:
# ============================================================================
# Training Configuration
# ============================================================================
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(custom_model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

num_epochs = 20
best_val_loss = float('inf')
best_model_wts = copy.deepcopy(custom_model.state_dict())
patience = 5
patience_counter = 0

# Training history
history_custom = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

print(f"\nTraining Configuration:")
print(f"  Epochs: {num_epochs}")
print(f"  Learning rate: {optimizer.param_groups[0]['lr']}")
print(f"  Optimizer: Adam")
print(f"  Loss function: CrossEntropyLoss")
print(f"  Early stopping patience: {patience}")


Training Configuration:
  Epochs: 20
  Learning rate: 0.001
  Optimizer: Adam
  Loss function: CrossEntropyLoss
  Early stopping patience: 5


In [15]:

# ============================================================================
# Training and Validation Functions
# ============================================================================
def train_epoch_custom(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(dataloader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update progress bar
        current_batch = total // labels.size(0)
        pbar.set_postfix({
            'loss': f'{running_loss/current_batch:.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

In [16]:

def validate_epoch_custom(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            current_batch = total // labels.size(0)
            pbar.set_postfix({
                'loss': f'{running_loss/current_batch:.4f}',
                'acc': f'{100*correct/total:.2f}%'
            })

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

# ============================================================================
# Training Loop
# ============================================================================
print("\n" + "="*60)
print("Starting Custom CNN Training...")
print("="*60)

start_time = time.time()

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 60)

    # Training phase
    train_loss, train_acc = train_epoch_custom(custom_model, train_loader, criterion, optimizer, device)

    # Validation phase
    val_loss, val_acc = validate_epoch_custom(custom_model, val_loader, criterion, device)

    # Learning rate scheduling
    scheduler.step(val_loss)

    # Save history
    history_custom['train_loss'].append(train_loss)
    history_custom['train_acc'].append(train_acc)
    history_custom['val_loss'].append(val_loss)
    history_custom['val_acc'].append(val_acc)

    # Print epoch results
    print(f"\nEpoch {epoch+1} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = copy.deepcopy(custom_model.state_dict())
        patience_counter = 0
        print(f"  ✓ New best validation loss: {best_val_loss:.4f}")
    else:
        patience_counter += 1
        print(f"  Patience: {patience_counter}/{patience}")

    # Early stopping
    if patience_counter >= patience:
        print(f"\n⚠ Early stopping triggered after {epoch+1} epochs")
        break

# Load best model weights
custom_model.load_state_dict(best_model_wts)
print(f"\n✓ Training completed!")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Total training time: {(time.time() - start_time)/60:.2f} minutes")

# Save the trained model
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
model_path = models_dir / "custom_cnn_medical_mnist.pth"
torch.save({
    'model_state_dict': custom_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_val_loss': best_val_loss,
    'history': history_custom,
    'num_classes': num_classes,
    'class_names': class_names
}, str(model_path))

print(f"✓ Model saved to: {model_path.absolute()}")



Starting Custom CNN Training...

Epoch 1/20
------------------------------------------------------------


Training: 100%|██████████| 1474/1474 [01:56<00:00, 12.62it/s, loss=0.0563, acc=98.07%]
Validation: 100%|██████████| 185/185 [00:27<00:00,  6.82it/s, loss=0.0032, acc=99.69%]



Epoch 1 Results:
  Train Loss: 0.0667 | Train Acc: 98.07%
  Val Loss: 0.0144 | Val Acc: 99.69%
  Learning Rate: 0.001000
  ✓ New best validation loss: 0.0144

Epoch 2/20
------------------------------------------------------------


Training:   9%|▉         | 139/1474 [00:12<01:23, 15.99it/s, loss=0.0157, acc=99.57%] Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x10b16aa60>
Traceback (most recent call last):
  File "/Users/sztaki/Documents/computer_vision/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/Users/sztaki/Documents/computer_vision/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1628, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/Library/Developer/CommandLi

KeyboardInterrupt: 

## Vision Transformer (ViT)

The **Vision Transformer (ViT)** represents a paradigm shift in computer vision, adapting the Transformer architecture (originally designed for natural language processing) to image classification tasks. Unlike CNNs that use convolutional operations, ViTs process images as sequences of patches.

### Architecture Overview

**Key Components:**

1. **Image Patching**: 
   - Input images are divided into fixed-size patches (e.g., 16×16 or 32×32 pixels)
   - Each patch is flattened and linearly projected into an embedding space
   - This converts the 2D image into a 1D sequence of patch embeddings

2. **Position Embeddings**:
   - Learnable position embeddings are added to patch embeddings
   - Allows the model to understand spatial relationships between patches

3. **Transformer Encoder**:
   - Multiple layers of self-attention and feed-forward networks
   - Self-attention mechanism enables the model to focus on relevant patches
   - Each layer refines the patch representations

4. **Classification Token**:
   - A special [CLS] token is prepended to the sequence
   - This token aggregates information from all patches
   - Final classification is performed using this token

### Advantages

- **Global Receptive Field**: Self-attention allows the model to attend to all patches simultaneously, capturing long-range dependencies
- **Scalability**: Performance improves significantly with more data and larger models
- **Transfer Learning**: Pre-trained ViTs can be fine-tuned effectively on downstream tasks
- **Interpretability**: Attention maps can visualize which image regions the model focuses on

### Comparison with CNNs

- **CNNs**: Inductive bias of locality and translation equivariance (convolutional operations)
- **ViTs**: Minimal inductive bias, rely on attention mechanisms and data to learn patterns
- **Data Requirements**: ViTs typically require more data than CNNs to achieve similar performance, but excel with large-scale pre-training

For this project, we will train a Vision Transformer to compare its performance against the CNN-based architectures (ResNet-50 and Custom CNN) on the Medical MNIST dataset.


In [17]:
from transformers import ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
from transformers import DefaultDataCollator
from torch.utils.data import Dataset
from PIL import Image
import torch
from pathlib import Path
import os
from sklearn.metrics import accuracy_score
import numpy as np
import random
from PIL import Image
import torch

In [18]:
# ============================================================================
# Custom Dataset Class for Hugging Face
# ============================================================================
class MedicalMNISTDataset(Dataset):
    def __init__(self, image_paths, labels, processor):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        # Process image with ViT processor
        encoding = self.processor(image, return_tensors="pt")
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding['labels'] = torch.tensor(label, dtype=torch.long)

        return encoding

In [19]:
# ============================================================================
# Prepare Dataset
# ============================================================================
print("Loading Vision Transformer processor and model...")

# Load ViT processor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=num_classes,
    ignore_mismatched_sizes=True
)

# Move model to device
model = model.to(device)
print(f"✓ ViT model loaded and moved to {device}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Loading Vision Transformer processor and model...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([6]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ ViT model loaded and moved to mps
  Total parameters: 85,803,270
  Trainable parameters: 85,803,270


In [24]:
# Prepare image paths and labels for train/val sets
def prepare_dataset_paths(dataset_folder, class_names):
    image_paths = []
    labels = []

    for class_idx, class_name in enumerate(class_names):
        class_dir = dataset_folder / class_name
        if class_dir.exists():
            for img_file in class_dir.glob('*'):
                if img_file.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                    image_paths.append(str(img_file))
                    labels.append(class_idx)

    return image_paths, labels

In [25]:
print("\nPreparing datasets...")
train_paths, train_labels = prepare_dataset_paths(train_dir, class_names)
val_paths, val_labels = prepare_dataset_paths(val_dir, class_names)

# Randomly sample 300 images for training and 50 for validation (for faster training)
import random
random.seed(42)

# Sample training data
if len(train_paths) > 300:
    indices = list(range(len(train_paths)))
    sampled_indices = random.sample(indices, 300)
    train_paths = [train_paths[i] for i in sampled_indices]
    train_labels = [train_labels[i] for i in sampled_indices]
    print(f"  Randomly sampled 300 training images from {len(indices)} total")

# Sample validation data
if len(val_paths) > 50:
    indices = list(range(len(val_paths)))
    sampled_indices = random.sample(indices, 50)
    val_paths = [val_paths[i] for i in sampled_indices]
    val_labels = [val_labels[i] for i in sampled_indices]
    print(f"  Randomly sampled 50 validation images from {len(indices)} total")

print(f"  Training images: {len(train_paths)}")
print(f"  Validation images: {len(val_paths)}")

# Create datasets
train_dataset_hf = MedicalMNISTDataset(train_paths, train_labels, processor)
val_dataset_hf = MedicalMNISTDataset(val_paths, val_labels, processor)

# ============================================================================
# Training Configuration
# ============================================================================
output_dir = "./models/vit_medical_mnist"

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir=f'{output_dir}/logs',
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    save_total_limit=2,
    push_to_hub=False,
    report_to="none",
)

# Data collator
data_collator = DefaultDataCollator()


Preparing datasets...
  Randomly sampled 300 training images from 47163 total
  Randomly sampled 50 validation images from 5895 total
  Training images: 300
  Validation images: 50


In [26]:
# Metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}

# ============================================================================
# Initialize Trainer
# ============================================================================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=val_dataset_hf,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [27]:
# ============================================================================
# Training
# ============================================================================
print("\n" + "="*60)
print("Starting Vision Transformer Training...")
print("="*60)
print(f"Training on {len(train_dataset_hf)} samples")
print(f"Validating on {len(val_dataset_hf)} samples")
print(f"Output directory: {output_dir}")

# Train the model
train_result = trainer.train()

# Save the final model
trainer.save_model()
processor.save_pretrained(output_dir)

print(f"\n✓ Training completed!")
print(f"  Final training loss: {train_result.training_loss:.4f}")
print(f"  Model saved to: {output_dir}")

# Evaluate on validation set
eval_results = trainer.evaluate()
print(f"\nValidation Results:")
print(f"  Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"  Validation Accuracy: {eval_results['eval_accuracy']:.4f}")

print(f"\n✓ Vision Transformer training pipeline completed!")


Starting Vision Transformer Training...
Training on 300 samples
Validating on 50 samples
Output directory: ./models/vit_medical_mnist


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.002534,1.0
2,No log,0.001366,1.0
3,No log,0.001027,1.0
4,No log,0.000849,1.0
5,No log,0.00074,1.0
6,0.001300,0.000657,1.0
7,0.001300,0.000606,1.0
8,0.001300,0.00056,1.0
9,0.001300,0.000528,1.0
10,0.001300,0.0005,1.0





✓ Training completed!
  Final training loss: 0.0004
  Model saved to: ./models/vit_medical_mnist





Validation Results:
  Validation Loss: 0.0004
  Validation Accuracy: 1.0000

✓ Vision Transformer training pipeline completed!


In [28]:
# ============================================================================
# Test: Random 100 images classification with ViT
# ============================================================================
# Load trained model
output_dir = "./models/vit_medical_mnist"
processor = ViTImageProcessor.from_pretrained(output_dir)
model = ViTForImageClassification.from_pretrained(output_dir)
model = model.to(device)
model.eval()

# Get 100 random test images
test_paths, test_labels = prepare_dataset_paths(test_dir, class_names)
random.seed(42)
if len(test_paths) > 100:
    sampled = random.sample(list(range(len(test_paths))), 100)
    test_paths = [test_paths[i] for i in sampled]
    test_labels = [test_labels[i] for i in sampled]

# Classify
correct = 0
with torch.no_grad():
    for img_path, true_label in zip(test_paths, test_labels):
        image = Image.open(img_path).convert('RGB')
        inputs = processor(image, return_tensors="pt").to(device)
        outputs = model(**inputs)
        pred = outputs.logits.argmax(-1).item()
        if pred == true_label:
            correct += 1

print(f"Test Accuracy: {correct}/100 = {correct}%")


Test Accuracy: 100/100 = 100%
