In [1]:
import numpy
import torch 
import torch.nn as nn
import torch.nn.functional as F 
import einops

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
ma = nn.MultiheadAttention(128, 8)
sum((p.numel() for p in ma.parameters()))

66048

In [114]:
out_dim  = 10
def make_sum_net(): 
    return nn.Sequential(
            nn.Linear(2,32),
            nn.Linear(32,128),
            nn.Linear(128,out_dim)
           )

In [113]:
X = torch.randint(0,10, (1_000,2),dtype=torch.float32)
emb = torch.rand(32, out_dim) # b_size x out_dim
C = torch.zeros(out_dim)
y = X.sum(dim=1)
student = make_sum_net()
optim = torch.optim.SGD(student.parameters(), lr=1e-3)
teacher = make_sum_net()
# there is no backpropagation through the teacher, so no need for gradients
for p in teacher.parameters():
    p.requires_grad = False

In [96]:
def polyak(teacher, student, tau=0.99, device='cpu'):
  one = torch.ones(1, requires_grad=False).to(device)
  for s_param, t_param in zip(student.parameters(), teacher.parameters()):
    t_param.data.mul_(tau)
    t_param.data.addcmul_(s_param.data, one, value=(1-tau))

In [118]:
class DINOLoss(nn.Module):
    def __init__(self, out_dim, teacher_temp, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student_output, teacher_output):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp

        # teacher centering and sharpening
        temp = self.teacher_temp
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)

        loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1)
        loss = loss.mean()
        self.update_center(teacher_output)
        return loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.mean(teacher_output, dim=0)
        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

In [145]:
m, tau_s, tau_t, l = 0.9, 0.1, 0.04, 0.99

In [146]:
dino_loss = DINOLoss(out_dim, tau_t, student_temp=tau_t, center_momentum=m)

In [152]:
for i in range(100):
    s = student(X)
    t = teacher(X)
    # student update
    optim.zero_grad()
    loss = dino_loss(s, t) # also updates C
    loss.backward()
    optim.step()
    # teacher update
    polyak(teacher, student, tau=l, device='cpu')
    print(loss)

tensor(0.6631, grad_fn=<MeanBackward0>)
tensor(0.8075, grad_fn=<MeanBackward0>)
tensor(0.6495, grad_fn=<MeanBackward0>)
tensor(0.7483, grad_fn=<MeanBackward0>)
tensor(0.6441, grad_fn=<MeanBackward0>)
tensor(0.7629, grad_fn=<MeanBackward0>)
tensor(0.6329, grad_fn=<MeanBackward0>)
tensor(0.7246, grad_fn=<MeanBackward0>)
tensor(0.6261, grad_fn=<MeanBackward0>)
tensor(0.7241, grad_fn=<MeanBackward0>)
tensor(0.6167, grad_fn=<MeanBackward0>)
tensor(0.6993, grad_fn=<MeanBackward0>)
tensor(0.6092, grad_fn=<MeanBackward0>)
tensor(0.6904, grad_fn=<MeanBackward0>)
tensor(0.6007, grad_fn=<MeanBackward0>)
tensor(0.6730, grad_fn=<MeanBackward0>)
tensor(0.5930, grad_fn=<MeanBackward0>)
tensor(0.6609, grad_fn=<MeanBackward0>)
tensor(0.5852, grad_fn=<MeanBackward0>)
tensor(0.6470, grad_fn=<MeanBackward0>)
tensor(0.5776, grad_fn=<MeanBackward0>)
tensor(0.6346, grad_fn=<MeanBackward0>)
tensor(0.5702, grad_fn=<MeanBackward0>)
tensor(0.6224, grad_fn=<MeanBackward0>)
tensor(0.5630, grad_fn=<MeanBackward0>)


In [25]:
# VICREG
# Variance
def variance_loss(x, gamma=1, eps=0.0001):
    std_x = torch.sqrt(x.var(dim=0) + eps)
    std_loss = torch.mean(F.relu(gamma - std_x)) 
    return std_loss
# Covariance 
def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
def covariance_loss(x):
    b_size, num_features = x.shape[0], x.shape[1]
    x = x - x.mean(dim=0)
    cov_x = (x.T @ x) / (b_size - 1)
    cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features) 
    return cov_loss