In [1]:
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional

import faiss
import torch
from datasets import Features, Sequence, Value, load_dataset

from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    HfArgumentParser,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenizer,
)

  from .autonotebook import tqdm as notebook_tqdm


In [101]:
logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"


def split_text(text: str, n=10, character=" ") -> List[str]:
    """Split the text every ``n``-th occurrence of ``character``"""
    text = text.split(character)
    return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]


def split_documents(documents: dict) -> dict:
    """Split documents into passages"""
    titles, texts = [], []
    for title, text in zip(documents["title"], documents["text"]):
        if text is not None:
            for passage in split_text(text):
                titles.append(title if title is not None else "")
                texts.append(passage)
    return {"title": titles, "text": texts}


def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
    """Compute the DPR embeddings of document passages"""
    input_ids = ctx_tokenizer(
        documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
    )["input_ids"]
    embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
    # return embeddings.detach().cpu().numpy()
    return {"embeddings": embeddings.detach().cpu().numpy().flatten()}

In [64]:
path = "dataset.csv"

In [102]:
assert os.path.isfile(path), "Please provide a valid path to a csv file"

# You can load a Dataset object this way
dataset = load_dataset(
    "csv", data_files=[path], delimiter=","
)

# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets?highlight=csv#csv-files

# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=1)
dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'text'],
        num_rows: 2432974
    })
})

In [93]:
count = 0
for row in dataset['train']:
    text = row["text"]
    print (row)
    if count==10:
        break
    count+=1

{'title': 'Klutch 15-Slot Universal Wrench Pouch', 'text': 'description: Sturdy Klutch fabric pouch is ideal for organizing, storing and transporting wrenches. Includes 15 slots that provide ample room'}
{'title': 'Klutch 15-Slot Universal Wrench Pouch', 'text': 'for SAE and metric wrenches. Capacity qty. 15, Mounting Type Drawer, hanging, Storage Type Pouch. 15 slots for wrenches Tie-trap'}
{'title': 'Klutch 15-Slot Universal Wrench Pouch', 'text': 'design keeps things stored neatly Rolls up to save space Eyelets for mounting Pouch is 17.5in.H x 24.75in.W Slot dimensions'}
{'title': 'Klutch 15-Slot Universal Wrench Pouch', 'text': 'smallest 1 1/4in. and largest 2 5/8in. Model Number: 81684. Age Group: Adult.; url: northerntool.com; retailer: Northern Tool; brand: Klutch'}
{'title': 'TAG OFF Skin Natural Skin Tag Remover Take Skin Tag Away', 'text': 'description: Tag OFF "Skin Tag Remover" is a topical remedy made from all-natural plant extracts that help eliminate those harmless'}
{'t

In [100]:
# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
new_features = Features(
    {"title": Value("string"), "text": Value("string"), "embeddings": Sequence(Value("float32"))}
)  # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
    partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
    batched=False,
    batch_size=1,
    features=new_features,
)

# And finally save your dataset
# passages_path = os.path.join(path.output_dir, "my_knowledge_dataset")
dataset.save_to_disk("test3")
# from datasets import load_from_disk
# dataset = load_from_disk("test3")  # to reload the dataset

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokeniz

RuntimeError: The size of tensor a (757) must match the size of tensor b (512) at non-singleton dimension 1

In [None]:
index = faiss.IndexHNSWFlat(768, 16, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)

# And save the index
# index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save("index_path")
# dataset.load_faiss_index("embeddings", index_path)  # to reload the index

In [None]:
# Easy way to load the model
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base", index_name="custom", indexed_dataset=dataset
)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")

# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

In [None]:
question = "Is Bob nice?"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
print ("Q: " + question)
print ("A: " + generated_string)