In [None]:
import torch
checkpoint = torch.load('./data/model.pth.tar', map_location="cpu")
opt_checkpoint = checkpoint["opt"]
step = checkpoint["step"]
model_dict = checkpoint["model"]

# reader, reader_tokenizer = load_reader(opt)
# retriever, retriever_tokenizer = load_retriever(opt, opt_checkpoint)
# from src.transformers.models.atlas.fid import FiD
# generator = FiD.from_pretrained('t5-small')

model_dict = {
    k.replace("retriever.module", "query_passage_encoder")
    .replace("reader.module", "generator"): v for k, v in model_dict.items()
}
print(model_dict)

In [None]:
from datasets import Dataset
import numpy as np
import faiss

retrieval_vector_size = 768

dataset = Dataset.from_dict(
    {
        "id": ["0", "1"] * 5,
        "text": ["My favourite number is 3455", "The secret word is FROG"] * 5,
        "embeddings": [
            0.1 * np.ones(retrieval_vector_size),
            0.9 * np.ones(retrieval_vector_size),
        ] * 5,
    }
)
dataset.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)


In [None]:
from transformers import AtlasConfig, AutoConfig, AutoTokenizer, AtlasTokenizer

bertModelString = "facebook/contriever"
t5ModelString = "google/t5-base-lm-adapt"


bertConfig = AutoConfig.from_pretrained(bertModelString)
t5Config = AutoConfig.from_pretrained(t5ModelString)
config = AtlasConfig.from_query_passage_encoder_generator_configs(
    bertConfig,
    t5Config
)

bertTokenizer = AutoTokenizer.from_pretrained(bertModelString)
t5Tokenizer = AutoTokenizer.from_pretrained(t5ModelString)

tokenizer = AtlasTokenizer(bertTokenizer, t5Tokenizer)

config.n_context = 5
config.bsz = 2
config.generator.bsz = 2
config.generator.n_context = 2


In [None]:
from src.transformers.models.atlas.retriever import Contriever, UntiedDualEncoderRetriever, DualEncoderRetriever
from src.transformers.models.atlas.fid import FiD
from transformers import AtlasModel

contriever = Contriever.from_pretrained(bertModelString)

# if 'query_passage_encoder.query_contriever.embeddings.position_ids' in model_dict:
#     questionPassageEncoder = UntiedDualEncoderRetriever(config, contriever)
# else:
questionPassageEncoder = DualEncoderRetriever(config, contriever)


generator = FiD.from_pretrained(t5ModelString)
generator.encoder.config.bsz= 2
generator.encoder.config.n_context = 5

atlas = AtlasModel(config, questionPassageEncoder, generator, dataset)


In [None]:
atlas.load_state_dict(model_dict)


In [None]:
atlas.save_pretrained('data/atlas-pretrained')

In [None]:
# reencode the dataset

def reindex(examples):
    tokenized = tokenizer(examples['text'], return_tensors="pt", padding=True, truncation=True, max_length=512)
    hidden_states = atlas.query_passage_encoder.embed_passages(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"])
    examples['embeddings'] = hidden_states.cpu().detach().numpy()
    return examples

atlas.index = atlas.index.map(reindex, batched=True)
atlas.index.add_faiss_index("embeddings", metric_type=faiss.METRIC_INNER_PRODUCT)



In [None]:
from functools import reduce

inputs = ["What is my favourite number?", "What is the secret word?"]
target = ["3455", "FROG"]

inputs = [f"question: {question} answer: <extra_id_0>" for question in inputs]
target = [f"<extra_id_0> {answer}" for answer in target]
print(inputs, target)

self = atlas
queries = inputs
topk = config.n_context



bsz = len(queries)


queries_tokens = self.query_encoder_tokenizer(queries, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(self.device)

query_hidden_states = self.query_passage_encoder(input_ids=queries_tokens["input_ids"], attention_mask=queries_tokens["attention_mask"])

query_hidden_states = query_hidden_states.cpu().detach().numpy()
_, passage_ids = self.index.search_batch("embeddings", query_hidden_states, topk)


docs = [self.index[[i for i in indices if i >= 0]] for indices in passage_ids]



passages = [[f'{queries[i]} context: {passage}' for passage in doc["text"]] for i, doc in enumerate(docs)]

def encode_passages(batch, tokenizer, max_length):
    bsz = len(batch)
    n = max([len(example) for example in batch])
    batch = [example + [""] * (n - len(example)) for example in batch]
    batch = reduce(lambda a, b: a + b, batch)
    tokens = tokenizer(
        batch,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
        truncation=True,
    )
    tokens = {k: v.view(bsz, n, -1) for k, v in tokens.items()}
    
    return tokens


reader_tokens = encode_passages(passages, self.generator_tokenizer, 512)
labels = self.generator_tokenizer(target, return_tensors="pt", padding="max_length", truncation=True, max_length=512)['input_ids']
labels[labels == self.generator_tokenizer.pad_token_id] = -100

reader_ids = reader_tokens["input_ids"]  # FIXME
reader_mask = reader_tokens["attention_mask"].bool()

n_context_training = min(topk, reader_ids.size(1))
cfg = self.generator.encoder.config
cfg.bsz = reader_ids.size(0)
cfg.n_context = n_context_training

reader_ids_training = reader_ids[:, :n_context_training].contiguous()
reader_mask_training = reader_mask[:, :n_context_training].contiguous()

reader_ids_training = reader_ids_training.view(reader_ids.size(0), -1)
reader_mask_training = reader_mask_training.view(reader_mask.size(0), -1)



reader_output = self.generator(
            input_ids=reader_ids_training,
            attention_mask=reader_mask_training,
            decoder_input_ids=None,
            labels=labels,
            use_cache=False,
        )

reader_output.logits

In [None]:
reader_output_for_loss = self.generator(
    input_ids=reader_ids.view(reader_ids.size(0), -1),
    attention_mask=reader_mask.view(reader_mask.size(0), -1),
    decoder_input_ids=None,
    labels=labels,
    use_cache=False,
)
reader_output.loss.item()

In [None]:
generated = self.generator.generate(
        input_ids=reader_ids_training,
        attention_mask=reader_mask_training,
)

tokenizer.generator.batch_decode(generated)

In [None]:
# question: What is my favourite number? answer: <extra_id_0>
# question: What is my favourite number? answer: <extra_id_0>

In [None]:
from transformers import AtlasModel
atlas_2 = AtlasModel.from_pretrained('data/atlas-pretrained')

print(atlas_2.state_dict().keys())
# atlas_2("test", "test", None, 1)