# 8. LLM Fine-Tuning (RBF Data)

**Objective:** Train the LLM using the dataset generated by the **Continuous VAE + Post-Hoc RBF Quantizer** pipeline.
This uses the same custom training loop as Notebook 04.

In [None]:
%pip install datasets transformers torch accelerate tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
import os
import torch
from torch.utils.data import DataLoader
from transformers import get_scheduler
from torch.optim import AdamW
from tqdm.auto import tqdm
from datasets import load_dataset

# PATH FIX
project_root = os.path.abspath(os.getcwd())
if 'src' not in os.listdir(project_root):
    project_root = os.path.abspath(os.path.join(project_root, '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.utils import get_llm_tokenizer, MAX_SEQ_LEN, LLM_MODEL_NAME
# IMPORT THE NEW RBF DATA PATH
from src.utils import PATH_PROCESSED_DATA_RBF 
from src.model.transformer import get_llm_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --- Config ---
NUM_TRAIN_EPOCHS = 3
PER_DEVICE_TRAIN_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
NUM_WORKERS = 2

DRIVE_SAVE_DIR = "/content/drive/My Drive/TokenAssorted_GSM8K_Results/llm_rbf_experiment/"
os.makedirs(DRIVE_SAVE_DIR, exist_ok=True)
FINAL_MODEL_DIR = os.path.join(DRIVE_SAVE_DIR, "final_model")

In [None]:
# 1. Load & Pre-Tokenize RBF Data
tokenizer = get_llm_tokenizer()

try:
    print(f"Loading RBF dataset from: {PATH_PROCESSED_DATA_RBF}")
    raw_dataset = load_dataset("json", data_files=PATH_PROCESSED_DATA_RBF, split="train")
except FileNotFoundError:
    print("Error: Run 'python scripts/run_preprocessing_rbf.py' first!")

def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        max_length=512, # Keeping it fast
        padding="max_length",
        truncation=True,
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

tokenized_dataset = raw_dataset.map(
    tokenize_function, 
    batched=True, 
    num_proc=os.cpu_count(), 
    remove_columns=["text"]
)
tokenized_dataset.set_format("torch")

train_dataloader = DataLoader(
    tokenized_dataset, 
    batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, 
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
# 2. Setup Model & Train
model = get_llm_model(LLM_MODEL_NAME, len(tokenizer)).to(device)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
num_training_steps = NUM_TRAIN_EPOCHS * len(train_dataloader) // GRADIENT_ACCUMULATION_STEPS
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

model.train()
progress_bar = tqdm(range(num_training_steps))

for epoch in range(NUM_TRAIN_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_TRAIN_EPOCHS}")
    for step, batch in enumerate(train_dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
        loss.backward()
        
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_description(f"Loss: {loss.item() * GRADIENT_ACCUMULATION_STEPS:.4f}")

print("Saving...")
model.save_pretrained(FINAL_MODEL_DIR)
tokenizer.save_pretrained(FINAL_MODEL_DIR)
print("Done.")