In [1]:
# setup dataset + index
from datasets import Dataset
import numpy as np
import faiss
retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": [str(i) for i in range(10)],
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)

dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)
from src.transformers.models.atlas.retrieval_atlas import AtlasRetrieverIndex, AtlasConfig, AtlasTokenizer

config = AtlasConfig.from_pretrained("./data/atlas-pretrained")
tokenizer = AtlasTokenizer.from_pretrained("./data/atlas-pretrained", config=config)

retriever_index = AtlasRetrieverIndex(config, tokenizer, dataset)

  0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
# for some reason, src.transformers.models.atlas.modeling_atlas as import causes an error here
# model doesn't load right?
from transformers import AtlasModel

atlas = AtlasModel.from_pretrained('data/atlas-pretrained', retriever_index=retriever_index)

In [5]:
retriever_index.reindex(atlas, batch_size=2)


inputs_string = ["What is my favourite number?", "What is the secret word?"]
target_string = ["3455", "FROG"]

inputs_string = [f"question: {question} answer: <extra_id_0>" for question in inputs_string]
target_string = [f"<extra_id_0> {answer}" for answer in target_string]

tokens = tokenizer.generator(inputs_string, return_tensors="pt", padding=True)
labels = tokenizer.generator(target_string, return_tensors="pt", padding=True)
query_tokens = tokenizer.retriever(inputs_string, return_tensors="pt", padding=True)

labels[labels == tokenizer.generator.pad_token_id] = -100

atlas.train()
atlas.config.query_side_retriever_training = True

print(tokens)
atlas.forward(
    input_ids=tokens.input_ids,
    attention_mask=tokens.attention_mask,
    labels=labels.input_ids,
    query_input_ids=query_tokens.input_ids,
    query_attention_mask=query_tokens.attention_mask,
    top_k=2,
)


  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'input_ids': tensor([[  822,    10,   363,    19,    82,  3960,   381,    58,  1525,    10,
             3, 32099,     1],
        [  822,    10,   363,    19,     8,  2829,  1448,    58,  1525,    10,
             3, 32099,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}




AtlasModelOutput(generator_loss=tensor(3.9709, grad_fn=<NllLossBackward0>), retriever_loss=tensor(0.0053, grad_fn=<KlDivBackward0>), logits=tensor([[[-12.8470,  -3.9406,  -3.9002,  ..., -13.3872, -11.8982, -13.3796],
         [-25.8175,  -2.2608,  -7.4646,  ..., -25.6973, -25.7258, -25.3556],
         [-30.8447,  -5.4745,  -8.6406,  ..., -30.5384, -30.6604, -30.6171],
         [-28.4645,  -3.0418,  -6.6645,  ..., -27.9618, -28.6331, -27.9040],
         [-31.9013,  -0.7016,  -4.6872,  ..., -31.3952, -32.0902, -31.1038]],

        [[-12.7249,  -5.5796,  -6.8293,  ..., -13.2572, -12.4315, -13.3793],
         [-37.6499,  -7.6553, -11.8667,  ..., -37.5812, -37.6253, -37.5362],
         [-37.7885,  -8.9861,  -7.1741,  ..., -37.3482, -38.3792, -37.5330],
         [-49.3809, -15.4851, -13.5717,  ..., -48.8605, -49.7353, -49.3180],
         [-33.3617,  -0.7058,  -6.9944,  ..., -33.0043, -33.3792, -32.9788]]],
       grad_fn=<UnsafeViewBackward0>), doc_scores=None, past_key_values=None, retrieve

In [4]:
atlas.train()
import torch
optimizer = torch.optim.Adam(atlas.parameters(), lr=1e-4)
for i in range(10):
    loss = atlas.forward(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        labels=labels.input_ids,
        query_input_ids=query_tokens.input_ids,
        query_attention_mask=query_tokens.attention_mask,
        top_k=2,
    )[1]
    print(loss)
    loss.backward()
    optimizer.step()
    # zero out gradients
    optimizer.zero_grad()





tensor(0.1015, grad_fn=<KlDivBackward0>)
tensor(0.1034, grad_fn=<KlDivBackward0>)
tensor(0.0187, grad_fn=<KlDivBackward0>)
tensor(0.0284, grad_fn=<KlDivBackward0>)
tensor(0.0025, grad_fn=<KlDivBackward0>)
tensor(0.0041, grad_fn=<KlDivBackward0>)
tensor(0.1115, grad_fn=<KlDivBackward0>)
tensor(0.0009, grad_fn=<KlDivBackward0>)
tensor(0.0028, grad_fn=<KlDivBackward0>)
tensor(0.0057, grad_fn=<KlDivBackward0>)
