In [5]:
from datasets import Dataset
import numpy as np
import faiss

retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": ["0", "1"] * 5,
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)


100%|██████████| 1/1 [00:00<00:00, 931.03it/s]


Dataset({
    features: ['id', 'text', 'embeddings'],
    num_rows: 10
})

In [8]:
from transformers import AtlasModel, AtlasTokenizer


# tokenizer = AtlasTokenizer.from_pretrained('data/atlas-pretrained')
atlas = AtlasModel.from_pretrained('data/atlas-pretrained', index=dataset)
# 

query_passage_encoder DualEncoderRetriever(
  (contriever): Contriever(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,)

In [9]:
print(atlas.query_passage_encoder.contriever.encoder.layer[0].intermediate.dense.weight)

Parameter containing:
tensor([[ 0.0104,  0.0103,  0.0042,  ..., -0.0640,  0.0247, -0.0228],
        [-0.0708,  0.0181,  0.0579,  ..., -0.0396, -0.0310,  0.0322],
        [ 0.0053,  0.0258,  0.0131,  ..., -0.0067,  0.0291, -0.0206],
        ...,
        [-0.0684, -0.0747,  0.0236,  ..., -0.0172, -0.0154,  0.0140],
        [ 0.0713, -0.0287,  0.0121,  ..., -0.0211,  0.0136,  0.0103],
        [-0.0063,  0.0234,  0.0178,  ...,  0.0117,  0.0115,  0.0515]],
       requires_grad=True)


In [3]:
from datasets import Dataset
import numpy as np
import faiss

retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": ["0", "1"] * 5,
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)
from transformers import AutoTokenizer, AtlasTokenizer
bertModelString = "facebook/contriever"
t5ModelString = "google/t5-large-lm-adapt"
bertTokenizer = AutoTokenizer.from_pretrained(bertModelString)
t5Tokenizer = AutoTokenizer.from_pretrained(t5ModelString)

tokenizer = AtlasTokenizer(bertTokenizer, t5Tokenizer)
atlas.index = dataset
atlas.query_encoder_tokenizer = tokenizer.query_encoder
atlas.generator_tokenizer = tokenizer.generator

In [None]:
def reindex(examples):
    tokenized = tokenizer(examples['text'], return_tensors="pt", padding=True, truncation=True, max_length=512)
    hidden_states = atlas.query_passage_encoder.embed_passages(input_ids=tokenized["input_ids"].to(atlas.device), attention_mask=tokenized["attention_mask"].to(atlas.device))
    examples['embeddings'] = hidden_states.cpu().detach().numpy()
    return examples

atlas.index = atlas.index.map(reindex, batched=True)
atlas.index.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)


In [5]:
inputs = ["What is my favourite number?", "What is the secret word?"]
target = ["3455", "FROG"]

inputs = [f"question: {question} answer: <extra_id_0>" for question in inputs]
target = [f"<extra_id_0> {answer}" for answer in target]

atlas.forward(
    inputs,
    target,
    None,
    2
)

Seq2SeqLMOutput(loss=tensor(0.0117, grad_fn=<NllLossBackward0>), logits=tensor([[[-25.2537, -11.2296,  -8.5146,  ..., -26.4157, -24.2000, -25.8787],
         [-38.6698,  -9.2458, -13.9714,  ..., -38.4948, -38.6632, -38.2607],
         [-45.0214, -13.7125, -13.7919,  ..., -44.6558, -45.3571, -44.9504],
         [-35.5092,   0.5112,  -8.9547,  ..., -34.9981, -35.5459, -35.1949],
         [-33.0204,  -2.6576,  -6.8508,  ..., -32.5283, -33.0761, -32.5818]],

        [[-21.3347,  -9.7461,  -6.9815,  ..., -22.4861, -20.1532, -21.9368],
         [-38.0355, -11.0931, -11.7789,  ..., -37.8584, -38.3653, -37.8470],
         [-47.7797, -14.6103, -11.9611,  ..., -47.1814, -48.3027, -47.4551],
         [-83.5846, -24.5328, -26.5067,  ..., -82.6250, -84.3740, -83.8924],
         [-35.4897,   1.0183,  -8.4577,  ..., -35.0151, -35.6077, -35.1869]]],
       grad_fn=<UnsafeViewBackward0>), past_key_values=None, decoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, encoder_last_hidd

In [5]:
# reencode the dataset

def reindex(examples):
    tokenized = tokenizer(examples['text'], return_tensors="pt", padding=True, truncation=True, max_length=512)
    hidden_states = atlas.query_passage_encoder.embed_passages(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"])
    examples['embeddings'] = hidden_states.cpu().detach().numpy()
    return examples

atlas.index = atlas.index.map(reindex, batched=True)
atlas.index.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)



  0%|          | 0/1 [00:00<?, ?ba/s]


NameError: name 'tokenizer' is not defined

In [None]:
from functools import reduce

inputs = ["What is my favourite number?", "What is the secret word?"]
target = ["3455", "FROG"]

inputs = [f"question: {question} answer: <extra_id_0>" for question in inputs]
target = [f"<extra_id_0> {answer}" for answer in target]
print(inputs, target)

self = atlas
queries = inputs
topk = config.n_context



bsz = len(queries)


queries_tokens = self.query_encoder_tokenizer(queries, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)

query_hidden_states = self.query_passage_encoder(input_ids=queries_tokens["input_ids"], attention_mask=queries_tokens["attention_mask"])

query_hidden_states = query_hidden_states.cpu().detach().numpy()
_, passage_ids = self.index.search_batch("embeddings", query_hidden_states, topk)


docs = [self.index[[i for i in indices if i >= 0]] for indices in passage_ids]



passages = [[f'{queries[i]} context: {passage}' for passage in doc["text"]] for i, doc in enumerate(docs)]

def encode_passages(batch, tokenizer, max_length):
    bsz = len(batch)
    n = max([len(example) for example in batch])
    batch = [example + [""] * (n - len(example)) for example in batch]
    batch = reduce(lambda a, b: a + b, batch)
    tokens = tokenizer(
        batch,
        padding=True,
        max_length=max_length,
        return_tensors="pt",
        truncation=True,
    )
    tokens = {k: v.view(bsz, n, -1) for k, v in tokens.items()}
    
    return tokens


reader_tokens = encode_passages(passages, self.generator_tokenizer, 512)
labels = self.generator_tokenizer(target, return_tensors="pt", padding=True, truncation=True, max_length=512)['input_ids']
labels[labels == self.generator_tokenizer.pad_token_id] = -100

reader_ids = reader_tokens["input_ids"]  # FIXME
reader_mask = reader_tokens["attention_mask"].bool()

n_context_training = min(topk, reader_ids.size(1))
cfg = self.generator.encoder.config
cfg.bsz = reader_ids.size(0)
cfg.n_context = n_context_training

reader_ids_training = reader_ids[:, :n_context_training].contiguous()
reader_mask_training = reader_mask[:, :n_context_training].contiguous()

reader_ids_training = reader_ids_training.view(reader_ids.size(0), -1)
reader_mask_training = reader_mask_training.view(reader_mask.size(0), -1)



reader_output = self.generator(
            input_ids=reader_ids_training,
            attention_mask=reader_mask_training,
            decoder_input_ids=None,
            labels=labels,
            use_cache=False,
        )

reader_output.logits.shape

In [None]:
reader_output_for_loss = self.generator(
    input_ids=reader_ids.view(reader_ids.size(0), -1),
    attention_mask=reader_mask.view(reader_mask.size(0), -1),
    decoder_input_ids=None,
    labels=labels,
    use_cache=False,
)
reader_output.loss


In [None]:
generated = self.generator.generate(
        input_ids=reader_ids_training,
        attention_mask=reader_mask_training,
)

tokenizer.generator.batch_decode(generated)

In [None]:
# question: What is my favourite number? answer: <extra_id_0>
# question: What is my favourite number? answer: <extra_id_0>