## data2vec

Reconstructing the criterion as per https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
embedding_dim = 128
seq_len = 384
batch_size = 20
attention_layers = 12

has_faiss_format = False
batch_norm_target_layer = True
instance_norm_target_layer = True
layer_norm_target_layer = True
layer_norm_targets = True
instance_norm_targets = True

projector = nn.Sequential(
    nn.Linear(embedding_dim, embedding_dim * 2),
    nn.GELU(),
    nn.Linear(embedding_dim * 2, embedding_dim * 4),
    nn.GELU(),
    nn.Linear(embedding_dim * 4, embedding_dim)
)

x = torch.randn(batch_size, seq_len, embedding_dim) # (batch_size, sequence_length, hidden_size)
x = projector(x) # (batch_size, sequence_length, hidden_size)

print('x shape: ', x.shape)

# take k last layers
k = 4
y = [torch.randn(batch_size, seq_len, embedding_dim)] * attention_layers # (batch_size, sequence_length, hidden_size) * attention_layers
y = y[-k:]

# B: batch size, T: sequence length, C: hidden size

if not has_faiss_format:
    y = [tl.permute(1, 0, 2) for tl in y] # BTC -> TBC

permuted = False
if  batch_norm_target_layer or instance_norm_target_layer:
    y = [tl.permute(1, 2, 0) for tl in y]  # TBC -> BCT
    permuted = True

if batch_norm_target_layer:
    y = [
        F.batch_norm(
            tl.float(), running_mean=None, running_var=None, training=True
        )
        for tl in y
    ]

if instance_norm_target_layer:
    y = [F.instance_norm(tl.float()) for tl in y]

if permuted:
    y = [tl.transpose(1, 2) for tl in y]  # BCT -> BTC

if layer_norm_target_layer:
    y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]

y = sum(y) / len(y)

if not permuted:
    y = y.transpose(0, 1)

if layer_norm_targets:
    y = F.layer_norm(y.float(), y.shape[-1:])

if instance_norm_targets:
    y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
    
print('y shape: ', y.shape)

loss_beta = 1.0
loss_scale = 1.0
sz = x.size(-1)

loss = F.smooth_l1_loss(
                x.float(), y.float(), reduction="none", beta=loss_beta
            ).sum(dim=-1)
print('loss: ', loss, 'loss shape: ', loss.shape)

result = {
            "losses": {
                "main": loss.sum() / math.sqrt(sz)
                if loss_scale <= 0
                else loss.sum() * loss_scale,
            },
            "sample_size": loss.numel(),
        }

print(result)

x shape:  torch.Size([20, 384, 128])
y shape:  torch.Size([20, 384, 128])
loss:  tensor([[55.5846, 55.4130, 53.9140,  ..., 52.7618, 55.7299, 54.5257],
        [54.7572, 54.4994, 55.6759,  ..., 55.4732, 54.7223, 53.5696],
        [54.4100, 52.8356, 53.7533,  ..., 55.7398, 52.6998, 55.2166],
        ...,
        [53.4885, 53.2036, 56.9283,  ..., 53.5202, 55.3001, 53.7072],
        [54.7876, 55.1792, 54.9925,  ..., 55.5256, 53.9834, 56.4594],
        [55.0112, 55.7622, 54.9119,  ..., 54.9068, 52.6478, 54.5516]],
       grad_fn=<SumBackward1>) loss shape:  torch.Size([20, 384])
{'losses': {'main': tensor(419955.5938, grad_fn=<MulBackward0>)}, 'sample_size': 7680}
