In [2]:
import faiss


In [4]:
!nvcc --version

import torch
print(torch.cuda.is_available())  # Should print True if CUDA is available
print(torch.version.cuda)         # Should print 12.4 or the desired CUDA version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Fri_Jun_14_16:44:19_Pacific_Daylight_Time_2024
Cuda compilation tools, release 12.6, V12.6.20
Build cuda_12.6.r12.6/compiler.34431801_0
True
12.6


In [1]:
!pip install datasets


Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-19.0.1-cp312-cp312-win_amd64.whl.metadata (3.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp312-cp312-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.11.14-cp312-cp312-win_amd64.whl.metadata (8.0 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets)
  Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets)
  Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting frozenlist>=1.1.1 (


[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
from transformers import AutoTokenizer, RagRetriever, RagModel, RagSequenceForGeneration, RagTokenizer


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [6]:
from datasets import load_dataset

# Load WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
print(dataset)  # Check the structure (should have 'text' column)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 278461.61 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 2338071.84 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 2281293.66 examples/s]

Dataset({
    features: ['text'],
    num_rows: 36718
})





In [7]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch

encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

# Function to create embeddings
def create_embeddings(dataset):
    embeddings = []
    for doc in dataset:
        # Tokenize the text
        inputs = tokenizer(doc['text'], truncation=True, padding=True, return_tensors="pt", max_length=512)

        # Generate embeddings
        with torch.no_grad():
            output = encoder(**inputs)
            embeddings.append(output.pooler_output.squeeze().cpu().numpy())
    return embeddings

# Create embeddings for the WikiText dataset
embeddings = create_embeddings(dataset)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this check

In [8]:
import faiss
import numpy as np

# Convert embeddings into a numpy array
embedding_matrix = np.array(embeddings)

# Create FAISS index using L2 distance
index = faiss.IndexFlatL2(embedding_matrix.shape[1])  # L2 distance
index.add(embedding_matrix)  # Add the embeddings to the index

# Save the FAISS index for later use
faiss.write_index(index, "wikitext_index.faiss")

In [9]:
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
# retriever = RagRetriever.from_pretrained(
#     "facebook/rag-token-base", index_name="exact"
# )
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base",
    index_name=index,  # Custom index name
    # passages_path="path/to/your/custom_dataset",  # Optional: path to dataset (not needed if you have a FAISS index)
    index_path="wikitext_index.faiss",  # Path to the FAISS index file
)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in 

TypeError: Object of type IndexFlatL2 is not JSON serializable