In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import numpy as np
import torch
import faiss
import json
import math
import time
import os

DECODER_PATH='../Llama-3.2-3B-Instruct'
ENCODER_PATH = "../bge-small-en"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DOCS_PATH = "../dataset_txt_small/train"
QUESTIONS_PATH = "./rag_questions_json"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class RAGEngine:
    def __init__(self):
        self.decoder = AutoModelForCausalLM.from_pretrained(DECODER_PATH, torch_dtype=torch.bfloat16).to(DEVICE)
        self.encoder = SentenceTransformer(ENCODER_PATH).to(DEVICE)

        self.decoder.config.use_cache = True

        self.tokenizer = AutoTokenizer.from_pretrained(DECODER_PATH)

    def embed_documents(self, docs):
        return self.encoder.encode(docs)

    def embed_query(self, query):
        return self.encoder.encode([query])
    
engine = RAGEngine()

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 29.06it/s]


In [3]:
docs = [(fn.split(".")[0], open(os.path.join(DOCS_PATH, fn), 'r', encoding='utf-8').read()) for fn in tqdm(os.listdir(DOCS_PATH)) if fn.endswith(".txt")]

100%|██████████| 747/747 [00:00<00:00, 2455.75it/s]


In [4]:
MAX_CHAR_LEN = 4000
MAX_CHAR_OVERLAP = 500
splitter = CharacterTextSplitter(separator=" ", chunk_size=MAX_CHAR_LEN, chunk_overlap=MAX_CHAR_OVERLAP)
split_docs = []
for doc in tqdm(docs):
    split_docs.extend(splitter.split_text(doc[1]) if len(doc[1]) > MAX_CHAR_LEN else [doc[1]])

100%|██████████| 747/747 [00:00<00:00, 1042.91it/s]


In [5]:
doc_embeddings = engine.embed_documents(split_docs)

In [6]:
def create_chat(context, question):
    return [
        {"role": "system", "content": f"Use only the following pieces of context to answer the question at the end. Different references are seperated by \"\n\n\". Please only use the references relevant to answer the question f{context}"},
        {"role": "user", "content": f"{question}"},
    ]

In [7]:
TOP_K_DOCS = 3
D = doc_embeddings.shape[1]
M = 32
index = faiss.IndexHNSWFlat(D, M)
index.add(doc_embeddings)

### Example question

In [8]:
question = "What are the main indicators that were chosen to study in order to understand and forecast the evolution of carbon emissions on a country-scale, and why were they chosen?"
query_vec = engine.embed_query(question)
distances, indices = index.search(query_vec, TOP_K_DOCS)
top_docs = [split_docs[i] for i in indices[0]]
context = "\n\n".join(top_docs)
prompt_with_context = engine.tokenizer.apply_chat_template(
    create_chat(context, question), 
    tokenize=False, 
    add_generation_prompt=True
)
print(prompt_with_context)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 21 Apr 2025

Use only the following pieces of context to answer the question at the end. Different references are seperated by "

". Please only use the references relevant to answer the question fresults and give relevant insight to policymakers. 1. Introduction Lowering human greenhouse gases emissions is one major goal of the efforts against climate change, and the focus and concern of international cooperation (Paris Agreement, 2015). Many indicators of human development - population, Gross Domestic Product (GDP), environmental footprint - have been following exponential curves during the past decades (Steffen et al., 2015); hence, drastic measures are needed if we are to switch from increasing to quickly de- creasing emissions, as expressed in global organisations goals (IPCC Fifth Assessment Report (2014)). Understanding and forecasting the evolution, on a country- scale

In [9]:
streamer = TextStreamer(engine.tokenizer, skip_prompt=True)

input_ids = engine.tokenizer.encode(prompt_with_context, return_tensors="pt").to(DEVICE)
outputs = engine.decoder.generate(
    input_ids,
    max_new_tokens=500,
    pad_token_id=128004,
    eos_token_id=128009,
    streamer=streamer,
    do_sample=True,
)

The main indicators that were chosen to study in order to understand and forecast the evolution of carbon emissions on a country-scale are:

1. Population (P)
2. National GDP (G)
3. Energy supply (E)
4. CO2 emissions (F)

These indicators were chosen because they are all related to carbon emissions through the Kaya identity, which is a mathematical relationship that describes the relationship between carbon emissions, energy supply, GDP, and population.

The Kaya identity is expressed as:

F = P x G x E^G x F/E

Where:

* F is carbon emissions
* P is population
* G is GDP
* E is energy supply
* G is GDP per capita (which represents the average life standard)
* E is energy intensity of the GDP (the energy needed to create one unit of GDP)
* F is the carbon intensity of energy (the CO2 emission corresponding to the supply of one unit of primary energy)

These four indicators were chosen because they provide a clear and actionable relationship between carbon emissions and other macroecono

### Evaluation (Perplexity)

In [10]:
def create_target_chat(context, question, answer):
    return [
        {"role": "system", "content": f"Use only the following pieces of context to answer the question at the end. Different references are seperated by \"\n\n\". Please only use the references relevant to answer the question f{context}"},
        {"role": "user", "content": f"{question}"},
        {"role": "assistant", "content": f"{answer}"},
    ]

In [11]:
loss = 0
losses = list()

for d in tqdm(docs):
    with open(os.path.join(QUESTIONS_PATH, f"{d[0]}.json"), 'r', encoding='utf-8') as f:
        qa = json.load(f)

    query_vec = engine.embed_query(qa["question"])
    distances, indices = index.search(query_vec, TOP_K_DOCS)
    top_docs = [split_docs[i] for i in indices[0]]
    context = "\n\n".join(top_docs)
    target_chat = engine.tokenizer.apply_chat_template(
        create_target_chat(context, qa["question"], qa["answer"]), 
        tokenize=False, 
        add_generation_prompt=False
    )
    tokens = engine.tokenizer(target_chat, return_tensors="pt", truncation=False, padding=False, add_special_tokens=False).to(DEVICE)
    tokens = {k: v.to(DEVICE) for k, v in tokens.items()}

    prompt = engine.tokenizer.apply_chat_template(
        create_chat(context, qa["question"]), 
        tokenize=False, 
        add_generation_prompt=True
    )
    ignore_idx = engine.tokenizer.encode(prompt, return_tensors="pt", truncation=False, padding=False, add_special_tokens=False).shape[1]

    filtered_labels = tokens["input_ids"].clone()
    ignore_mask = torch.zeros_like(filtered_labels, dtype=torch.bool)
    ignore_mask[0, :ignore_idx] = True
    filtered_labels[ignore_mask] = -100

    with torch.no_grad():
        outputs = engine.decoder(**tokens, labels=filtered_labels)
        losses.append(outputs.loss.item())
        loss += outputs.loss.item()

perplexity = math.exp(loss/len(docs))
print(f"Perplexity: {perplexity:.2f}")

100%|██████████| 747/747 [01:27<00:00,  8.58it/s]

Perplexity: 2.38





### Evaluation (Average time per request)

In [12]:
begin = time.time()

for d in tqdm(docs):
    with open(os.path.join(QUESTIONS_PATH, f"{d[0]}.json"), 'r', encoding='utf-8') as f:
        qa = json.load(f)

    query_vec = engine.embed_query(qa["question"])
    distances, indices = index.search(query_vec, TOP_K_DOCS)
    top_docs = [split_docs[i] for i in indices[0]]
    context = "\n\n".join(top_docs)

    prompt = engine.tokenizer.apply_chat_template(
        create_chat(context, qa["question"]), 
        tokenize=False, 
        add_generation_prompt=True
    )
    input_ids = engine.tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    engine.decoder.generate(
        input_ids,
        max_new_tokens=500,
        pad_token_id=128004,
        eos_token_id=128009,
        do_sample=False,
        top_p=1.0,
    )

print(f"Average time per request: {(time.time() - begin) / len(docs):.2f} seconds")

100%|██████████| 747/747 [1:12:05<00:00,  5.79s/it]

Average time per request: 5.79 seconds



