In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import numpy as np
import torch
import faiss
import json
import math
import time
import os

DECODER_PATH='../../Llama-3.2-1B-Instruct'
ENCODER_PATH = "../../bge-small-en"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DOCS_PATH = "../../dataset_txt_small/train"
QUESTIONS_PATH = "../rag_questions_json"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class RAGEngine:
    def __init__(self):
        self.decoder = AutoModelForCausalLM.from_pretrained(DECODER_PATH, torch_dtype=torch.bfloat16).to(DEVICE)
        self.encoder = SentenceTransformer(ENCODER_PATH).to(DEVICE)

        self.decoder.config.use_cache = True

        self.tokenizer = AutoTokenizer.from_pretrained(DECODER_PATH)

    def embed_documents(self, docs):
        return self.encoder.encode(docs)

    def embed_query(self, query):
        return self.encoder.encode([query])
    
engine = RAGEngine()

In [3]:
docs = [(fn.split(".")[0], open(os.path.join(DOCS_PATH, fn), 'r', encoding='utf-8').read()) for fn in tqdm(os.listdir(DOCS_PATH)) if fn.endswith(".txt")]

100%|██████████| 747/747 [00:00<00:00, 2218.64it/s]


In [4]:
MAX_CHAR_LEN = 4000
MAX_CHAR_OVERLAP = 500
splitter = CharacterTextSplitter(separator=" ", chunk_size=MAX_CHAR_LEN, chunk_overlap=MAX_CHAR_OVERLAP)
split_docs = []
for doc in tqdm(docs):
    split_docs.extend(splitter.split_text(doc[1]) if len(doc[1]) > MAX_CHAR_LEN else [doc[1]])

100%|██████████| 747/747 [00:00<00:00, 1038.22it/s]


In [5]:
doc_embeddings = engine.embed_documents(split_docs)

In [6]:
def create_chat(context, question):
    return [
        {"role": "system", "content": f"Use only the following pieces of context to answer the question at the end. Different references are seperated by \"\n\n\". Please only use the references relevant to answer the question f{context}"},
        {"role": "user", "content": f"{question}"},
    ]

In [7]:
TOP_K_DOCS = 3
D = doc_embeddings.shape[1]
m = 8
assert D % m == 0
nbits = 5
index = faiss.IndexPQ(D, m, nbits)
index.train(doc_embeddings)
index.add(doc_embeddings)

### Example question

In [8]:
question = "What are the main indicators that were chosen to study in order to understand and forecast the evolution of carbon emissions on a country-scale, and why were they chosen?"
query_vec = engine.embed_query(question)
distances, indices = index.search(query_vec, TOP_K_DOCS)
top_docs = [split_docs[i] for i in indices[0]]
context = "\n\n".join(top_docs)
prompt_with_context = engine.tokenizer.apply_chat_template(
    create_chat(context, question), 
    tokenize=False, 
    add_generation_prompt=True
)
print(prompt_with_context)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 21 Apr 2025

Use only the following pieces of context to answer the question at the end. Different references are seperated by "

". Please only use the references relevant to answer the question f  Co-Director of Science  T. Eren Bilir Monalisa Chatterjee Kristie L. Ebi Yuka Otsuki Estrada Robert C. Genova Betelhem Girma  Eric S. Kissel Andrew N. Levy Sandy MacCracken Patricia R. Mastrandrea Leslie L. White 32 Avenue of the Americas, New York, NY 10013-2473, USA  Cambridge University Press is part of the University of Cambridge. It furthers the University’s mission by disseminating knowledge in the pursuit of education, learning, and research at the highest international levels of excellence.  www.cambridge.org  Information on this title: www.cambridge.org/9781107641655  © Intergovernmental Panel on Climate Change 2014  This publication is in copyright. Subject to statutory e

In [9]:
streamer = TextStreamer(engine.tokenizer, skip_prompt=True)

input_ids = engine.tokenizer.encode(prompt_with_context, return_tensors="pt").to(DEVICE)
outputs = engine.decoder.generate(
    input_ids,
    max_new_tokens=500,
    pad_token_id=128004,
    eos_token_id=128009,
    streamer=streamer,
    do_sample=True,
)

According to the provided text, the main indicators used to study the evolution of carbon emissions on a country-scale to understand and forecast its impact on climate change are:

1. **Carbon dioxide (CO2) fluxes**: This is a key indicator of greenhouse gas emissions, which are a major contributor to climate change.

2. **Land use changes**: This includes changes in land use such as deforestation, land conversion, and land cover changes, which are significant contributors to carbon emissions.

3. **Land degradation**: This includes processes such as soil degradation, loss of forests, and loss of agricultural productivity, which also contribute to carbon emissions.

4. **Forest degradation**: This includes the loss of forests, which is a significant contributor to carbon emissions.

5. **Greenhouse gas fluxes**: This includes the fluxes of other greenhouse gases such as methane, nitrous oxide, and nitric acid, which are also significant contributors to climate change.

The text does no

### Evaluation (Perplexity)

In [10]:
def create_target_chat(context, question, answer):
    return [
        {"role": "system", "content": f"Use only the following pieces of context to answer the question at the end. Different references are seperated by \"\n\n\". Please only use the references relevant to answer the question f{context}"},
        {"role": "user", "content": f"{question}"},
        {"role": "assistant", "content": f"{answer}"},
    ]

In [11]:
loss = 0
losses = list()

for d in tqdm(docs):
    with open(os.path.join(QUESTIONS_PATH, f"{d[0]}.json"), 'r', encoding='utf-8') as f:
        qa = json.load(f)

    query_vec = engine.embed_query(qa["question"])
    distances, indices = index.search(query_vec, TOP_K_DOCS)
    top_docs = [split_docs[i] for i in indices[0]]
    context = "\n\n".join(top_docs)
    target_chat = engine.tokenizer.apply_chat_template(
        create_target_chat(context, qa["question"], qa["answer"]), 
        tokenize=False, 
        add_generation_prompt=False
    )
    tokens = engine.tokenizer(target_chat, return_tensors="pt", truncation=False, padding=False, add_special_tokens=False).to(DEVICE)
    tokens = {k: v.to(DEVICE) for k, v in tokens.items()}

    prompt = engine.tokenizer.apply_chat_template(
        create_chat(context, qa["question"]), 
        tokenize=False, 
        add_generation_prompt=True
    )
    ignore_idx = engine.tokenizer.encode(prompt, return_tensors="pt", truncation=False, padding=False, add_special_tokens=False).shape[1]

    filtered_labels = tokens["input_ids"].clone()
    ignore_mask = torch.zeros_like(filtered_labels, dtype=torch.bool)
    ignore_mask[0, :ignore_idx] = True
    filtered_labels[ignore_mask] = -100

    with torch.no_grad():
        outputs = engine.decoder(**tokens, labels=filtered_labels)
        losses.append(outputs.loss.item())
        loss += outputs.loss.item()

perplexity = math.exp(loss/len(docs))
print(f"Perplexity: {perplexity:.2f}")

100%|██████████| 747/747 [00:48<00:00, 15.52it/s]

Perplexity: 7.70





### Evaluation (Average time per request)

In [12]:
begin = time.time()

for d in tqdm(docs):
    with open(os.path.join(QUESTIONS_PATH, f"{d[0]}.json"), 'r', encoding='utf-8') as f:
        qa = json.load(f)

    query_vec = engine.embed_query(qa["question"])
    distances, indices = index.search(query_vec, TOP_K_DOCS)
    top_docs = [split_docs[i] for i in indices[0]]
    context = "\n\n".join(top_docs)

    prompt = engine.tokenizer.apply_chat_template(
        create_chat(context, qa["question"]), 
        tokenize=False, 
        add_generation_prompt=True
    )
    input_ids = engine.tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    engine.decoder.generate(
        input_ids,
        max_new_tokens=500,
        pad_token_id=128004,
        eos_token_id=128009,
        do_sample=False,
        top_p=1.0,
    )

print(f"Average time per request: {(time.time() - begin) / len(docs):.2f} seconds")

100%|██████████| 747/747 [1:20:16<00:00,  6.45s/it]

Average time per request: 6.45 seconds



