<a href="https://colab.research.google.com/github/hieudeptrai123-sudo/AIO_RAGforJack/blob/main/RAG_Faiss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import Libraries


In [None]:
!git clone https://github.com/hieudeptrai123-sudo/AIO_RAGforJack.git
!mv /content/AIO_RAGforJack/LegalDataset.py /content

In [None]:
!pip install pyvi transformers faiss-gpu torch datasets unidecode gdown

In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm #If this cause IProgree Error, change to tqdm instead of tqdm.notebook
from transformers import RagRetriever, RagTokenForGeneration, DPRContextEncoder, DPRQuestionEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, AdamW
import torch
from torch.utils.data import DataLoader
from LegalDataset import LegalDataset
import faiss
from datasets import Dataset
import time
import gdown

In [None]:
gdown.download_folder('https://drive.google.com/drive/folders/1XpqF_ejSmQQJ4IsO38hJDZgMWLZelJyW?usp=sharing',quiet = False)

## Training

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Function to free CUDA memory
def free_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# Step 1: Load dataset
corpus_df = pd.read_csv('/content/AIO_RAGforJack/corpus.csv')  # Assuming columns: 'context' and 'cid'
train_df = pd.read_csv('/content/AIO_RAGforJack/train.csv')  # Assuming columns: 'question', 'context', 'cid', 'qid'
public_test_df = pd.read_csv('/content/AIO_RAGforJack/public_test.csv')  # Assuming columns: 'question', 'qid'

# Step 2: Extract passages and their IDs
passages = corpus_df['text'].tolist()
passage_ids = corpus_df['cid'].tolist()

In [None]:
# Step 3: Load DPR context and question encoders with appropriate tokenizers
dpr_context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", ignore_mismatched_sizes=True).to(device)
dpr_context_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
dpr_question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
dpr_question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

In [None]:
# Step 4: Create a custom retriever using FAISS
class FaissRetriever(RagRetriever):
    def __init__(self, index, passages, tokenizer, top_k=5):
        self.index = index
        self.passages = passages
        self.tokenizer = tokenizer
        self.top_k = top_k

    def retrieve(self, question_input_ids, question_hidden_states, question_attention_mask=None):
        # Tokenize and embed the question using DPRQuestionEncoder
        question_embeds = self._embed_question(question_input_ids)

        # Search FAISS index for top-k relevant documents
        _, indices = self.index.search(question_embeds.cpu().numpy(), self.top_k)  # FAISS on CPU

        # Return top-k passages
        retrieved_passages = [self.passages[i] for i in indices[0]]
        return retrieved_passages

    def _embed_question(self, question_input_ids):
        question = self.tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
        inputs = dpr_question_tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)

        # Move inputs to the CUDA
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            question_embeddings = dpr_question_encoder(**inputs).pooler_output
        return question_embeddings

In [None]:
# Step 5: Embed all passages using DPRContextEncoder with memory management
def embed_passages(passages, batch_size=32):
    all_embeddings = []

    # Wrap the for loop with tqdm to display the progress
    for i in tqdm(range(0, len(passages), batch_size), desc="Embedding Passages"):
        batch = passages[i:i + batch_size]
        inputs = dpr_context_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128)

        # Move inputs to the CUDA
        inputs = {key: value.to(device) for key, value in inputs.items()}

        # Disable gradient computation for faster inference
        with torch.no_grad():
            embeddings = dpr_context_encoder(**inputs).pooler_output

        # Move embeddings to CPU to free up CUDA memory
        all_embeddings.append(embeddings.cpu())

        # Free CUDA memory
        free_memory()

    # Concatenate all the embeddings into a single tensor
    return torch.cat(all_embeddings, dim=0)

In [None]:
# Get embeddings for all passages
passage_embeddings = embed_passages(passages)
free_memory()

In [None]:
# Step 6: Use FAISS to index the passage embeddings
d = passage_embeddings.shape[1]  # Dimensionality of the embeddings

# Create a FAISS index for inner product (dot product) on CPU
index = faiss.IndexFlatIP(d)  # Inner Product index on CPU

# Add embeddings to the index (no need to move to GPU)
index.add(passage_embeddings.cpu().numpy())  # Add embeddings to the index

# Optional: If you want to verify the number of embeddings added
print(f"Total passages indexed: {index.ntotal}")

# Clear memory if needed
free_memory()

In [None]:
def train_model(model, train_loader, optimizer, device, num_epochs):
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            context_input_ids = batch['context_input_ids'].to(device)
            context_attention_mask = batch['context_attention_mask'].to(device)

            # Print shapes of input tensors for debugging
            print("Input IDs shape:", input_ids.shape)
            print("Attention mask shape:", attention_mask.shape)
            print("Context input IDs shape:", context_input_ids.shape)
            print("Context attention mask shape:", context_attention_mask.shape)

            # Check input shapes
            try:
                # Expected shape: [batch_size, sequence_length]
                expected_shape = (input_ids.size(0), input_ids.size(1))  # (batch_size, seq_length)
                assert input_ids.dim() == 2, "Input IDs should be 2-dimensional"
                assert input_ids.shape == expected_shape, f"Expected input_ids shape {expected_shape}, but got {input_ids.shape}"

                assert attention_mask.dim() == 2, "Attention mask should be 2-dimensional"
                assert attention_mask.shape == expected_shape, f"Expected attention_mask shape {expected_shape}, but got {attention_mask.shape}"

                assert context_input_ids.dim() == 2, "Context input IDs should be 2-dimensional"
                assert context_input_ids.shape == expected_shape, f"Expected context_input_ids shape {expected_shape}, but got {context_input_ids.shape}"

                assert context_attention_mask.dim() == 2, "Context attention mask should be 2-dimensional"
                assert context_attention_mask.shape == expected_shape, f"Expected context_attention_mask shape {expected_shape}, but got {context_attention_mask.shape}"

                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    context_input_ids=context_input_ids,
                    context_attention_mask=context_attention_mask,
                    labels=context_input_ids
                )

                # Process loss and perform backpropagation
                loss = outputs.loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            except AssertionError as e:
                print(f"Shape Assertion Error: {e}")
            except Exception as e:
                print(f"Error during forward pass: {e}")

        print(f"Epoch {epoch + 1} completed. Loss: {epoch_loss / len(train_loader)}")

In [None]:
# Full training script with updated function

if __name__ == '__main__':
    # Assuming train_df is already defined and preprocessed
    train_dataset = LegalDataset(
        df=train_df,
        tokenizer_question=dpr_question_tokenizer,
        tokenizer_context=dpr_context_tokenizer,
        max_length=128  # Maximum token length for truncation
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True,
        num_workers=4
    )

    retriever = FaissRetriever(index=index, passages=passages, tokenizer=dpr_question_tokenizer, top_k=5)
    model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(device)
    free_memory
    optimizer = AdamW(model.parameters(), lr=5e-5)

    model.train()
    # Train the model
    train_model(model, train_loader, optimizer, device, num_epochs=50)

    # Save the fine-tuned model
    model.save_pretrained('fine_tuned_rag_model')
    print("Fine-tuning completed and model saved.")

## Prediction

In [None]:
# Step 9: Loading the fine-tuned model
print("Loading the fine-tuned model...")
retriever = FaissRetriever(index=index, passages=passages, tokenizer=dpr_question_tokenizer, top_k=10)
fine_tuned_model = RagTokenForGeneration.from_pretrained('fine_tuned_rag_model', retriever=retriever).to(device)
print("Fine-tuned model loaded successfully.")
# Step 8: Predict top-k documents for each query in public_test.csv with progress bar
top_k_predictions = []
fine_tuned_model.eval()
# Use tqdm to display progress while generating predictions
for idx, row in tqdm(public_test_df.iterrows(), total=len(public_test_df), desc="Processing Queries"):
    question = row['question']
    qid = row['qid']

    # Tokenize question
    input_ids = dpr_question_tokenizer(question, return_tensors="pt", max_length=128, truncation=True, padding=True).input_ids

    # Move inputs to the MPS
    input_ids = input_ids.to(device)

    # Generate top-k passage predictions
    retrieved_passages = retriever.retrieve(input_ids, None)

    # Retrieve the passage IDs (cid) for the predicted passages
    retrieved_cids = [corpus_df.loc[corpus_df['text'] == passage]['cid'].values[0] for passage in retrieved_passages]

    # Append result (qid followed by top-k cids)
    top_k_predictions.append(f"{qid} " + " ".join(map(str, retrieved_cids)))

    # Free MPS memory after each query
    free_memory()

# Step 9: Save predictions to predict.txt
with open('predict.txt', 'w') as f:
    for prediction in top_k_predictions:
        f.write(prediction + '\n')

print("Predictions saved to predict.txt.")