In [355]:
from sentence_transformers import SentenceTransformer

bi_encoder = SentenceTransformer('multi-qa-mpnet-base-dot-v1')

In [356]:
import datasets

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:50000]').shuffle(seed=42)


Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)
Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-6f69f77053a477e7.arrow


In [357]:
queries = dataset['query'][:-1]
answers = [dataset[i]['passages']['passage_text'][0] for i in range(len(dataset) - 1)]
negatives = [dataset[i]['passages']['passage_text'][0] for i in range(1, len(dataset))]

dataset = datasets.Dataset.from_dict({'query': queries, 'answer': answers, 'negative': negatives}).with_format('torch')

In [358]:
def compute_embeddings(batch):
    query_embeddings = bi_encoder.encode(batch['query'], convert_to_tensor=True)
    answer_embeddings = bi_encoder.encode(batch['answer'], convert_to_tensor=True)
    negative_embeddings = bi_encoder.encode(batch['negative'], convert_to_tensor=True)
    return {
        'query': batch['query'],
        'answer': batch['answer'],
        'query_embeddings': query_embeddings,
        'answer_embeddings': answer_embeddings,
        'negative_embeddings': negative_embeddings
    }

dataset = dataset.map(compute_embeddings, batched=True, batch_size=32, remove_columns=dataset.column_names)

  0%|          | 0/1563 [00:00<?, ?ba/s]

In [359]:
dataset = dataset.train_test_split(test_size=0.1)

In [389]:
import torch

import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.input = torch.nn.Linear(768, 768)
        self.hidden = torch.nn.Linear(768, 768)
        self.hidden2 = torch.nn.Linear(768, 768)
        self.output = torch.nn.Linear(768, 768)
        self.apply(self._init_weights)
        
    def forward(self, query_embeddings, answer_embeddings = None, negative_embeddings = None):
        x = torch.nn.functional.relu(self.input(query_embeddings))
        x = torch.nn.functional.relu(self.hidden(x))
        x = torch.nn.functional.relu(self.hidden2(x))
        x = self.output(x)
        if answer_embeddings is None:
            return x
        if negative_embeddings is None:
            return torch.cdist(x, answer_embeddings)
        return torch.cdist(x, answer_embeddings)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.eye_(module.weight.data)
            if module.bias is not None:
                module.bias.data.zero_()


In [402]:
model = Model()

device = torch.device("cuda")

model = model.to(device)

In [398]:
train_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=1)
test_loader = torch.utils.data.DataLoader(dataset['test'], batch_size=1)

In [455]:
from torch.optim import AdamW


model.train()


optimizer = AdamW(model.parameters(), lr=1e-6, betas=(0.9,0.999), eps=1e-08)

batch_loss = 0

for epoch in range(1):
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        query_embeddings = batch['query_embeddings'].to(device)
        answer_embeddings = batch['answer_embeddings'].to(device)
        negative_embeddings = batch['answer_embeddings'].to(device)
        loss = model(query_embeddings, answer_embeddings, negative_embeddings)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        batch_loss += loss.item()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * query_embeddings.shape[0], len(train_loader.dataset),
                100. * batch_idx / len(train_loader), batch_loss / 500))
            batch_loss = 0



In [456]:
model.eval()

test_loss = 0
iterations = 0
with torch.no_grad():
    for batch in test_loader:
        query_embeddings = batch['query_embeddings'].to(device)
        answer_embeddings = batch['answer_embeddings'].to(device)
        loss = model(query_embeddings, answer_embeddings)
        test_loss += loss.item()
        iterations += 1

print('Test set: Average loss: {:.6f}'.format(test_loss / iterations))

Test set: Average loss: 5.171576


In [365]:
torch.save(model, '~/models/qa-autoencoder')

In [373]:
model(dataset['test'][1]['query_embeddings'].unsqueeze(0).cuda(), dataset['test'][1]['answer_embeddings'].unsqueeze(0).cuda())


tensor([[5.3928]], device='cuda:0', grad_fn=<CdistBackward0>)

In [374]:
model(dataset['test'][1]['query_embeddings'].unsqueeze(0).cuda(), dataset['test'][5]['answer_embeddings'].unsqueeze(0).cuda())

tensor([[6.6375]], device='cuda:0', grad_fn=<CdistBackward0>)

In [375]:
1 - F.cosine_similarity(torch.tensor([[1.0, -1.0]]), torch.tensor([[1.0, 1.0]]))

tensor([1.])

In [457]:
n = 500
correct = 0
for i in range(n):
    dist = torch.norm(model(dataset['test'][i]['query_embeddings'].cuda()) - dataset['test']['answer_embeddings'].cuda(), dim=1)
    if i in dist.topk(10, largest=False).indices:
        correct += 1

print('Accuracy: {:.2f}%'.format(correct / n * 100))

Accuracy: 83.80%


In [454]:
n = 500
correct = 0
for i in range(n):
    dist = torch.norm(dataset['test'][i]['query_embeddings'].cuda() - dataset['test']['answer_embeddings'].cuda(), dim=1)
    if i in dist.topk(10, largest=False).indices:
        correct += 1

print('Accuracy: {:.2f}%'.format(correct / n * 100))

KeyboardInterrupt: 

In [450]:
list(model.named_parameters())

[('input.weight',
  Parameter containing:
  tensor([[ 9.2294e-01, -2.3622e-03, -8.9584e-03,  ..., -6.7652e-03,
           -3.5332e-03, -6.2148e-04],
          [-2.6781e-03,  9.1454e-01, -1.7326e-02,  ..., -6.3215e-03,
           -2.9847e-03, -5.9126e-03],
          [ 4.1321e-05, -7.0142e-04,  9.7410e-01,  ..., -1.4900e-05,
           -2.2402e-03, -4.3997e-03],
          ...,
          [ 1.1665e-03, -1.1044e-04, -1.0033e-02,  ...,  9.2585e-01,
            3.5307e-03,  2.4766e-03],
          [ 4.3231e-04, -4.8298e-04, -8.5984e-03,  ..., -3.2194e-03,
            9.3160e-01, -5.7781e-03],
          [ 3.3451e-03, -1.5221e-03, -4.9937e-03,  ..., -3.3635e-03,
           -1.5077e-03,  9.3447e-01]], device='cuda:0', requires_grad=True)),
 ('input.bias',
  Parameter containing:
  tensor([ 7.4603e-03,  1.8334e-02,  3.4891e-03,  6.6953e-03,  6.9675e-03,
           7.2083e-03,  2.9177e-03,  2.1013e-03,  9.2054e-03, -9.5998e-05,
           5.9117e-03,  6.7753e-03,  7.2387e-03,  9.2988e-03,  7.1064e-