In [3]:
import tqdm
import torch
from torch import nn
import torchvision
from transformers import AutoImageProcessor, AutoModel


In [19]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
        self.backbone = AutoModel.from_pretrained('facebook/dinov2-small')
        
        self.head = nn.Sequential(
            nn.Linear(384, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):
        
        x = self.processor(x, return_tensors='pt')['pixel_values']
        x = self.backbone(x).last_hidden_state.mean(dim=1).cpu().numpy()
            
        return self.head(x)
    
model = Model()
optimizer = torch.optim.Adam(model.head.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


In [21]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [15]:
with tqdm.notebook.tnrange(0, 10) as epochs:
    for epoch in epochs:
        model.train()
        with tqdm.notebook.tqdm(train_loader, leave=False) as batches:
            for x, y in batches:
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
                batches.set_postfix(loss=loss.item())
                
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for x, y in test_loader:
                y_pred = model(x)
                correct += (y_pred.argmax(1) == y).sum().item()
                total += y.size(0)
                
        accuracy = correct / total
        epochs.set_postfix(accuracy=accuracy)


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [18]:
def test(model, test_dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        with tqdm.notebook.tqdm(test_dataloader, leave=False) as batches:
            for x, y in batches:
                y_pred = model(x)
                correct += (y_pred.argmax(1) == y).sum().item()
                total += y.size(0)

    return correct / total

test(model, test_loader)


  0%|          | 0/157 [00:00<?, ?it/s]

KeyboardInterrupt: 