In [14]:
import torch
from torch import nn
from torch.nn import functional as F

In [9]:
def MLP(dim, projection_size, widening_factor=4):
    
    hidden_size = dim * widening_factor
    
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

def SimSiamMLP(dim, projection_size, widening_factor=4):
    
    hidden_size = dim * widening_factor
    
    return nn.Sequential(
        nn.Linear(dim, hidden_size, bias=False),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, hidden_size, bias=False),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size, bias=False),
        nn.BatchNorm1d(projection_size, affine=False)
    )

In [26]:
def old_aggregator_in_criterion(
    last_hidden_state: torch.Tensor = None,
    ) -> torch.Tensor:
    #! bad function, don't use
    
    batch_size, sequence_length, _ = last_hidden_state.size()
    attention_mask = torch.ones(batch_size, sequence_length)
    
    output_vectors = []
    
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(
        last_hidden_state.size()).float().to(last_hidden_state.device
    )
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    
    sum_mask = input_mask_expanded.sum(1)
    sum_mask = torch.clamp(sum_mask, min=1e-9)
    
    output_vectors.append(sum_embeddings / sum_mask)
    output_vector = torch.cat(output_vectors, 0)
    
    return output_vector

def aggregate_across_seq_len(
    last_hidden_state: torch.Tensor = None,
    ) -> torch.Tensor:
    
    x = [torch.layer_norm(tl.float(), tl.shape[-1:]) for tl in last_hidden_state]
    x = sum(x) / len(x)
    x = torch.layer_norm(x.float(), x.shape[-1:])
    
    return x

In [27]:
input = torch.randn(20, 64, 128)

input = aggregate_across_seq_len(input)

mlp = MLP(128, 128)
simsiam = SimSiamMLP(128, 128)

out_mlp = mlp(input)
print(out_mlp.shape)
out_sim = simsiam(input)
print(out_sim.shape)

torch.Size([64, 128])
torch.Size([64, 128])
