In [1]:
# setup dataset + index
from datasets import Dataset
import numpy as np
import faiss
retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": [str(i) for i in range(10)],
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)

dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)
from src.transformers.models.atlas.retrieval_atlas import AtlasRetrieverIndex, AtlasConfig, AtlasTokenizer

config = AtlasConfig.from_pretrained("./data/atlas-pretrained")
tokenizer = AtlasTokenizer.from_pretrained("./data/atlas-pretrained", config=config)

retriever_index = AtlasRetrieverIndex(config, tokenizer, dataset)

  0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
# for some reason, src.transformers.models.atlas.modeling_atlas as import causes an error here
# model doesn't load right?
from transformers import AtlasModel

atlas = AtlasModel.from_pretrained('data/atlas-pretrained', retriever_index=retriever_index)

In [3]:
retriever_index.reindex(atlas, batch_size=2)


inputs_string = ["What is my favourite number?", "What is the secret word?"]
target_string = ["3455", "FROG"]

inputs_string = [f"question: {question} answer: <extra_id_0>" for question in inputs_string]
target_string = [f"<extra_id_0> {answer}" for answer in target_string]

tokens = tokenizer.generator(inputs_string, return_tensors="pt", padding=True)
labels = tokenizer.generator(target_string, return_tensors="pt", padding=True)
query_tokens = tokenizer.retriever(inputs_string, return_tensors="pt", padding=True)

labels[labels == tokenizer.generator.pad_token_id] = -100

atlas.config.query_side_retriever_training = True

print(tokens)
atlas.forward(
    input_ids=tokens.input_ids,
    attention_mask=tokens.attention_mask,
    labels=labels.input_ids,
    query_input_ids=query_tokens.input_ids,
    query_attention_mask=query_tokens.attention_mask,
    top_k=2,
)


  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'input_ids': tensor([[  822,    10,   363,    19,    82,  3960,   381,    58,  1525,    10,
             3, 32099,     1],
        [  822,    10,   363,    19,     8,  2829,  1448,    58,  1525,    10,
             3, 32099,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tensor(0., grad_fn=<KlDivBackward0>)




(tensor(0., grad_fn=<KlDivBackward0>),
 tensor(3.6842, grad_fn=<NllLossBackward0>))

In [8]:
atlas.train()
import torch
optimizer = torch.optim.Adam(atlas.parameters(), lr=1e-4)
for i in range(10):
    loss = atlas.forward(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        labels=labels.input_ids,
        query_input_ids=query_tokens.input_ids,
        query_attention_mask=query_tokens.attention_mask,
        top_k=2,
    )[1]
    print(loss)
    loss.backward()
    optimizer.step()
    # zero out gradients
    optimizer.zero_grad()



tensor(0.1129, grad_fn=<KlDivBackward0>)
tensor(0.1129, grad_fn=<KlDivBackward0>)
tensor(0.0097, grad_fn=<KlDivBackward0>)
tensor(0.0097, grad_fn=<KlDivBackward0>)
tensor(0.0370, grad_fn=<KlDivBackward0>)
tensor(0.0370, grad_fn=<KlDivBackward0>)
tensor(0.0181, grad_fn=<KlDivBackward0>)
tensor(0.0181, grad_fn=<KlDivBackward0>)
tensor(0.0073, grad_fn=<KlDivBackward0>)
tensor(0.0073, grad_fn=<KlDivBackward0>)
tensor(0.0342, grad_fn=<KlDivBackward0>)
tensor(0.0342, grad_fn=<KlDivBackward0>)
tensor(0.5594, grad_fn=<KlDivBackward0>)
tensor(0.5594, grad_fn=<KlDivBackward0>)
tensor(0.9392, grad_fn=<KlDivBackward0>)
tensor(0.9392, grad_fn=<KlDivBackward0>)
tensor(0.1170, grad_fn=<KlDivBackward0>)
tensor(0.1170, grad_fn=<KlDivBackward0>)
tensor(0.0789, grad_fn=<KlDivBackward0>)
tensor(0.0789, grad_fn=<KlDivBackward0>)
