This file builds an MNIST image classifier using the pre-trained Vision Transformer (google/vit-base-patch16-224) from Hugging Face and fine-tunes it with the actual MNIST dataset to achieve 99.6% accuracy. The purpose of this MNIST image classifier is to provide a quantitative performance measurement for different generative models we have built so far, namely GANs, PixelCNN, VAE, and diffusion models.

## Procedures
1. We start with the `vit-base-patch16-224` model from Hugging Face, designed for image classification with a resolution of 224x224 pixels.
2. The model is further fine-tuned on the original MNIST dataset.
3. After fine-tuning, we evaluate the model on the MNIST test set to assess its accuracy, achieving approximately 99.6% accuracy.
4. We generate 5000 new reconstructed MNIST images with their original labels using the Variational Autoencoder (VAE) model.
5. The fine-tuned model is used as a 'human judge' to evaluate how well the reconstructed MNIST images generated from the VAE model can be correctly classified by the fine-tuned Vision Transformer model.

In [1]:
from PIL import Image
import requests
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig
import torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.transforms import Lambda
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from transformers import get_scheduler
from torchmetrics import Accuracy
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import os


# Set a seed value
seed = 100 
random.seed(seed)  
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) 

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2024-04-25 01:59:09.020536: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-25 01:59:09.049357: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### prepare MNIST data to dataloader and load pre-train vision transformer

In [9]:
# Define transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to fit the model input
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    Lambda(lambda x: x.repeat(3, 1, 1))  # Repeat the grayscale channel to simulate RGB
])

# Load datasets
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load a pre-trained Vision Transformer with a specific configuration for 10 classes
config = ViTConfig.from_pretrained('google/vit-base-patch16-224', num_labels=10)
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', config=config, ignore_mismatched_sizes=True)

# Reinitialize the classifier layer as it was mismatched
model.classifier = nn.Linear(model.config.hidden_size, 10)  # Ensure the input features match hidden_size of model
model.classifier.to(device)

# Verify and replace the classifier if needed
if model.classifier.out_features != 10:
    model.classifier = nn.Linear(model.classifier.in_features, 10)
model.classifier.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Linear(in_features=768, out_features=10, bias=True)

### check the model performance without fine tune

In [3]:
model.eval()  # Set the model to evaluation mode

def evaluate_model(model, data_loader, device):
    # Move model to the right device
    model = model.to(device)
    model.eval()
    
    total_preds, total_labels = [], []

    with torch.no_grad():
        for images, labels in data_loader:
            # Ensure data is on the same device as the model
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(pixel_values=images)
            preds = torch.argmax(outputs.logits, dim=1)

            # Collect predictions and labels
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.cpu().numpy())

    # Compute accuracy
    accuracy = accuracy_score(total_labels, total_preds)
    return accuracy
    
# Assuming 'device' is defined (e.g., device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
# Run evaluation
test_accuracy = evaluate_model(model, test_loader, device)
print(f"Accuracy on MNIST test data (without fine-tuning): {test_accuracy:.4f}")


Accuracy on MNIST test data (without fine-tuning): 0.1360


### train vision transformer on MNIST

In [5]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Prepare optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Prepare learning rate scheduler
num_epochs = 4
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Update model's classifier head if necessary (for MNIST, 10 classes)
model.classifier = nn.Linear(model.classifier.in_features, 10)
model.classifier.to(device)

# Training loop
model.train()
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # The transformers model expects pixel_values key in the input dictionary
        outputs = model(pixel_values=images, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1}: Loss {loss.item()}")



Epoch 1: Loss 0.002542877569794655
Epoch 2: Loss 0.001039237598888576
Epoch 3: Loss 0.0011476653162389994
Epoch 4: Loss 0.00024169111566152424


### compute the test set accuracy

In [6]:
# Setup the metric
accuracy_metric = Accuracy(num_classes=10, average='macro', task='multiclass').to(device)

# Switch model to evaluation mode
model.eval()
total_accuracy = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(pixel_values=images)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        total_accuracy += accuracy_metric(predictions, labels)

# Calculate the average accuracy
average_accuracy = total_accuracy / len(test_loader)
print(f"Test Accuracy: {average_accuracy.item()}")

Test Accuracy: 0.996269702911377


### load reconstructed data into dataloader

In [7]:
# try some augmented data from VAE (5000 images)

class ReconstructedDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the reconstructed images and labels.csv.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.labels = pd.read_csv(os.path.join(root_dir, 'labels.csv'))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, f'reconstructed_{idx}.png')
        image = Image.open(img_name).convert('RGB')  
        label = int(self.labels.iloc[idx, 1])

        if self.transform:
            image = self.transform(image)

        return image, label

# Transformations for the Vision Transformer
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match the model input
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Assuming reconstructed images are saved in 'reconstructed_images' directory
reconstructed_dataset = ReconstructedDataset(root_dir='reconstructed_images', transform=transform)
reconstructed_loader = DataLoader(reconstructed_dataset, batch_size=32, shuffle=False)

images, _ = next(iter(reconstructed_loader))
print(images.shape)  # Should output torch.Size([batch_size, 3, 224, 224])

torch.Size([32, 3, 224, 224])


### compute accuracy for the VAE generated reconstructed image

In [8]:
def evaluate_model(model, data_loader, device):
    model.eval()
    total_preds, total_labels = [], []

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            outputs = model(pixel_values=images)
            preds = torch.argmax(outputs.logits, dim=1)
            
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.numpy())

    accuracy = accuracy_score(total_labels, total_preds)
    return accuracy

accuracy = evaluate_model(model, reconstructed_loader, device)
print(f"Accuracy on reconstructed images: {accuracy:.4f}")

Accuracy on reconstructed images: 0.7070
