In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from langchain.text_splitter import CharacterTextSplitter
from tqdm import tqdm
import numpy as np
import torch
import faiss
import json
import os

MODEL_PATH='../Llama-3.1-8B-Instruct'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DOCS_PATH = "../dataset_txt_small/train"
QUESTIONS_PATH = "./rag_questions_json"

In [2]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 56.26it/s]


In [3]:
DOCS_PATH = "../dataset_txt_small/train"
docs = [(fn, open(os.path.join(DOCS_PATH, fn), 'r', encoding='utf-8').read()) for fn in tqdm(os.listdir(DOCS_PATH)) if fn.endswith(".txt")]

100%|██████████| 747/747 [00:00<00:00, 2355.98it/s]


In [4]:
def create_chat(question):
    return [
        {"role": "user", "content": f"{question}"},
    ]

In [9]:
template = "Please create a complex but medium/short question and respective answer (in JSON format, \"question\" and \"answer\" fields) for the following text:\n\n\"{context}\"."

# create folder (with exists_ok = True)
os.makedirs(QUESTIONS_PATH, exist_ok=True)

for d in tqdm(docs):
    if not os.path.exists(os.path.join(QUESTIONS_PATH, f"{d[0].split('.')[0]}.json")):
        prompt_with_context = tokenizer.apply_chat_template(
            create_chat(template.format(context=d[1])), 
            tokenize=False, 
            add_generation_prompt=True
        )

        input_ids = tokenizer.encode(prompt_with_context, return_tensors="pt").to(DEVICE)
        prompt_len = input_ids.shape[1]
        output_ids = model.generate(
            input_ids,
            max_new_tokens=1000,
            pad_token_id=128004,
            eos_token_id=128009,
            do_sample=False,
            top_p=1.0,
        )[:, prompt_len:]
        output_json = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        first_brace, last_brace = output_json.find("{"), output_json.find("}")
        json_str = output_json[first_brace:last_brace+1]

        try:
            json_obj = json.loads(json_str)
            assert "question" in json_obj
            assert "answer" in json_obj
            
            with open(os.path.join(QUESTIONS_PATH, f"{d[0].split(".")[0]}.json"), "w", encoding="utf-8") as f:
                json.dump(json_obj, f, ensure_ascii=False, indent=4)
        except json.JSONDecodeError as e:
            print(f"Failed to parse JSON: {e}")

100%|██████████| 747/747 [00:00<00:00, 34871.23it/s]
