In [60]:
import sys
sys.path.append("../src")

import os

import torch
import numpy as np
import torch.nn as nn
import lightning as L
import torch.nn.functional as F
from sklearn.metrics import accuracy_score

from utils import (
    get_args, 
    get_model,
    save_args,
    get_base_lr,
    get_dataset,
    get_param_groups,
    cosine_scheduling,
    get_teacher_temperatures,
    DINO,
    Encoder
    )

In [48]:
model = get_model("vit-s-16")

model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [14]:
def get_encoder_args(run_dir: str):
    args = get_args(run_dir)

    encoder_keys = {"backbone", "mlp_layers", "hidden_dim", "bottleneck_dim", "k_dim"}

    encoder_args = {k: v for k, v in args.items() if k in encoder_keys}

    return encoder_args

In [20]:
backbone = "vit-s-16"
experiment_num = 0

ckpt_dir = os.path.join("..", "assets", "model-weights", backbone, "pre-train", f"version_{experiment_num}", "min-loss.ckpt")
run_dir = os.path.join("..", "src", "pre-train-runs", backbone, f"version_{experiment_num}", "run-config.yaml")

ckpt = torch.load(ckpt_dir, map_location=torch.device("cpu"))["state_dict"]


In [21]:
encoder_args = get_encoder_args(run_dir)

encoder = Encoder(**encoder_args)

encoder


Encoder(
  (encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
  

In [22]:
for name, _ in encoder.named_parameters():
    print(name)

encoder.cls_token
encoder.pos_embed
encoder.patch_embed.proj.weight
encoder.patch_embed.proj.bias
encoder.blocks.0.norm1.weight
encoder.blocks.0.norm1.bias
encoder.blocks.0.attn.qkv.weight
encoder.blocks.0.attn.qkv.bias
encoder.blocks.0.attn.proj.weight
encoder.blocks.0.attn.proj.bias
encoder.blocks.0.norm2.weight
encoder.blocks.0.norm2.bias
encoder.blocks.0.mlp.fc1.weight
encoder.blocks.0.mlp.fc1.bias
encoder.blocks.0.mlp.fc2.weight
encoder.blocks.0.mlp.fc2.bias
encoder.blocks.1.norm1.weight
encoder.blocks.1.norm1.bias
encoder.blocks.1.attn.qkv.weight
encoder.blocks.1.attn.qkv.bias
encoder.blocks.1.attn.proj.weight
encoder.blocks.1.attn.proj.bias
encoder.blocks.1.norm2.weight
encoder.blocks.1.norm2.bias
encoder.blocks.1.mlp.fc1.weight
encoder.blocks.1.mlp.fc1.bias
encoder.blocks.1.mlp.fc2.weight
encoder.blocks.1.mlp.fc2.bias
encoder.blocks.2.norm1.weight
encoder.blocks.2.norm1.bias
encoder.blocks.2.attn.qkv.weight
encoder.blocks.2.attn.qkv.bias
encoder.blocks.2.attn.proj.weight
encode

In [23]:
for i in ckpt.keys():
    print(i)

student.encoder.cls_token
student.encoder.pos_embed
student.encoder.patch_embed.proj.weight
student.encoder.patch_embed.proj.bias
student.encoder.blocks.0.norm1.weight
student.encoder.blocks.0.norm1.bias
student.encoder.blocks.0.attn.qkv.weight
student.encoder.blocks.0.attn.qkv.bias
student.encoder.blocks.0.attn.proj.weight
student.encoder.blocks.0.attn.proj.bias
student.encoder.blocks.0.norm2.weight
student.encoder.blocks.0.norm2.bias
student.encoder.blocks.0.mlp.fc1.weight
student.encoder.blocks.0.mlp.fc1.bias
student.encoder.blocks.0.mlp.fc2.weight
student.encoder.blocks.0.mlp.fc2.bias
student.encoder.blocks.1.norm1.weight
student.encoder.blocks.1.norm1.bias
student.encoder.blocks.1.attn.qkv.weight
student.encoder.blocks.1.attn.qkv.bias
student.encoder.blocks.1.attn.proj.weight
student.encoder.blocks.1.attn.proj.bias
student.encoder.blocks.1.norm2.weight
student.encoder.blocks.1.norm2.bias
student.encoder.blocks.1.mlp.fc1.weight
student.encoder.blocks.1.mlp.fc1.bias
student.encoder.

In [26]:
student_params = {k: params for k, params in ckpt.items() if "student." in k}

student_params.keys()

dict_keys(['student.encoder.cls_token', 'student.encoder.pos_embed', 'student.encoder.patch_embed.proj.weight', 'student.encoder.patch_embed.proj.bias', 'student.encoder.blocks.0.norm1.weight', 'student.encoder.blocks.0.norm1.bias', 'student.encoder.blocks.0.attn.qkv.weight', 'student.encoder.blocks.0.attn.qkv.bias', 'student.encoder.blocks.0.attn.proj.weight', 'student.encoder.blocks.0.attn.proj.bias', 'student.encoder.blocks.0.norm2.weight', 'student.encoder.blocks.0.norm2.bias', 'student.encoder.blocks.0.mlp.fc1.weight', 'student.encoder.blocks.0.mlp.fc1.bias', 'student.encoder.blocks.0.mlp.fc2.weight', 'student.encoder.blocks.0.mlp.fc2.bias', 'student.encoder.blocks.1.norm1.weight', 'student.encoder.blocks.1.norm1.bias', 'student.encoder.blocks.1.attn.qkv.weight', 'student.encoder.blocks.1.attn.qkv.bias', 'student.encoder.blocks.1.attn.proj.weight', 'student.encoder.blocks.1.attn.proj.bias', 'student.encoder.blocks.1.norm2.weight', 'student.encoder.blocks.1.norm2.bias', 'student.en

In [28]:
student_params = {k.replace("student.", ""): params for k, params in student_params.items()}

student_params.keys()

dict_keys(['encoder.cls_token', 'encoder.pos_embed', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.qkv.bias', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.2.norm1.weight', 'encoder.blocks.2.norm1.bias', 'en

In [29]:
encoder.load_state_dict(student_params)

<All keys matched successfully>

In [38]:
encoder

Encoder(
  (encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
  

In [45]:
embedding_dim = encoder.mlp[0].in_features

In [None]:
class Classifier(L.LightningModule):
    """
    Constructs and initializes the classifier for fine-tuning and training.
    """

    def __init__(
        self, 
        encoder: Encoder,
        num_classes: int,
        embedding_dim: int,
        learning_rate: float,
        eta_min: float,
        weight_decay: float,
        ):
        super().__init__()
        
        self.learning_rate = learning_rate
        self.eta_min = eta_min
        self.weight_decay = weight_decay
        self.criterion = nn.CrossEntropyLoss()

        self.encoder = encoder
        self.encoder.requires_grad_(False)
        self.encoder.eval()

        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(
        self,
        x: torch.Tensor
        ):

        h = self.encoder(x)
        logits = self.fc(h)

        return logits
    
    def _pred_and_eval(self, batch):
        img, target = batch
        logits = self(img)

        confidence = F.softmax(logits, dim=1)
        pred = torch.argmax(confidence, dim=1)

        loss = self.criterion(logits, target)
        accuracy = accuracy_score(target.cpu(), pred.cpu())

        return loss, accuracy
    
    def training_step(self, batch, _):
        train_loss, train_accuracy = self._pred_and_eval(batch)
        optimizer = self.trainer.optimizers[0]
        lr = optimizer.param_groups[0]["lr"]

        self.log("Train/Loss", train_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("Train/Accuracy", train_accuracy, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("Learning Rate", lr, on_step=False, on_epoch=True, prog_bar=True)

        metrics = {
            "loss": train_loss,
            "accuracy": train_accuracy
        }

        return metrics
    
    def validation_step(self, batch, _):
        val_loss, val_accuracy = self._pred_and_eval(batch)

        self.log("Validation/Loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("Validation/Accuracy", val_accuracy, on_step=False, on_epoch=True, prog_bar=True)

        metrics = {
            "loss": val_loss,
            "accuracy": val_accuracy
        }

        return metrics
    
    def configure_optimizers(self):
        params = list(self.encoder.parameters() + list(self.fc.parameters()))
        optimizer = torch.optim.AdamW(params, lr=self.learning_rate, weight_decay=self.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, eta_min=self.eta_min)

        config = {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "epoch",
                "frequency": 1
            }
        }

        return config
        

In [52]:
img = torch.randn(1, 3, 224, 224)
img

tensor([[[[-1.3758,  0.4629, -2.6067,  ...,  0.4263,  1.6670, -0.4575],
          [-1.0205, -0.2919,  0.9853,  ..., -0.8425, -0.5968,  0.2812],
          [-1.3441, -0.4191,  1.2663,  ..., -0.6104, -1.3954, -0.3468],
          ...,
          [-0.1583,  0.6331,  1.2217,  ..., -1.0472,  0.0089, -0.1754],
          [ 0.1801, -0.9542, -2.1773,  ..., -1.1632, -1.9616,  1.0455],
          [-0.1797, -0.5522,  0.8437,  ...,  1.6244,  0.6465, -0.7188]],

         [[-1.4782,  0.8769,  2.1025,  ...,  0.2194,  0.1642,  0.0179],
          [ 2.1357,  0.3146, -0.2709,  ...,  2.5437, -0.4364,  0.5338],
          [ 0.3186,  0.7456, -0.8078,  ...,  2.3906, -1.5178, -0.1774],
          ...,
          [ 0.0472, -0.6731,  1.4287,  ..., -0.6014,  0.7867,  0.0309],
          [-0.7663, -0.1107, -0.1258,  ...,  0.9039,  0.3793, -0.1356],
          [-0.5656,  2.4810,  0.1978,  ..., -1.3918,  0.9315,  1.1948]],

         [[-0.0058, -1.5197,  2.4743,  ...,  0.3685,  1.8244,  0.2583],
          [-2.6850,  0.0448,  

In [53]:
encoder = encoder.encoder

encoder

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [56]:
h = encoder(img)

h.shape

torch.Size([1, 384])

In [58]:
fc = nn.Linear(384, 10)

logits = fc(h)

logits.shape

torch.Size([1, 10])

In [59]:
F.softmax(logits, dim=1)

tensor([[0.4339, 0.0460, 0.0219, 0.0052, 0.0427, 0.0264, 0.0213, 0.3528, 0.0197,
         0.0302]], grad_fn=<SoftmaxBackward0>)