# A simple RAG application using open-source models

In [1]:
from langchain.document_loaders.pdf import PyPDFDirectoryLoader

def load_documents():
    document_loader = PyPDFDirectoryLoader('Documents')
    return document_loader.load()

documents = load_documents()
print(len(documents)) # numero di pagine

Ignoring wrong pointing object 13 0 (offset 0)
Ignoring wrong pointing object 48 0 (offset 0)
Ignoring wrong pointing object 74 0 (offset 0)
Ignoring wrong pointing object 6 0 (offset 0)
Ignoring wrong pointing object 9 0 (offset 0)
Ignoring wrong pointing object 11 0 (offset 0)
Ignoring wrong pointing object 34 0 (offset 0)
Ignoring wrong pointing object 37 0 (offset 0)
Ignoring wrong pointing object 95 0 (offset 0)
Ignoring wrong pointing object 17 0 (offset 0)
Ignoring wrong pointing object 6 0 (offset 0)
Ignoring wrong pointing object 9 0 (offset 0)
Ignoring wrong pointing object 13 0 (offset 0)
Ignoring wrong pointing object 92 0 (offset 0)
Ignoring wrong pointing object 98 0 (offset 0)
Ignoring wrong pointing object 112 0 (offset 0)
Ignoring wrong pointing object 135 0 (offset 0)
Ignoring wrong pointing object 145 0 (offset 0)
Ignoring wrong pointing object 148 0 (offset 0)
Ignoring wrong pointing object 179 0 (offset 0)
Ignoring wrong pointing object 212 0 (offset 0)
Ignoring wr

355


In [2]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema.document import Document

def split_documents(documents: list[Document]):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = 400,
        chunk_overlap = 40,
        length_function = len,
        is_separator_regex = False
    )
    return text_splitter.split_documents(documents)

chunks = split_documents(documents)
print(chunks[0])
print(len(chunks))


page_content='Applied Thermal Engineering 228 (2023) 120454' metadata={'source': 'Documents/1-s2.0-S1359431123004830-main.pdf', 'page': 0}
4105


In [3]:
from langchain_community.embeddings.ollama import OllamaEmbeddings

def get_embedding_function():
    embeddings = OllamaEmbeddings(model = 'nomic-embed-text')
    return embeddings

def calculate_chunk_ids(chunks):
    # This will create IDs like "data/monopoly.pdf:6:2"
    # Page Source : Page Number : Chunk Index

    last_page_id = None
    current_chunk_index = 0

    for chunk in chunks:
        source = chunk.metadata.get("source")
        page = chunk.metadata.get("page")
        current_page_id = f"{source}:{page}"

        # If the page ID is the same as the last one, increment the index.
        if current_page_id == last_page_id:
            current_chunk_index += 1
        else:
            current_chunk_index = 0

        # Calculate the chunk ID.
        chunk_id = f"{current_page_id}:{current_chunk_index}"
        last_page_id = current_page_id

        # Add it to the page meta-data.
        chunk.metadata["id"] = chunk_id

    return chunks


In [4]:
from langchain.vectorstores.chroma import Chroma

def add_to_chroma(chunks: list[Document]):
    # Load the existing database.
    db = Chroma(
        persist_directory="chroma", embedding_function=get_embedding_function()
    )

    # Calculate Page IDs.
    chunks_with_ids = calculate_chunk_ids(chunks)

    # Add or Update the documents.
    existing_items = db.get(include=[])  # IDs are always included by default
    existing_ids = set(existing_items["ids"])
    print(f"Number of existing documents in DB: {len(existing_ids)}")

    # Only add documents that don't exist in the DB.
    new_chunks = []
    for chunk in chunks_with_ids:
        if chunk.metadata["id"] not in existing_ids:
            new_chunks.append(chunk)

    if len(new_chunks):
        print(f"👉 Adding new documents: {len(new_chunks)}")
        new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
        db.add_documents(new_chunks, ids=new_chunk_ids)
        db.persist()
    else:
        print("✅ No new documents to add")

In [5]:
from transformers import AutoTokenizer, GPT2LMHeadModel
import torch

if torch.backends.mps.is_available(): device = torch.device("mps")
else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f'Using device: {device}')

Using device: mps


In [6]:
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama
import gpt_wrapper
from gpt_wrapper.chat import Chat
gpt_wrapper.api_base = "http://mnlp-backend-938795011.eu-central-1.elb.amazonaws.com"
gpt_wrapper.api_key = "3e510581-28fa-4342-9758-4fa131bc2f42"

def query_rag(query_text: str, PROMPT_TEMPLATE):
    # Prepare the DB.
    embedding_function = get_embedding_function()
    db = Chroma(persist_directory="chroma", embedding_function=embedding_function)

    # Search the DB.
    results = db.similarity_search_with_score(query_text, k=1) # can change k to get more results

    context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
    #print(context_text)
    prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
    prompt = prompt_template.format(context=context_text, question=query_text).split("Human: \n")[1]
    #print(prompt)

    answer_chat = Chat.create("answer_chat") 
    response_text = str(answer_chat.ask(content = prompt))

    sources = [doc.metadata.get("id", None) for doc, _score in results]
    formatted_response = f"Response: {response_text}\nSources: {sources}"
    print(prompt)
    print("\n--------------------------\n")
    print(formatted_response)
    return response_text

In [7]:
embedding_function = get_embedding_function()
db = Chroma(persist_directory = "chroma", embedding_function = embedding_function)
add_to_chroma(chunks)
PROMPT_TEMPLATE = """
Based on the following context: {context} \nAnswer this question: {question}
"""
query_text = "How do heat pumps work in cold climates?"
query_rag(query_text, PROMPT_TEMPLATE)

Number of existing documents in DB: 4105
✅ No new documents to add
Based on the following context: and presents best practices for the design, performance assessment, and optimisation of
heat pumps for simultaneous heating and cooling.
An introductory article presents the uses of heat pump productions in the form of
an analysis of the thermal demands of different types of buildings and a literature review 
Answer this question: How do heat pumps work in cold climates?


--------------------------

Response: In cold climates, heat pumps work by extracting heat from the outside air, even in temperatures below freezing, and transferring it into the building to provide heating. They are able to operate efficiently in cold temperatures by utilizing advanced technology and refrigerants that are designed to work effectively in low temperatures. Additionally, some heat pumps have a defrost cycle that helps prevent the system from freezing up during cold weather conditions.
Sources: ['Documents



'In cold climates, heat pumps work by extracting heat from the outside air, even in temperatures below freezing, and transferring it into the building to provide heating. They are able to operate efficiently in cold temperatures by utilizing advanced technology and refrigerants that are designed to work effectively in low temperatures. Additionally, some heat pumps have a defrost cycle that helps prevent the system from freezing up during cold weather conditions.'

# RL - Optimized RAG

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gpt_wrapper
from gpt_wrapper.chat import Chat
from transformers import BertModel, BertTokenizer, BertConfig

if torch.backends.mps.is_available(): device = torch.device("mps")
else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f'Using device: {device}')
gpt_wrapper.api_base = "http://mnlp-backend-938795011.eu-central-1.elb.amazonaws.com"
gpt_wrapper.api_key = "3e510581-28fa-4342-9758-4fa131bc2f42"

# Define the policy network
class BertForPolicyNetwork(nn.Module):
    def __init__(self, config):
        super(BertForPolicyNetwork, self).__init__()
        self.bert = BertModel(config)
        self.classifier = nn.Linear(config.hidden_size, 2)  # 2 actions

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        return probabilities

def generate_query(state, chat, type_of_question):
    prompt = "Generate a question based on this previous conversation: " + state + "."
    
    if type_of_question == "new":
        instruction = " The question should be a new question about heat pumps."
    
    elif type_of_question == "follow_up":
        instruction = " The question should be a follow-up to the previous one. For example: 'Are you sure? or 'Tell me more about that.' or similar."
    
    else:
        instruction = " The question should be a question completely unrelated to heat pumps."

    query = str(chat.ask(content = prompt, instruction = instruction))
    return query

# Reward function
def get_reward(action, type_of_question):
    if action == "FETCH":
        if type_of_question == 'new':
            return 0.1 
        elif type_of_question == 'follow_up':
            return 0.1
        else:
            return - 1
    elif action == "NO_FETCH": 
        if type_of_question == 'new':
            return -2
        elif type_of_question == 'follow_up':
            return 1
        elif type_of_question == 'unrelated':
            return 2 
        
def update_accuracy(action, type_of_question, accuracy):
    if action == "FETCH":
        if type_of_question == 'new':
            return accuracy + 1
        elif type_of_question == 'follow_up':
            return accuracy 
        else:
            return accuracy
    elif action == "NO_FETCH": 
        if type_of_question == 'new':
            return accuracy
        elif type_of_question == 'follow_up':
            return accuracy + 1
        elif type_of_question == 'unrelated':
            return accuracy + 1

Using device: mps


In [11]:
# Training the policy network
from tqdm import tqdm
embedding_function = get_embedding_function()
db = Chroma(persist_directory = "chroma", embedding_function = embedding_function)
add_to_chroma(chunks)

def train_policy_network(policy_network, tokenizer, optimizer, num_of_questions = 20):
    gamma = 0.1  # discount factor for future rewards
    # set generator model to eval mode
    accuracies = []
    total_rewards = []
    for epoch in range(15):
        rewards = []
        log_probs = []
        state = "" 
        number_correct = 0
        for i in tqdm(range(num_of_questions)):
            question_generator_chat = Chat.create("question_generator_chat")
            if i == 0:
                type_of_question = "new"
            else:
                type_of_question = np.random.choice(["new", "follow_up", "unrelated"], p=[0.6, 0.2, 0.2])
            
            #print("Type of question: ", type_of_question)

            query = generate_query(state, question_generator_chat, type_of_question)

            #print("Query: ", query)
            state += "\nQ: " + query
            # keep only the most recent queries + context + answer triples
            if state.count("Q: ") > 2:
                state = state[state.index("Q: ", 1):]
                
            state_tokens = tokenizer(state, return_tensors = "pt").to(device)
               
            probs = policy_network(**state_tokens)
            # choose action with monte carlo dropout method
            m = torch.distributions.Categorical(probs)  
            action_index = m.sample()

            action = 'FETCH' if action_index.item() == 0 else 'NO_FETCH'
            correct_actions = update_accuracy(action, type_of_question, correct_actions)
            #print("Action: ", action)
            # print(action)
            if action == 'FETCH':
                results = db.similarity_search_with_score(query, k = 1)
                context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
                state += "\nContext: " + context_text
                answer = str(answer_chat.ask(content = "Based on this context: " + context_text + " Answer this question: " + query))
            
        else:
            with torch.no_grad():
                answer_chat = Chat.create("answer_chat")
                answer = str(answer_chat.ask(content = query))

            state += "\nAnswer: " + answer
            #print("Answer: ", answer)
            #print(gpt4_evaluation)
            reward = get_reward(action, type_of_question)
            rewards.append(reward)

            #print("Reward: ", reward)
            
            # Calculate the log probability
            log_prob = m.log_prob(action_index)
            log_probs.append(log_prob)
            del state_tokens, probs, m, action_index
            torch.cuda.empty_cache()

        accuracy = correct_actions / num_of_questions
        accuracies.append(accuracy)
        print(accuracy)
        # print sum of rewards
        total_rewards.append(sum(rewards))
        print(sum(rewards))
        # Update policy
        discounted_rewards = []
        cumulative_rewards = 0
        for reward in rewards[::-1]:
            cumulative_rewards = reward + gamma * cumulative_rewards
            discounted_rewards.insert(0, cumulative_rewards)

        discounted_rewards = torch.tensor(discounted_rewards)
        policy_loss = []
        for log_prob, discounted_reward in zip(log_probs, discounted_rewards):
            policy_loss.append(-log_prob * discounted_reward)

        cumulative_rewards
        optimizer.zero_grad()
        policy_loss = torch.stack(policy_loss).sum()
        policy_loss.backward()
        optimizer.step()
        del state, rewards, log_probs, discounted_rewards, policy_loss
        torch.cuda.empty_cache()

Number of existing documents in DB: 4105
✅ No new documents to add


In [12]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
tokenizer.model_max_length = 500
# Load the config
config = BertConfig.from_pretrained("bert-large-uncased")
# Update the max position embeddings
config.max_position_embeddings = 500
policy_network = BertForPolicyNetwork(config).to(device)
optimizer = optim.Adam(policy_network.parameters(), lr=5e-7)



In [13]:
train_policy_network(policy_network, tokenizer, optimizer, 20)

100%|██████████| 20/20 [00:26<00:00,  1.31s/it]


0.6
-5.3000000000000025


100%|██████████| 20/20 [00:23<00:00,  1.18s/it]


0.7
-0.20000000000000015


100%|██████████| 20/20 [00:22<00:00,  1.10s/it]


0.5
-5.500000000000002


100%|██████████| 20/20 [00:21<00:00,  1.09s/it]


0.65
-3.499999999999999


100%|██████████| 20/20 [00:23<00:00,  1.18s/it]


0.6
-3.5


100%|██████████| 20/20 [00:25<00:00,  1.25s/it]


0.7
-1.2999999999999994


100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


0.7
-1.2999999999999998


100%|██████████| 20/20 [00:24<00:00,  1.23s/it]


0.55
-0.20000000000000007


100%|██████████| 20/20 [00:25<00:00,  1.27s/it]


0.5
-2.4


  5%|▌         | 1/20 [00:01<00:35,  1.86s/it]


KeyboardInterrupt: 

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
from tqdm import tqdm

# Define Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return state, action, reward, next_state, done
    
    def __len__(self):
        return len(self.buffer)
