In [1]:
!pip install transformers faiss-cpu torch

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.meta

In [2]:
import torch
import numpy as np
import faiss

from transformers import (
    DPRContextEncoder, DPRContextEncoderTokenizer,
    DPRQuestionEncoder, DPRQuestionEncoderTokenizer,
    AutoTokenizer, AutoModelForCausalLM
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------------------------------------------------------------
# STEP 1: Prepare a small corpus of documents
# -----------------------------------------------------------------------------------
documents = [
    "The company offers 15 days of paid vacation per year.",
    "Employees should submit reimbursement forms within 30 days.",
    "Mobile phones must be secured with a company-approved password.",
    "Remote work is allowed up to 3 days per week.",
    "Drinking alcohol during work hours is strictly prohibited."
]


In [3]:

# -----------------------------------------------------------------------------------
# STEP 2: Load DPR Context Encoder and Tokenizer
# These will convert documents into dense embeddings
# -----------------------------------------------------------------------------------
context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

# Tokenize and encode each document
context_embeddings = []
for doc in documents:
    inputs = context_tokenizer(doc, return_tensors='pt', truncation=True, padding=True).to(device)
    with torch.no_grad():
        embedding = context_encoder(**inputs).pooler_output  # (1, 768)
    context_embeddings.append(embedding.cpu().numpy())

# Stack into numpy array for FAISS
context_embeddings_np = np.vstack(context_embeddings).astype('float32')  # Shape: (5, 768)


config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [4]:

# -----------------------------------------------------------------------------------
# STEP 3: Create FAISS index
# This allows us to retrieve similar documents by vector similarity
# -----------------------------------------------------------------------------------
embedding_dim = context_embeddings_np.shape[1]
index = faiss.IndexFlatL2(embedding_dim)  # L2 = Euclidean distance
index.add(context_embeddings_np)  # Add document embeddings to index


In [5]:

# -----------------------------------------------------------------------------------
# STEP 4: Load DPR Question Encoder
# This will embed the user query in the same vector space
# -----------------------------------------------------------------------------------
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")


config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [6]:

# -----------------------------------------------------------------------------------
# STEP 5: Define a query and retrieve top-k relevant documents
# -----------------------------------------------------------------------------------
query = "What is the mobile phone policy?"

# Tokenize and encode the question
inputs = question_tokenizer(query, return_tensors="pt").to(device)
with torch.no_grad():
    query_embedding = question_encoder(**inputs).pooler_output.cpu().numpy()

# Search FAISS for top 2 closest documents
D, I = index.search(query_embedding, k=2)

print("Top matching documents:")
for idx in I[0]:
    print("-", documents[idx])


Top matching documents:
- Mobile phones must be secured with a company-approved password.
- The company offers 15 days of paid vacation per year.


In [7]:

# -----------------------------------------------------------------------------------
# STEP 6: Load GPT-2 for generation
# -----------------------------------------------------------------------------------
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
gpt_model.eval()

# Set special token to avoid warnings
gpt_model.generation_config.pad_token_id = gpt_tokenizer.eos_token_id


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [8]:

# -----------------------------------------------------------------------------------
# STEP 7a: Generate answer WITHOUT context
# -----------------------------------------------------------------------------------
def generate_without_context(query):
    inputs = gpt_tokenizer(query, return_tensors="pt").to(device)
    output = gpt_model.generate(inputs["input_ids"], max_new_tokens=50)
    return gpt_tokenizer.decode(output[0], skip_special_tokens=True)

print("\nAnswer without context:")
print(generate_without_context(query))


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Answer without context:
What is the mobile phone policy?

The mobile phone policy is a policy that allows you to use your mobile phone for any purpose. It is a policy that allows you to use your mobile phone for any purpose. It is a policy that allows you to use your mobile phone for


In [9]:

# -----------------------------------------------------------------------------------
# STEP 7b: Generate answer WITH top-k context
# -----------------------------------------------------------------------------------
def generate_with_context(query, retrieved_docs):
    full_input = query + " " + " ".join(retrieved_docs)
    inputs = gpt_tokenizer(full_input, return_tensors="pt", truncation=True, max_length=1024).to(device)
    output = gpt_model.generate(inputs["input_ids"], max_new_tokens=50)
    return gpt_tokenizer.decode(output[0], skip_special_tokens=True)

# Use the top 2 FAISS matches
retrieved = [documents[i] for i in I[0]]

print("\nAnswer with retrieved context:")
print(generate_with_context(query, retrieved))



Answer with retrieved context:
What is the mobile phone policy? Mobile phones must be secured with a company-approved password. The company offers 15 days of paid vacation per year.

What is the mobile phone policy? Mobile phones must be secured with a company-approved password. The company offers 15 days of paid vacation per year. What is the mobile phone policy? Mobile phones must be secured with a company-approved password
