In [24]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

In [25]:
class ContrastiveModel(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased', embedding_dim=768):
        super(ContrastiveModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.proj = nn.Linear(embedding_dim, 512)

    def mean_pooling(self, model_output, attention_mask):
        # Average token embeddings
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def forward(self, input_ids, attention_mask):
        # Forward pass

        # Get contextual representations
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        # Combine encoder outputs into one sentence level embedding
        embeddings = self.mean_pooling(outputs, attention_mask)

        # map embeddings from bert native to 512
        embeddings = self.proj(embeddings)
        # L2 normalization
        embedding = F.normalize(embeddings, p=2, dim=1)

        return embeddings

In [26]:
model = ContrastiveModel()
test_id = torch.randint(0, 1000, (2,128))
test_mask = torch.ones_like(test_id)
embeddings = model(test_id, test_mask)

In [29]:
embeddings.shape

torch.Size([2, 512])

In [43]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5, loss_type='contrastive'):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.loss_type = loss_type

    def forward(self, anchor, positive, negative=None, labels=None):
        if self.loss_type == 'contrastive':
            distances = torch.norm(anchor - positive, dim=1)
            
            # For positive labels minimize, for negative ensure distance is at least margin
            losses = labels * distances + (1 -labels) * F.relu(self.margin - distances)
            return losses.mean()
            
        elif loss_type == 'cosine':
            cos_sim = F.cosine_similarity(anchor, positive)
            # maximize similarity
            return -cos_sim.mean()
        

In [47]:
loss_fn = ContrastiveLoss(margin=0.5, loss_type='contrastive')
anchor = torch.randn(4, 512)
positive = torch.randn(4, 512)
labels = torch.tensor([1, 1, 0, 0], dtype=torch.float)
loss = loss_fn(F.normalize(anchor, dim=1), F.normalize(positive, dim=1), labels=labels)
loss.item()

0.7039779424667358