In [None]:
import torch
from models.backbones import *
from models.projectors import *

In [None]:
class BarlowTwins(torch.nn.Module):
    def __init__(self, backbone, projector, loss_param_scale, loss_param_lmbda):
        super().__init__()
        self.backbone = backbone
        self.projector = projector
        
        # affine = False -> no learnable parameters
        self.bn = torch.nn.BatchNorm1d(projector[-1].out_features, affine=False)
        
        self.loss_param_scale = loss_param_scale
        self.loss_param_lmbda = loss_param_lmbda
    
    def forward(self, x1, x2):
        z1 = self.projector(self.backbone(x1))
        z2 = self.projector(self.backbone(x2))
        
        # emprical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)
        
        loss = self.loss(c)
        return loss
    
    def off_diagonal(self, x):
        # return a flattened view of the off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
    
    
    def loss(self, c):
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.loss_param_scale)
        off_diag = self.off_diagonal(c).pow_(2).sum().mul(self.loss_param_scale)
        #
        loss = on_diag + self.loss_param_lmbda * off_diag
        return loss

In [None]:
d_out = 512
d_hidden = 1024
n_hidden = 2
normalize = True
dropout_rate = None
activation_last = False
normalize_last = False
dropout_rate_last = None

In [None]:
backbone = get_backbone("ResNet-18")
#
projector = get_projection_head_layers(
    d_in=backbone.dim_out,
    d_out=d_out,
    d_hidden=d_hidden,
    n_hidden=n_hidden,
    normalize=normalize,
    dropout_rate=dropout_rate,
    activation_last=activation_last,
    normalize_last=normalize_last,
    dropout_rate_last=dropout_rate_last)
projector = torch.nn.Sequential(*projector)

In [None]:
model = BarlowTwins(backbone, projector, loss_param_scale=1/32, loss_param_lmbda=3.9e-3)

In [None]:
x1 = torch.rand((64, 3, 32, 32))
x2 = torch.rand((64, 3, 32, 32))

In [None]:
loss = model(x1, x2)