In [1]:
from sentence_transformers import SentenceTransformer

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

In [2]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:100000]').shuffle(seed=42)
dataset = dataset.train_test_split(test_size=0.1, seed=42)

Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)
Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-89575f570aa58446.arrow
Loading cached split indices for dataset at /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-82afc0a1dfe05f54.arrow and /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-839d92ed9c39c390.arrow


In [4]:
train = dataset['train']

queries = train['query']
passages = [item['passages']['passage_text'][0] for item in train]

queries_encodings = bi_encoder.encode(queries, convert_to_tensor=True)
passages_encodings = bi_encoder.encode(passages, convert_to_tensor=True)

In [13]:
import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self, encoding_dim):
        super().__init__()
        self.hidden_dim = 2 * encoding_dim
        self.feed_forward = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(self.hidden_dim * 2, self.hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(self.hidden_dim * 2, 1)
        )

    def forward(self, query, passage):
        x = torch.cat([query, passage], dim=-1)
        x = self.feed_forward(x)
        x = torch.sigmoid(x)
        return x

In [5]:
from sentence_transformers import util

def generate_hard_negative(i):
    query = dataset['train'][i]['query']
    query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embedding, passages_encodings, top_k=2)
    hits = hits[0]
    if hits[0]['corpus_id'] == i:
        return hits[1]['corpus_id']
    else:
        return hits[0]['corpus_id']
    
hard_negatives = [passages_encodings[generate_hard_negative(i)] for i in range(len(dataset['train']))]


In [6]:
import itertools
passages_negative_encodings = list(itertools.islice(itertools.cycle(passages_encodings), 1, len(passages_encodings) + 1))

In [7]:
training_examples = []

for i in range(len(queries_encodings)):
    training_examples.append((queries_encodings[i], passages_encodings[i], 1))
    # training_examples.append((queries_encodings[i], hard_negatives[i] if i % 2 == 0 else passages_negative_encodings[i], -1))
    training_examples.append((queries_encodings[i], hard_negatives[i], -1))

import random
random.shuffle(training_examples)

In [19]:
model = Model(384).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
criterion = torch.nn.MSELoss()
for epoch in range(5):
    total_loss = 0
    for i, (query, passage, label) in enumerate(training_examples):
        optimizer.zero_grad()
        output = model(query.cuda(), passage.cuda())
        # loss = torch.nn.functional.margin_ranking_loss(output, torch.zeros_like(output).cuda(), torch.tensor([label]).cuda(), margin=1)
        loss = criterion(output, torch.tensor([label]).float().cuda())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        log_interval = 1000
        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            print(f'epoch {epoch}, iter {i}, loss {cur_loss}')
            total_loss = 0
    scheduler.step()

epoch 0, iter 1000, loss 1.0169987825453282
epoch 0, iter 2000, loss 0.9994893597066402
epoch 0, iter 3000, loss 1.0008970820903778
epoch 0, iter 4000, loss 1.0008500580489637
epoch 0, iter 5000, loss 0.9996972008943558
epoch 0, iter 6000, loss 1.0001417140364648
epoch 0, iter 7000, loss 0.9953153203949332
epoch 0, iter 8000, loss 1.0056423263326286
epoch 0, iter 9000, loss 1.0022458066940307
epoch 0, iter 10000, loss 0.9923955171108246
epoch 0, iter 11000, loss 0.9957776233404875
epoch 0, iter 12000, loss 0.997160516962409
epoch 0, iter 13000, loss 0.992397736787796
epoch 0, iter 14000, loss 0.9996234012544155
epoch 0, iter 15000, loss 0.9974916656017303
epoch 0, iter 16000, loss 0.9844332495480775
epoch 0, iter 17000, loss 0.9845375006496906
epoch 0, iter 18000, loss 0.992303868740797
epoch 0, iter 19000, loss 0.9690869277082383
epoch 0, iter 20000, loss 0.9769618330905214
epoch 0, iter 21000, loss 0.9792052135504782
epoch 0, iter 22000, loss 0.9756225202009082
epoch 0, iter 23000, l

KeyboardInterrupt: 

In [None]:
torch.save(model, '~/models/cross-embeddings.pt')