1. Import Libraries

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Dict
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!pip install accelerate



2. Build System

In [3]:
class SQuADRAGSystem:
    def __init__(self, model_name='bert-base-uncased', trained_model_path=None):
        self.squad_dataset = self._load_squad_dataset()
        
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if trained_model_path:
            self.model = AutoModelForQuestionAnswering.from_pretrained(trained_model_path)
        else:
            self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        
        self.processed_dataset = self._preprocess_squad_dataset()
        
        self.document_texts = self._extract_context_passages()
        self.document_embeddings = self._embed_documents()
        self.faiss_index = self._create_faiss_index()
        
    def _load_squad_dataset(self):
        return load_dataset('squad', split='train')
    
    def _preprocess_squad_dataset(self):
        def preprocess_function(examples):
            questions = examples['question']
            contexts = examples['context']
            answers = examples['answers']
            
            tokenized_examples = self.tokenizer(
                questions,
                contexts,
                truncation=True,
                max_length=348,
                stride=128,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                padding='max_length'
            )
            sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
            offset_mapping = tokenized_examples.pop("offset_mapping")
            
            tokenized_examples["start_positions"] = []
            tokenized_examples["end_positions"] = []
            
            for i, offsets in enumerate(offset_mapping):
                input_ids = tokenized_examples["input_ids"][i]
                cls_index = input_ids.index(self.tokenizer.cls_token_id)
                sequence_ids = tokenized_examples.sequence_ids(i)
                
                sample_index = sample_mapping[i]
                answer = answers[sample_index]
                
                if len(answer['text']) == 0:
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    start_char = answer['answer_start'][0]
                    end_char = start_char + len(answer['text'][0])    
                    
                    token_start_index = 0
                    while sequence_ids[token_start_index] != 1:
                        token_start_index += 1
                    
                    token_end_index = len(input_ids) - 1 
                    while sequence_ids[token_end_index] != 1:
                        token_end_index -=1
                    
                    if not(offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                        tokenized_examples["start_positions"].append(cls_index)
                        tokenized_examples["end_positions"].append(cls_index)
                    else:
                        while (token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char):      
                             token_start_index += 1
                        tokenized_examples["start_positions"].append(token_start_index -1)
                        while (offsets[token_end_index][1] >= end_char):
                            token_end_index -= 1
                        tokenized_examples["end_positions"].append(token_start_index + 1)
            return tokenized_examples
        
        return self.squad_dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=self.squad_dataset.column_names
        )
    
    def _extract_context_passages(self) -> List[str]:
        return list(set(self.squad_dataset['context']))
    
    def _embed_documents(self) -> np.array:
        return np.array(self.embedding_model.encode(self.document_texts))
    
    def _create_faiss_index(self):
         dimension = self.document_embeddings.shape[1]
         index = faiss.IndexFlatL2(dimension)
         index.add(self.document_embeddings)
         return index
     
    def retrieve_relevant_documents(self, query: str, top_k: int=3) -> List[str]:
        query_embedding = self.embedding_model.encode([query])
        distances, indices = self.faiss_index.search(query_embedding, top_k)
        
        return [self.document_texts[i] for i in indices[0]]
    
    def train_model(self, output_dir='./results', epochs=3):
        training_args = TrainingArguments(
            output_dir=output_dir,
            evaluation_strategy='epoch',
            learning_rate=2e-5,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=64,
            num_train_epochs=epochs,
            weight_decay=0.01,
            push_to_hub=False
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.processed_dataset,
            eval_dataset=self.processed_dataset
        )
        
        trainer.train()
        
        trainer.save_model(output_dir)
        self.model = AutoModelForQuestionAnswering.from_pretrained(output_dir)
    def generate_answer(self, query:str) -> Dict[str, str]:
        relevant_docs = self.retrieve_relevant_documents(query)
        context = " ".join(relevant_docs)
        inputs = self.tokenizer(query, context, return_tensors="pt", truncation=True)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        start_logit, end_logits = outputs.start_logits, outputs.end_logits
        start_index = torch.argmax(start_logit)
        end_index = torch.argmax(end_logits)
        
        answer_tokens = inputs['input_ids'][0][start_index:end_index+1]
        answer = self.tokenizer.decode(answer_tokens)
        
        return {
            'query': query,
            'answer': answer,
            'context': context,
            'confidence': torch.max(start_logit) + torch.max(end_logits)
        }

In [4]:
rag_system = SQuADRAGSystem()

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 87599/87599 [00:32<00:00, 2712.40 examples/s]


In [None]:
rag_system.train_model(output_dir='./trained_model', epochs=3)

In [None]:
queries = [
        "Who wrote Harry Potter?",
        "What is the capital of France?",
        "When was the United States founded?"
    ]

In [None]:
for query in queries:
        result = rag_system.generate_answer(query)
        print(f"\nQuery: {result['query']}")
        print(f"Answer: {result['answer']}")
        print(f"Context: {result['context'][:200]}...")
        print(f"Confidence: {result['confidence']:.2f}")