In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
device = "cpu"

In [28]:
import torch
import torch.nn.functional as F
import numpy as np



def cosine_similarity(a, b):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.

    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = F.normalize(a, p=2, dim=1)
    b_norm = F.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))


def mnr_loss(embeddings, scale=20.0):
    B, _ = embeddings.size()
    embeddings_a, embeddings_b = embeddings.split(B // 2, dim=0)
    scores = cosine_similarity(embeddings_a, embeddings_b) * scale
    labels = torch.tensor(
        range(len(scores)), dtype=torch.long, device=scores.device
    )  # Example a[i] should match with b[i]

    return F.cross_entropy(scores, labels)


def get_negative_mask(batch_size):
    negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
    for i in range(batch_size):
        negative_mask[i, i] = 0
        negative_mask[i, i + batch_size] = 0

    negative_mask = torch.cat((negative_mask, negative_mask), 0)
    return negative_mask


def hard_negative_loss(out, tau_plus=0.1, beta=0.5, temperature=0.07):
    batch_size = out.size(0)// 2
    out = F.normalize(out, p=2, dim=1)
    out_1, out_2 = out.split(batch_size, dim=0)
    
    # neg score
    neg = torch.exp(torch.mm(out, out.transpose(0, 1)) / temperature)
    mask = get_negative_mask(batch_size).to(device)
    neg = neg.masked_select(mask).view(2 * batch_size, -1)

    # pos score
    pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)
    

    N = batch_size * 2 - 2
    imp = (beta* neg.log()).exp()
    reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
    Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)
    # constrain (optional)
    Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))

        
    # contrastive loss
    loss = (- torch.log(pos / (pos + Ng) )).mean()

    return loss


In [29]:
from model import SentenceEncoder, Bert, BertConfig
from tokenizer import tokenizer

base_model = Bert(BertConfig()).load_model("data/chotot/new_ckpt.pt", "cpu")
model = SentenceEncoder(base_model, tokenizer.pad_token_id)

In [30]:
from transform import transform_sent
from datasets import load_from_disk



def augment(examples):
    street_names = examples["street_name"]
    examples["aug_street_name"] = [transform_sent(street_name) for street_name in street_names]
    return examples


def batch_tokenized(examples):
    return {key: [tokenizer.encode(value) for value in values] for key, values in examples.items()}


street_dataset = load_from_disk("augmented_street_dataset")
street_datasets = street_dataset.train_test_split(test_size=0.05, seed=42)

train_street_dataset = street_datasets["train"]
test_street_dataset = street_datasets["test"]
tokenized_test_street_dataset = test_street_dataset.map(batch_tokenized, batched=True).rename_columns({
    "street_name": "input_ids",
    "aug_street_name": "aug_input_ids"
})

Loading cached split indices for dataset at /Users/binhnguyenduc/Documents/1 Projects/llms/sentbert/augmented_street_dataset/cache-9e2cbdb50401850d.arrow and /Users/binhnguyenduc/Documents/1 Projects/llms/sentbert/augmented_street_dataset/cache-a415aa51bff6e9dd.arrow
Loading cached processed dataset at /Users/binhnguyenduc/Documents/1 Projects/llms/sentbert/augmented_street_dataset/cache-b21d4a26563d7dc1.arrow


In [31]:
from transformers import default_data_collator

device = "cpu"

def data_collator(features):
    features = default_data_collator(features)
    features["input_ids"] = features["input_ids"].to(device)
    features["aug_input_ids"] = features["aug_input_ids"].to(device)
    return features


class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        data = self.dataset[idx]
        return {"input_ids": tokenizer.encode(data["street_name"]), "aug_input_ids": tokenizer.encode(data["street_name"])}

    def __len__(self):
        return len(self.dataset)

train_dataset = Dataset(train_street_dataset)

In [32]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=data_collator)
test_loader = DataLoader(tokenized_test_street_dataset, batch_size=8, shuffle=False, collate_fn=data_collator)

In [33]:
from contextlib import nullcontext
from tqdm.notebook import trange

ctx = nullcontext()
epochs = 1
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in trange(epochs):
    for inputs in train_loader:
        inputs_1 = inputs["input_ids"]
        inputs_2 = inputs["aug_input_ids"]
        input_ids = torch.concat([inputs_1, inputs_2], dim=0)
        with ctx:
            embeddings = model(input_ids)
            loss = hard_negative_loss(embeddings)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        print(loss.item())
        

  0%|          | 0/1 [00:00<?, ?it/s]

0.16153472661972046
0.003507960354909301
0.003171156859025359
0.06943691521883011
0.0
0.1418997347354889
0.08032432943582535
0.0
0.5045386552810669
0.0
0.01415493618696928
0.5202610492706299
0.18204505741596222
0.1247250959277153
0.12076790630817413
0.4917925298213959
0.4096137285232544
0.08454518765211105
0.6235713362693787
0.2646808624267578
0.02988138236105442
0.418828547000885
0.1868470013141632
0.0
0.08570028841495514
0.0
0.11741338670253754
0.08273844420909882
0.11198565363883972
0.4179825782775879
0.0
0.6560600399971008
0.12432656437158585
0.0
0.6641465425491333
0.1819010078907013
0.0
0.002051190473139286
0.0
0.1700320988893509
0.10211546719074249
0.12547257542610168
0.1291946917772293
0.07902386784553528
0.12239798903465271
0.6986942291259766
0.0
0.0
0.04716791212558746
0.007289530243724585
0.06331422924995422
0.0
0.0
0.03713144361972809
0.0
0.021734334528446198
0.0
0.0
0.02637563645839691
0.644249439239502
0.05244774371385574
0.06585892289876938
0.029814869165420532
0.04806045

KeyboardInterrupt: 

In [26]:
loss = hard_negative_loss(embeddings)

In [27]:
loss

tensor(0.2057, grad_fn=<MeanBackward0>)

In [11]:
out = embeddings
batch_size = out.size(0)// 2
out_1, out_2 = out.split(batch_size, dim=0)

In [17]:
temperature = 0.07
neg = torch.exp(cosine_similarity(out, out) / temperature)

In [18]:
neg

tensor([[1600327.6250,  162046.5312,   75902.3984,  195953.3125,   21770.0996,
          145809.0000,  160701.6250,  191996.0000, 1126005.5000,  165599.6250,
           59242.7656,  115989.0469,   10814.1572,  106603.7344,  151332.7812,
          177264.5781],
        [ 162046.5312, 1600326.0000,   41476.6133,   47023.9961,    6186.6689,
          224424.9688,   82642.0000,   32690.5762,  131124.9531, 1240143.7500,
           27722.5234,   30601.1113,    3059.5376,  136082.9844,   90919.4609,
           48464.7188],
        [  75902.3984,   41476.6133, 1600323.0000,   54524.0859,   15670.7861,
           50957.6836,  135132.5625,   25715.7910,   75110.9609,   45155.6523,
          993864.0625,   53095.4023,    8681.9893,   37372.3906,  114471.0234,
           23520.1953],
        [ 195953.3125,   47023.9961,   54524.0859, 1600327.6250,   42133.5117,
          119657.9297,  371280.7188,  195052.5938,  156048.6250,   43597.9727,
           37018.2383, 1055215.0000,   27179.1406,   92970.

In [19]:

mask = get_negative_mask(batch_size).to(device)
neg = neg.masked_select(mask).view(2 * batch_size, -1)

In [22]:
        # pos score
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        pos = torch.cat([pos, pos], dim=0)

In [23]:
pos

tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
       grad_fn=<CatBackward0>)

In [None]:

        mask = get_negative_mask(batch_size).to(device)
        neg = neg.masked_select(mask).view(2 * batch_size, -1)

        # pos score
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        pos = torch.cat([pos, pos], dim=0)
        

        N = batch_size * 2 - 2
        imp = (beta* neg.log()).exp()
        reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
        Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)
        # constrain (optional)
        Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))

            
        # contrastive loss
        loss = (- torch.log(pos / (pos + Ng) )).mean()

        return loss

In [9]:
embeddings

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<DivBackward0>)

In [None]:
    batch_size = out.size(0)// 2
    out_1, out_2 = out.split(batch_size, dim=0)
    
    # neg score
    neg = torch.exp(cosine_similarity(out, out) / temperature)
    old_neg = neg.clone()
    mask = get_negative_mask(batch_size).to(device)
    neg = neg.masked_select(mask).view(2 * batch_size, -1)

    # pos score
    pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)
    

    N = batch_size * 2 - 2
    imp = (beta* neg.log()).exp()
    reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
    Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)
    # constrain (optional)
    Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))

        
    # contrastive loss
    loss = (- torch.log(pos / (pos + Ng) )).mean()

    return loss