In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch
import json
import os

MODEL_PATH='../Llama-3.1-8B-Instruct'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DOCS_PATH = "../dataset_txt_small/train"
SUMMARIES_PATH = "./finetune_summaries_json"

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [None]:
DOCS_PATH = "../dataset_txt_small/train"
docs = [(fn, open(os.path.join(DOCS_PATH, fn), 'r', encoding='utf-8').read()) for fn in tqdm(os.listdir(DOCS_PATH)) if fn.endswith(".txt")]

In [None]:
def create_chat(question):
    return [
        {"role": "user", "content": f"{question}"},
    ]

In [None]:
template = "The following text is noisy and dirty, so please rephrase the following text while removing noisy text (e.g., incorrect characters coming from the pdf reader, emails, footnotes, etc.), but without removing important content or adding new data, however keep as much data as possible:\n\n\"{content}\"."

# create folder (with exists_ok = True)
os.makedirs(SUMMARIES_PATH, exist_ok=True)

for d in tqdm(docs):
    if not os.path.exists(os.path.join(SUMMARIES_PATH, f"{d[0].split('.')[0]}.json")):
        prompt = tokenizer.apply_chat_template(
            create_chat(template.format(content=d[1])), 
            tokenize=False, 
            add_generation_prompt=True
        )

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
        prompt_len = input_ids.shape[1]
        output_ids = model.generate(
            input_ids,
            max_new_tokens=1000,
            pad_token_id=128004,
            eos_token_id=128009,
            do_sample=False,
            top_p=1.0,
        )[:, prompt_len:]
        summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        chat = {"question": prompt, "answer": summary}

        with open(os.path.join(SUMMARIES_PATH, f"{d[0].split(".")[0]}.json"), "w", encoding="utf-8") as f:
            json.dump(chat, f, ensure_ascii=False, indent=4)