In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import get_model
from loaders import dataset_loader

In [36]:
class CLIPDistill(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # teacher
                 teacher_dimension: int,
                 teacher_model: nn.Module,
                 # student
                 student_dimension: int,
                 student_model: nn.Module,
                 ):
        super().__init__()


        self.teacher = teacher_model
        self.student = student_model
        
        self.student_projection = nn.Parameter(torch.randn(student_dimension,embed_dim))
        self.teacher_projection = nn.Parameter(torch.randn(teacher_dimension,embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))        

  
    #### What is this for?
    
#     def build_attention_mask(self):
#         # lazily create causal attention mask, with full attention between the vision tokens
#         # pytorch uses additive attention mask; fill with -inf
#         mask = torch.empty(self.context_length, self.context_length)
#         mask.fill_(float("-inf"))
#         mask.triu_(1)  # zero out the lower diagonal
#         return mask
    
    def student_encoding(self, x):
        return self.student(x) @ self.student_projection

    
    def teacher_encoding(self, x):
        return self.teacher(x) @ self.teacher_projection
            
    def forward(self, x):
        student_features = self.student_encoding(x)
        teacher_features = self.teacher_encoding(x)

        # normalized features
        student_features = student_features / student_features.norm(dim=-1, keepdim=True)
        teacher_features = teacher_features / teacher_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_student_emb = logit_scale * student_features @ teacher_features.t()
        logits_per_teacher_emb = logits_per_student_emb.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_teacher_emb, logits_per_student_emb
    

In [66]:
def clip_distill_loss(logits_teacher,logits_student):
    targets = torch.arange(batch_size)
    ce_teacher = F.cross_entropy(logits_teacher,targets)
    ce_student = F.cross_entropy(logits_student,targets)    
    
    return (ce_teacher + ce_student) / 2

In [51]:
# Get loaders
batch_size = 64
train_loader,test_loader = dataset_loader('cifar', batch_size = batch_size,train_set_fraction = 1,validate = False)

Files already downloaded and verified
Files already downloaded and verified


In [72]:
# Get models
embed_dim = 512
teacher_dim = 512
student_dim = 512

T = get_model('resnet18_cifar',load = True,load_path = 'cifar_teacher/model.pt',map_location = 'cpu',get_embedder = True)
S = get_model('resnet18_cifar',load = False,get_embedder = True)

clip = CLIPDistill(embed_dim,teacher_dim,T,student_dim,S)

# Get Optimizer
optimizer = torch.optim.Adam(clip.parameters(),lr = 0.001)

In [None]:
### Run the training loop
epochs = range(10)
for epoch in epochs:
    
    for data,_ in train_loader:
        optimizer.zero_grad()
        
        teacher_logits,student_logits = clip(data)
        loss = clip_distill_loss(teacher_logits,student_logits)
        print(loss)
        loss.backward()
        
        optimizer.step()
        


tensor(4.2267, grad_fn=<DivBackward0>)
tensor(4.8297, grad_fn=<DivBackward0>)
tensor(4.7872, grad_fn=<DivBackward0>)
tensor(4.4656, grad_fn=<DivBackward0>)
tensor(4.0791, grad_fn=<DivBackward0>)
tensor(3.9272, grad_fn=<DivBackward0>)
tensor(3.9005, grad_fn=<DivBackward0>)
tensor(3.5109, grad_fn=<DivBackward0>)
tensor(3.4316, grad_fn=<DivBackward0>)
tensor(3.2927, grad_fn=<DivBackward0>)
tensor(3.1091, grad_fn=<DivBackward0>)
tensor(2.8958, grad_fn=<DivBackward0>)
tensor(2.9942, grad_fn=<DivBackward0>)
tensor(2.6891, grad_fn=<DivBackward0>)
tensor(2.6913, grad_fn=<DivBackward0>)
tensor(2.5474, grad_fn=<DivBackward0>)
