In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from transformers import ViTModel

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT needs minimum 224x224
    transforms.ToTensor(),
])
train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Modified ViT Model Definition
class ViTEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        # Convert grayscale to 3 channels
        self.channel_adapter = nn.Conv2d(1, 3, kernel_size=1)  # Better than repeating channels
        self.classifier = nn.Linear(768, 10)  # MNIST has 10 classes

    def forward(self, x):
        # Convert 1-channel to 3-channel
        x = self.channel_adapter(x)  # [B, 3, 224, 224]
        outputs = self.vit(x)
        embeddings = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        return self.classifier(embeddings), embeddings

# Initialize model
model = ViTEmbedder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(5):
    model.train()
    total_loss = 0
    
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        logits, embeddings = model(imgs)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

KeyboardInterrupt: 

In [None]:
# Collect embeddings
all_embeddings = []
all_labels = []

model.eval()
with torch.no_grad():
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        _, embeddings = model(imgs)
        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels)

embeddings = torch.cat(all_embeddings).numpy()
labels = torch.cat(all_labels).numpy()

In [None]:
# UMAP projection (same as before)
import umap
reducer = umap.UMAP()
emb_2d = reducer.fit_transform(embeddings)

plt.scatter(emb_2d[:,0], emb_2d[:,1], c=labels, cmap='tab10', alpha=0.6)
plt.title('ViT Embeddings via UMAP')
plt.show()