In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from tqdm import tqdm

from utils import WeakLabeledData

try:
    import einops
except:
    !pip install -r requirements.txt

import copy

Collecting einops (from -r requirements.txt (line 1))
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting kappamodules (from -r requirements.txt (line 2))
  Downloading kappamodules-0.1.112-py3-none-any.whl.metadata (926 bytes)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading kappamodules-0.1.112-py3-none-any.whl (77 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.7/77.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops, kappamodules
Successfully installed einops-0.8.0 kappamodules-0.1.112
[0m

In [11]:
GROUND_TRUTH = False
DATASET = 'imagenet'
STARTING_EPOCH = 0  # set this value when resuming a paused training run. The code will later load the saved intermediate model, and modify the training loop for the correct number of epochs

In [3]:
class DinoClassification(nn.Module):
    """Add a classification head to an existing DINO model"""
    def __init__(self, original_model, num_classes=10):
        super(DinoClassification, self).__init__()
        
        # copy layers from original model
        self.dino = original_model
        
        # add classification head
        self.head = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        # Pass input through dino model
        x = self.dino(x)

        # Extract the class token and pass through the classification head
        cls_token = x[:, 0]  # Shape: (batch_size, embed_dim)
        x = self.head(cls_token)  # Shape: (batch_size, num_classes)

        return x

In [4]:
def setup_dino(num_classes=10, freeze=True):
    # DINO, pretrained on ImageNet-1k based on the representations of ViT-L/16
    dino = torch.hub.load("BenediktAlkin/torchhub-ssl", "in1k_dinov2_l16", trust_repo=True)
    
    # add classification head
    model = DinoClassification(dino, num_classes)
    
    if freeze:
        # freeze layers for finetuning
        for param in model.parameters():
            param.requires_grad = False

        # unfreeze last layer to finetune
        for param in model.head.parameters():
            param.requires_grad = True
        
        # for param in model.dino.blocks[-1].parameters():
        #     param.requires_grad = True
        
    return model

In [5]:
model = setup_dino(1000)

Downloading: "https://github.com/BenediktAlkin/torchhub-ssl/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://huggingface.co/BenediktAlkin/DINOv2/resolve/main/in1k_large16.pth" to /root/.cache/torch/hub/checkpoints/in1k_dinov2_large16.pth
100%|██████████| 1.22G/1.22G [00:30<00:00, 42.6MB/s]


In [12]:
if STARTING_EPOCH > 0:
    model.load_state_dict(torch.load(f'models/{DATASET}/dino_{"gt" if GROUND_TRUTH else "wtsg"}_epoch{STARTING_EPOCH}.pth'))

In [7]:
train_set = torch.load(f'./data/{DATASET}-split/weak-labeled-half.pth')
train_loader = DataLoader(train_set, batch_size=128, shuffle=False)

In [8]:
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [9]:
# Training Loop setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [13]:
# Training Loop
model.train()

NUM_EPOCHS = 10  # Fine-tuning for 3 epochs

for epoch in range(STARTING_EPOCH,NUM_EPOCHS): # WHEN FINISHING A RUN SET LOWER BOUND
    cum_loss = 0.0
    total = 0
    
    progress_bar = tqdm(train_loader, leave=True)
    progress_bar.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}] ({'GT' if GROUND_TRUTH else 'WTSG'})")
    for idx, (images, true_labels, weak_labels) in enumerate(progress_bar):
        images, true_labels, weak_labels = images.to(device), true_labels.to(device), weak_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        if GROUND_TRUTH:
            # training step for ground truth model
            loss = criterion(outputs, true_labels)
        else:
            # training step for weak labels model
            loss = criterion(outputs, weak_labels)
    
        loss.backward()
        optimizer.step()
        
        cum_loss += loss.item()
        total += len(true_labels)
        
        avg_loss = cum_loss / total
        
        # Update the tqdm bar with loss and epoch
        progress_bar.set_postfix({"avg_loss":avg_loss, "cum_loss":cum_loss})
    
    # Save the fine-tuned model
    if (epoch + 1) % 2 == 0:
        if GROUND_TRUTH:
            name = f'dino_gt_epoch{epoch+1}.pth'
        else:
            name = f'dino_wtsg_epoch{epoch+1}.pth'
        torch.save(model.state_dict(), f'models/{DATASET}/{name}') 

Epoch [1/10] (WTSG): 100%|██████████| 313/313 [18:37<00:00,  3.57s/it, avg_loss=0.0483, cum_loss=1.93e+3]
Epoch [2/10] (WTSG): 100%|██████████| 313/313 [19:11<00:00,  3.68s/it, avg_loss=0.0323, cum_loss=1.29e+3]
Epoch [7/10] (WTSG): 100%|██████████| 313/313 [18:48<00:00,  3.60s/it, avg_loss=0.0132, cum_loss=529]
Epoch [8/10] (WTSG): 100%|██████████| 313/313 [18:12<00:00,  3.49s/it, avg_loss=0.0118, cum_loss=471]
Epoch [9/10] (WTSG): 100%|██████████| 313/313 [18:49<00:00,  3.61s/it, avg_loss=0.0105, cum_loss=418]
Epoch [10/10] (WTSG): 100%|██████████| 313/313 [18:39<00:00,  3.58s/it, avg_loss=0.00928, cum_loss=371]


In [14]:
# # Save the fine-tuned model
# if GROUND_TRUTH:
#     name = f'dino_gt_epoch10.pth'
# else:
#     name = f'dino_wtsg_epoch10.pth'
# torch.save(model.state_dict(), f'models/{DATASET}/{name}') 

In [None]:
# load the test data
# Define preprocessing transformations
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # All 3 models expect 224x224 images
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalization constants for ImageNet-1k (pre-training data)
# ])

# test_data = CIFAR100(root='./data', train=False, transform=transform)
test_data = torch.load(f'data/{DATASET}-split/test.pth')
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

In [None]:
model.eval()

# Initialize metrics
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():  # Disable gradient computation to save memory and computation
    progress_bar = tqdm(test_loader, leave=True)
    for idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device) # send data to gpu
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        test_loss += loss.item()  # Accumulate loss
        
        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += len(labels)
        correct += (predicted == labels).sum().item()
        
        # Calculate metrics for display
        avg_loss = test_loss / total
        accuracy = 100 * correct / total

        # Update the tqdm bar with loss and accuracy
        progress_bar.set_postfix(loss=avg_loss, accuracy=accuracy, correct=correct)

# Calculate final metrics
avg_loss = test_loss / len(test_loader)
accuracy = 100 * correct / total

print(f"Test Loss: {avg_loss:.4f}")
print(f"Test Accuracy: {accuracy:.2f}%")

100%|██████████| 79/79 [04:56<00:00,  3.75s/it, accuracy=72.7, correct=7267, loss=0.0118]

Test Loss: 1.4964
Test Accuracy: 72.67%





In [17]:
(.7267 - .5214) / (.7870 - .5214)

0.7729668674698795