In [1]:
# setup dataset + index
from datasets import Dataset
import numpy as np
import faiss
retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": [str(i) for i in range(10)],
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)

dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)
from src.transformers.models.atlas.retrieval_atlas import AtlasRetrieverIndex, AtlasConfig, AtlasTokenizer

config = AtlasConfig.from_pretrained("./data/atlas-pretrained")
tokenizer = AtlasTokenizer.from_pretrained("./data/atlas-pretrained", config=config)

retriever_index = AtlasRetrieverIndex(config, tokenizer, dataset)

  0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
# for some reason, src.transformers.models.atlas.modeling_atlas as import causes an error here
# model doesn't load right?
from transformers import AtlasModel

atlas = AtlasModel.from_pretrained('data/atlas-pretrained', retriever_index=retriever_index)

In [3]:
retriever_index.reindex(atlas, batch_size=2)


inputs_string = ["What is my favourite number?", "What is the secret word?"]
target_string = ["3455", "FROG"]

inputs_string = [f"question: {question} answer: <extra_id_0>" for question in inputs_string]
target_string = [f"<extra_id_0> {answer}" for answer in target_string]

tokens = tokenizer.generator(inputs_string, return_tensors="pt", padding=True)
labels = tokenizer.generator(target_string, return_tensors="pt", padding=True)
query_tokens = tokenizer.retriever(inputs_string, return_tensors="pt", padding=True)

labels[labels == tokenizer.generator.pad_token_id] = -100

atlas.train()
atlas.config.query_side_retriever_training = True

print(tokens)
atlas.forward(
    input_ids=tokens.input_ids,
    attention_mask=tokens.attention_mask,
    labels=labels.input_ids,
    query_input_ids=query_tokens.input_ids,
    query_attention_mask=query_tokens.attention_mask,
    top_k=2,
)


  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'input_ids': tensor([[  822,    10,   363,    19,    82,  3960,   381,    58,  1525,    10,
             3, 32099,     1],
        [  822,    10,   363,    19,     8,  2829,  1448,    58,  1525,    10,
             3, 32099,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}




AtlasModelOutput(generator_loss=tensor(3.7933, grad_fn=<NllLossBackward0>), retriever_loss=tensor(0.2630, grad_fn=<KlDivBackward0>), logits=tensor([[[-13.9893,  -4.8677,  -4.1102,  ..., -14.7916, -13.3198, -14.7123],
         [-22.3417,  -4.3604,  -5.7955,  ..., -22.0388, -22.2972, -21.5707],
         [-33.0299,  -7.8454, -10.1410,  ..., -32.7671, -33.2933, -32.7036],
         [-33.9165,  -4.7252, -11.1489,  ..., -33.7144, -34.1410, -33.6054],
         [-28.7629,  -5.0741,  -7.5760,  ..., -28.2001, -28.8387, -27.9074]],

        [[-22.9633,  -9.1757,  -8.2084,  ..., -24.1951, -22.0178, -23.0390],
         [-24.6265,  -2.8294,  -7.8373,  ..., -24.5500, -24.8253, -24.2540],
         [-24.9997,  -5.5261,  -2.2094,  ..., -24.5674, -25.0943, -24.5938],
         [-42.4483, -13.1286, -11.4029,  ..., -41.9430, -43.0953, -42.3096],
         [-30.0594,  -1.1307,  -7.7721,  ..., -29.8447, -30.2589, -29.9028]]],
       grad_fn=<UnsafeViewBackward0>), doc_scores=None, past_key_values=None, retrieve

In [25]:
atlas.eval()
generated = atlas.generate(
    input_ids=tokens.input_ids,
    attention_mask=tokens.attention_mask,
    query_input_ids=query_tokens.input_ids,
    query_attention_mask=query_tokens.attention_mask,
    top_k=2,
)

decoded = tokenizer.generator.batch_decode(generated)
print("OUTPUT:", decoded)
labels_decoded = tokenizer.generator.batch_decode(labels.input_ids)
print("EXPECTED:", labels_decoded)
    

OUTPUT: ['<pad><extra_id_0> 3455</s><pad>', '<pad><extra_id_0> FROG</s>']
EXPECTED: ['<extra_id_0> 3455</s><pad>', '<extra_id_0> FROG</s>']


In [19]:
atlas.train()
import torch
optimizer = torch.optim.Adam(atlas.parameters(), lr=1e-5)

for i in range(10):
    loss = atlas.forward(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        labels=labels.input_ids,
        query_input_ids=query_tokens.input_ids,
        query_attention_mask=query_tokens.attention_mask,
        top_k=2,
    )
    print(loss.generator_loss, loss.retriever_loss)
    loss.generator_loss.backward()
    # loss.retriever_loss.backward()

    optimizer.step()
    # zero out gradients
    optimizer.zero_grad()



tensor(2.6709, grad_fn=<NllLossBackward0>) tensor(0.0017, grad_fn=<KlDivBackward0>)
tensor(2.7269, grad_fn=<NllLossBackward0>) tensor(0.0130, grad_fn=<KlDivBackward0>)
tensor(2.5194, grad_fn=<NllLossBackward0>) tensor(0.0037, grad_fn=<KlDivBackward0>)
tensor(2.8665, grad_fn=<NllLossBackward0>) tensor(0.0024, grad_fn=<KlDivBackward0>)
tensor(2.5058, grad_fn=<NllLossBackward0>) tensor(0.0002, grad_fn=<KlDivBackward0>)
tensor(2.5434, grad_fn=<NllLossBackward0>) tensor(0.0004, grad_fn=<KlDivBackward0>)
tensor(2.6853, grad_fn=<NllLossBackward0>) tensor(0.0023, grad_fn=<KlDivBackward0>)
tensor(2.7317, grad_fn=<NllLossBackward0>) tensor(0.0013, grad_fn=<KlDivBackward0>)
tensor(2.4069, grad_fn=<NllLossBackward0>) tensor(3.0175e-05, grad_fn=<KlDivBackward0>)
tensor(2.3856, grad_fn=<NllLossBackward0>) tensor(0.0140, grad_fn=<KlDivBackward0>)


KeyboardInterrupt: 