In [1]:
import torch
checkpoint = torch.load('./data/model.pth.tar', map_location="cpu")
opt_checkpoint = checkpoint["opt"]
step = checkpoint["step"]
model_dict = checkpoint["model"]

# reader, reader_tokenizer = load_reader(opt)
# retriever, retriever_tokenizer = load_retriever(opt, opt_checkpoint)
# from src.transformers.models.atlas.fid import FiD
# generator = FiD.from_pretrained('t5-small')

model_dict = {
    k.replace("retriever.module", "query_passage_encoder")
    .replace("reader.module", "generator"): v for k, v in model_dict.items()
}

# model.load_state_dict()
print(model_dict.keys())

  from .autonotebook import tqdm as notebook_tqdm


dict_keys(['generator.shared.weight', 'generator.encoder.embed_tokens.weight', 'generator.encoder.block.0.layer.0.SelfAttention.q.weight', 'generator.encoder.block.0.layer.0.SelfAttention.k.weight', 'generator.encoder.block.0.layer.0.SelfAttention.v.weight', 'generator.encoder.block.0.layer.0.SelfAttention.o.weight', 'generator.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'generator.encoder.block.0.layer.0.layer_norm.weight', 'generator.encoder.block.0.layer.1.DenseReluDense.wi_0.weight', 'generator.encoder.block.0.layer.1.DenseReluDense.wi_1.weight', 'generator.encoder.block.0.layer.1.DenseReluDense.wo.weight', 'generator.encoder.block.0.layer.1.layer_norm.weight', 'generator.encoder.block.1.layer.0.SelfAttention.q.weight', 'generator.encoder.block.1.layer.0.SelfAttention.k.weight', 'generator.encoder.block.1.layer.0.SelfAttention.v.weight', 'generator.encoder.block.1.layer.0.SelfAttention.o.weight', 'generator.encoder.block.1.layer.0.layer_norm.weight', 'gen

In [2]:
from datasets import Dataset
import numpy as np
import faiss

retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": ["0", "1"] * 5,
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)


100%|██████████| 1/1 [00:00<00:00, 960.45it/s]


Dataset({
    features: ['id', 'text', 'embeddings'],
    num_rows: 10
})

In [3]:
from transformers import AtlasConfig, AutoConfig, AutoTokenizer, AtlasTokenizer

bertModelString = "facebook/contriever"
t5ModelString = "google/t5-base-lm-adapt"


bertConfig = AutoConfig.from_pretrained(bertModelString)
t5Config = AutoConfig.from_pretrained(t5ModelString)
config = AtlasConfig.from_query_encoder_generator_configs(
    bertConfig,
    t5Config
)

bertTokenizer = AutoTokenizer.from_pretrained(bertModelString)
t5Tokenizer = AutoTokenizer.from_pretrained(t5ModelString)

tokenizer = AtlasTokenizer(bertTokenizer, t5Tokenizer)

config.dataset = dataset
config.index_name = "custom"

config.n_context = 5
config.bsz = 2
config.generator.bsz = 2
config.generator.n_context = 2


Downloading: 100%|██████████| 655/655 [00:00<00:00, 315kB/s]
Downloading: 100%|██████████| 2.11k/2.11k [00:00<00:00, 1.38MB/s]
Downloading: 100%|██████████| 792k/792k [00:00<00:00, 1.99MB/s]
Downloading: 100%|██████████| 1.39M/1.39M [00:02<00:00, 520kB/s] 
Downloading: 100%|██████████| 1.79k/1.79k [00:00<00:00, 702kB/s]


In [6]:
from src.transformers.models.atlas.retriever import Contriever, UntiedDualEncoder, DualEncoderRetriever
from src.transformers.models.atlas.fid import FiD
from transformers import AtlasModel

contriever = Contriever.from_pretrained(bertModelString)

if 'query_passage_encoder.query_contriever.embeddings.position_ids' in model_dict:
    questionPassageEncoder = UntiedDualEncoder(config, contriever)
else:
    questionPassageEncoder = DualEncoderRetriever(config, contriever)


generator = FiD.from_pretrained(t5ModelString)
generator.encoder.config.bsz= 2
generator.encoder.config.n_context = 5

atlas = AtlasModel(config, questionPassageEncoder, generator, dataset, tokenizer)


Some weights of the model checkpoint at facebook/contriever were not used when initializing Contriever: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Downloading: 100%|██████████| 990M/990M [02:38<00:00, 6.25MB/s] 


In [7]:
atlas.load_state_dict(model_dict)


<All keys matched successfully>

In [8]:
# reencode the dataset

def reindex(examples):
    tokenized = tokenizer(examples['text'], return_tensors="pt", padding=True, truncation=True, max_length=512)
    hidden_states = atlas.query_passage_encoder.embed_passages(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"])
    examples['embeddings'] = hidden_states.cpu().detach().numpy()
    return examples

atlas.index = atlas.index.map(reindex, batched=True)
atlas.index.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)



100%|██████████| 1/1 [00:00<00:00, 14.49ba/s]
100%|██████████| 1/1 [00:00<00:00, 7307.15it/s]


Dataset({
    features: ['id', 'text', 'embeddings'],
    num_rows: 10
})

In [18]:
from functools import reduce

inputs = ["What is my favourite number?", "What is the secret word?"]
target = ["3455", "FROG"]

inputs = [f"question: {question} answer: <extra_id_0>" for question in inputs]
target = [f"<extra_id_0> {answer}" for answer in target]
print(inputs, target)

self = atlas
queries = inputs
topk = config.n_context



bsz = len(queries)


queries_tokens = self.query_encoder_tokenizer(queries, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(self.device)

query_hidden_states = self.query_passage_encoder(input_ids=queries_tokens["input_ids"], attention_mask=queries_tokens["attention_mask"])

query_hidden_states = query_hidden_states.cpu().detach().numpy()
_, passage_ids = self.index.search_batch("embeddings", query_hidden_states, topk)


docs = [self.index[[i for i in indices if i >= 0]] for indices in passage_ids]



passages = [[f'{queries[i]} context: {passage}' for passage in doc["text"]] for i, doc in enumerate(docs)]

def encode_passages(batch, tokenizer, max_length):
    bsz = len(batch)
    n = max([len(example) for example in batch])
    batch = [example + [""] * (n - len(example)) for example in batch]
    batch = reduce(lambda a, b: a + b, batch)
    tokens = tokenizer(
        batch,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
        truncation=True,
    )
    tokens = {k: v.view(bsz, n, -1) for k, v in tokens.items()}
    
    return tokens


reader_tokens = encode_passages(passages, self.generator_tokenizer, 512)
labels = self.generator_tokenizer(target, return_tensors="pt", padding="max_length", truncation=True, max_length=512)['input_ids']
labels[labels == self.generator_tokenizer.pad_token_id] = -100

reader_ids = reader_tokens["input_ids"]  # FIXME
reader_mask = reader_tokens["attention_mask"].bool()

n_context_training = min(topk, reader_ids.size(1))
cfg = self.generator.encoder.config
cfg.bsz = reader_ids.size(0)
cfg.n_context = n_context_training

reader_ids_training = reader_ids[:, :n_context_training].contiguous()
reader_mask_training = reader_mask[:, :n_context_training].contiguous()

reader_ids_training = reader_ids_training.view(reader_ids.size(0), -1)
reader_mask_training = reader_mask_training.view(reader_mask.size(0), -1)



reader_output = self.generator(
            input_ids=reader_ids_training,
            attention_mask=reader_mask_training,
            decoder_input_ids=None,
            labels=labels,
            use_cache=False,
        )

reader_output.logits

['question: What is my favourite number? answer: <extra_id_0>', 'question: What is the secret word? answer: <extra_id_0>'] ['<extra_id_0> 3455', '<extra_id_0> FROG']


tensor([[[-25.2537, -11.2296,  -8.5146,  ..., -26.4157, -24.2000, -25.8787],
         [-38.6698,  -9.2458, -13.9714,  ..., -38.4948, -38.6632, -38.2608],
         [-45.0214, -13.7125, -13.7919,  ..., -44.6558, -45.3572, -44.9504],
         ...,
         [-25.1870, -11.7906,  -9.0805,  ..., -26.1698, -24.3021, -25.6975],
         [-25.1880, -11.7905,  -9.0804,  ..., -26.1709, -24.3032, -25.6987],
         [-25.1892, -11.7904,  -9.0803,  ..., -26.1722, -24.3043, -25.6999]],

        [[-21.3347,  -9.7461,  -6.9815,  ..., -22.4861, -20.1532, -21.9368],
         [-38.0355, -11.0931, -11.7789,  ..., -37.8584, -38.3653, -37.8470],
         [-47.7797, -14.6103, -11.9611,  ..., -47.1814, -48.3027, -47.4551],
         ...,
         [-21.3352, -10.5644,  -7.5844,  ..., -22.3114, -20.2787, -21.8309],
         [-21.3360, -10.5644,  -7.5845,  ..., -22.3122, -20.2795, -21.8317],
         [-21.3367, -10.5645,  -7.5847,  ..., -22.3131, -20.2802, -21.8325]]],
       grad_fn=<UnsafeViewBackward0>)

In [16]:
reader_output_for_loss = self.generator(
    input_ids=reader_ids.view(reader_ids.size(0), -1),
    attention_mask=reader_mask.view(reader_mask.size(0), -1),
    decoder_input_ids=None,
    labels=labels,
    use_cache=False,
)
reader_output.loss.item()

0.011737149208784103

In [14]:
generated = self.generator.generate(
        input_ids=reader_ids_training,
        attention_mask=reader_mask_training,
)

tokenizer.generator.batch_decode(generated)



['<pad><extra_id_0> 3455</s><pad>', '<pad><extra_id_0> FROG</s>']

In [None]:
# question: What is my favourite number? answer: <extra_id_0>
# question: What is my favourite number? answer: <extra_id_0>