In [1]:
import tqdm
import torch
from torch import nn
import torchvision
from transformers import AutoImageProcessor, AutoModel


In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
        self.backbone = AutoModel.from_pretrained('facebook/dinov2-small')
        
        self.head = nn.Sequential(
            nn.BatchNorm1d(384),
            nn.Linear(384, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):
        x = self._backbone(x)
        return self.head(x)
    
    @torch.no_grad()
    def _backbone(self, x):
        x = self.processor(x, return_tensors='pt', do_rescale=False)
        x = self.backbone(**x).last_hidden_state.mean(dim=1)
        return x
    
model = Model()
criterion = nn.CrossEntropyLoss()

print(f"Model head has {sum(p.numel() for p in model.head.parameters()):,} parameters")


Model head has 51,594 parameters


In [3]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
optimizer = torch.optim.Adam(model.head.parameters(), lr=2e-4)


In [5]:
with tqdm.notebook.tnrange(0, 10) as epochs:
    for epoch in epochs:
        model.train()
        with tqdm.notebook.tqdm(train_loader, leave=False) as batches:
            for x, y in batches:
                optimizer.zero_grad()
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
                batches.set_postfix(loss=loss.item(), est_acc=torch.exp(-loss).item())
                
        model.eval()
        correct = 0
        total = 0
        
        # with torch.no_grad():
        #     for x, y in test_loader:
        #         y_pred = model(x)
        #         correct += (y_pred.argmax(1) == y).sum().item()
        #         total += y.size(0)
                
        # accuracy = correct / total
        # epochs.set_postfix(accuracy=accuracy)


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/782 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [6]:
@torch.no_grad()
def test(model, test_dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        with tqdm.notebook.tqdm(test_dataloader, leave=False) as batches:
            for x, y in batches:
                y_pred = model(x)
                correct += (y_pred.argmax(1) == y).sum().item()
                total += y.size(0)
                
                batches.set_postfix(accuracy=correct / total)

    return correct / total

test(model, test_loader)


  0%|          | 0/157 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [7]:
torch.save(model.head.state_dict(), 'projects/picturama/dinov2-mnist_head-b_norm.pth')
