In [1]:
# setup dataset
from datasets import Dataset
import numpy as np
import faiss
retrieval_vector_size = 10

dataset = Dataset.from_dict(
    {
        "id": [str(i) for i in range(10)],
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)

# We'd normally use faiss.METRIC_INNER_PRODUCT, but we're using METRIC_L2 here to make the results easier to understand
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_L2)


  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['id', 'text', 'embeddings'],
    num_rows: 10
})

In [2]:
from src.transformers.models.atlas.retrieval_atlas import AtlasRetrieverIndex, AtlasConfig, AtlasTokenizer

config = AtlasConfig.from_pretrained("./data/atlas-pretrained")
tokenizer = AtlasTokenizer.from_pretrained("./data/atlas-pretrained", config=config)

retriever_index = AtlasRetrieverIndex(config, tokenizer, dataset)

In [3]:
# test retriever
retriever_hidden_states = 0.2 * np.zeros((1, retrieval_vector_size))
generator_input_ids = tokenizer.generator("hello world", return_tensors="pt")["input_ids"]

retrieved = retriever_index(retriever_hidden_states, generator_input_ids, 5)

print(retrieved)
print(tokenizer.generator.batch_decode(retrieved["input_ids"][0]))

{'input_ids': tensor([[[21820,   296,  2625,    10,   499,  3960,   381,    19,  6154,  3769,
              1],
         [21820,   296,  2625,    10,   499,  3960,   381,    19,  6154,  3769,
              1],
         [21820,   296,  2625,    10,   499,  3960,   381,    19,  6154,  3769,
              1],
         [21820,   296,  2625,    10,   499,  3960,   381,    19,  6154,  3769,
              1],
         [21820,   296,  2625,    10,   499,  3960,   381,    19,  6154,  3769,
              1]]]), 'attention_mask': tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]])}
['hello world context: My favourite number is 3455</s>', 'hello world context: My favourite number is 3455</s>', 'hello world context: My favourite number is 3455</s>', 'hello world context: My favourite number is 3455</s>', 'hello world context: My fa

In [4]:
# test reindex
from transformers import AtlasModel
atlas = AtlasModel.from_pretrained('data/atlas-pretrained', retriever_index=retriever_index)

retriever_index.reindex(atlas, batch_size=2)

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]