# $\texttt{dopanim}$: $t$-SNE Plot from a Fine-tuned DINOv2 ViT-S/14

In [None]:
import torch
import os
import torch.nn as nn 
import numpy as np
import matplotlib.pyplot as plt
import sys

# TODO: Append the path to your `multi-annotator-machine-learning` project.
sys.path.append("../../")

from maml.data import Dopanim
from matplotlib.colors import to_rgba
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm 
from lightning import seed_everything


# TODO: Adjust data path to your dataset.
DATA_PATH = "."

# TODO: Adjust flag for downloading the dataset.
DOWNLOAD = False

# TODO: Define path, where figures are to be stored..
FIGURE_PATH = "."

### Fine-tune DINOv2 ViT-S/14 model on the Training Set

In [None]:
seed_everything(0, workers=True)

# Image transformations of DinoV2 models.
dino_mean = (0.485, 0.456, 0.406) 
dino_std = (0.229, 0.224, 0.225) 
dino_transform = transforms.Compose([ 
    transforms.Resize(256, interpolation=3), 
    transforms.CenterCrop(224), transforms.ToTensor(), 
    transforms.Normalize(dino_mean, dino_std) 
]) 

# Load train dataset.
train_ds = Dopanim(DATA_PATH, version='train', variant='full', transform=dino_transform, download=DOWNLOAD)

# Setup DinoV2 ViT-S/14 architecture.
dino_model_name = 'dinov2_vits14'
dino_model = torch.hub.load('facebookresearch/dinov2', dino_model_name) 
model = nn.Sequential(dino_model, nn.Linear(384, 15)) 

# Setup training.
batch_size = 16
num_epochs = 10
device = 'cuda' 
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=8, shuffle=True) 
criterion = nn.CrossEntropyLoss() 
params = [{'params': model[0].parameters(), 'lr': 1e-5}, {'params': model[1].parameters(), 'lr': 1e-3}] 
optimizer = torch.optim.RAdam(params, lr=1e-3, weight_decay=0) 
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# Train model.
torch.set_grad_enabled(True) 
model.to(device) 
for i in tqdm(range(num_epochs)):
    running_loss = 0
    for batch in train_loader: 
        inputs, targets = batch["x"], batch["y"] 
        logits = model(inputs.to(device)) 
        loss = criterion(logits, targets.to(device)) 

        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(running_loss)
    lr_scheduler.step()

### Evaluate Fine-tuned ViT-S/14 model on the Validation Set

In [None]:
# Load validation dataset.
val_ds = Dopanim(DATA_PATH, version='valid', transform=dino_transform, download=DOWNLOAD)
val_loader = DataLoader(val_ds, batch_size=64, num_workers=8)

# Evaluation on validation dataset.
torch.set_grad_enabled(False)
logits_list = []
targets_list = []
model.to(device) 
for batch in tqdm(val_loader): 
    inputs, targets = batch["x"], batch["y"] 
    logits = model(inputs.to(device)) 
    logits_list.append(logits.cpu())
    targets_list.append(targets.cpu())
logits = torch.cat(logits_list)
targets = torch.cat(targets_list)
print(f"Validation accuracy: {torch.mean((logits.argmax(-1) == targets).float())}")

### Load Features Learned by the Fine-tuned ViT-S/14 model' Penultimate's Layer

In [None]:
features_list = []
targets_list = []
torch.set_grad_enabled(False)
dino_model.to(device)
for batch in val_loader:
    inputs, targets = batch["x"], batch["y"]
    features = dino_model(inputs.to(device))
    features_list.append(features.cpu())
    targets_list.append(targets)
features = torch.cat(features_list)
targets = torch.cat(targets_list)

### Visualize Learned Features using $t$-SNE

In [None]:
tsne = TSNE(random_state=42)
X = tsne.fit_transform(features.numpy())
y = val_ds.le.inverse_transform(targets)

# Class groups
classes = {
    "Squirrels": ['American Red Squirrel', 'Douglas\' Squirrel', 'Eurasian Red Squirrel'],
    "Hares & Rabbits": ['Black-tailed Jackrabbit', 'Brown Hare', 'Desert Cottontail', 'European Rabbit', 'Marsh Rabbit'],
    "Insects": ['European Hornet', 'European Paper Wasp', 'German Yellowjacket', 'Yellow-legged Hornet'],
    "Big Cats": ['Cheetah', 'Jaguar', 'Leopard']
}

# Define markers
group_markers = {
    "Squirrels": 'o',
    "Hares & Rabbits": 's',
    "Insects": '^',
    "Big Cats": 'D'
}

# Define colors
color_palette = ["#008080ff", "#800080ff", "#2a7fffff", "#e580ffff", "#5fd3bcff"]

# Generate a color for each class within each group
class_colors = {}
color_index = 0
for group, members in classes.items():
    for member in members:
        class_colors[member] = to_rgba(color_palette[color_index % len(color_palette)])
        color_index += 1

# Plotting
plt.figure(figsize=(12, 8))

for group, members in classes.items():
    for member in members:
        idx = np.where(y == member)
        plt.scatter(X[idx, 0], X[idx, 1], c=[class_colors[member]], label=member, marker=group_markers[group], s=50)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(1.05, 1), loc='upper left')

plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Scatter plot with different shapes and colors for each class')
plt.savefig(os.path.join(FIGURE_PATH, "tsne.pdf"))
plt.show()