In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
# Load WikiText-2 dataset
dataset_name = "wikitext"
dataset_instance = "wikitext-2-v1"

In [3]:
from datasets import load_dataset

dataset = load_dataset(dataset_name, dataset_instance, split="train", download_mode="reuse_dataset_if_exists")

new_column = [dataset_name] * len(dataset)
dataset = dataset.add_column("title", new_column)
print(dataset)

Dataset({
    features: ['text', 'title'],
    num_rows: 36718
})


In [4]:
from transformers import AutoTokenizer, RagRetriever, RagModel, RagSequenceForGeneration, RagTokenizer
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, RagConfig
from datasets import Features, Sequence, Value
from functools import partial

batch_size = 1

enc_model_name = "facebook/dpr-ctx_encoder-single-nq-base"

def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizer) -> dict:
    """Compute the DPR embeddings of document passages"""
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    return {"embeddings": embeddings.detach().cpu().numpy().flatten()}

ctx_encoder = DPRContextEncoder.from_pretrained(enc_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(enc_model_name)
new_features = Features(
    {"title": Value("string"), "text": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space
dataset_mapped = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched = True if batch_size > 1 else False,
    batch_size = batch_size,
    features=new_features,
)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [5]:
# Save dataset
data_set_path = "rag-" + dataset_name
dataset_mapped.save_to_disk(data_set_path)

Saving the dataset (0/1 shards):   0%|          | 0/36718 [00:00<?, ? examples/s]

In [4]:
from datasets import load_from_disk
# Load datasets
data_set_path = "rag-" + dataset_name
dataset_mapped = load_from_disk(data_set_path)  # to reload the dataset

In [6]:
import faiss

# Use the Faiss implementation of HNSW for fast approximate nearest neighbor search
faiss_index = faiss.IndexHNSWFlat(768, 16, faiss.METRIC_INNER_PRODUCT)
dataset_mapped.add_faiss_index("embeddings", custom_index=faiss_index)


  0%|          | 0/37 [00:00<?, ?it/s]

Dataset({
    features: ['title', 'text', 'embeddings'],
    num_rows: 36718
})

In [7]:
# Save Faiss index
dataset_mapped.get_index("embeddings").save(data_set_path + ".faiss")

In [5]:
# Load Faiss index
dataset_mapped.load_faiss_index("embeddings", data_set_path + ".faiss")

In [8]:
from transformers import AutoTokenizer, RagRetriever, RagModel, RagSequenceForGeneration, RagTokenizer
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, RagConfig

# Easy way to load the model
rag_model_name = "facebook/rag-token-base"
retriever = RagRetriever.from_pretrained(
        rag_model_name, index_name="custom", indexed_dataset=dataset_mapped
    )
model = RagSequenceForGeneration.from_pretrained(rag_model_name, retriever=retriever)
tokenizer = RagTokenizer.from_pretrained(rag_model_name)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

In [9]:
question = "Who is wikipedia?"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
print ("Q: " + question)
print ("A: " + generated_string)



Q: Who is wikipedia?
A: FinFinkelstein. wikitext / Finkelstein noted 20 instances, in as
