In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
sys.path.append('../')

from torch import nn, Tensor
from transformers import AutoTokenizer

import src.utils as ut
import src.trainer as trn
import src.finetuning as ft

from src.utils import LLAMA_MODEL_ID, login_to_hf_hub
from src.txt_dataset import TokenizedTxtDataset


login_to_hf_hub()
TOKENIZER = AutoTokenizer.from_pretrained(LLAMA_MODEL_ID)

ut.gpu_mem_info()


In [None]:
model = ut.load_raw_model()
print("before installing LoRA layers: ", ut.gpu_mem_info())

ft.freeze_and_install_lora(model, lora_rank=16)
print("after installing LoRA layers: ", ut.gpu_mem_info())

In [None]:
text_fpaths = [
    Path("../data/Estatutos-Universidad-de-los-Andes-2020-ratificados-MEN-RQ.translated.txt"),
    Path("../data/reglamento-maestria-web-2024.translated.txt"),
]

train_ds = TokenizedTxtDataset(text_fpaths,
                block_size=128,
                stride=64,
                tokenizer=TOKENIZER,
                start_pct=0.0,
                end_pct=90.0
            )

valid_ds = TokenizedTxtDataset(text_fpaths,
                block_size=128,
                stride=64,
                tokenizer=TOKENIZER,
                start_pct=90.0,
                end_pct=100.0
            )

print("len(ds):", len(train_ds), "max_stride_mult:", train_ds.max_stride_mult)

DEVICE = ut.module_device(model)
print(f"DEVICE: {DEVICE}")
TRAIN_BATCH_SIZE = 1
VALID_BATCH_SIZE = 8
LEARNING_RATE = 2e-4

trainer = trn.Trainer(
    train_ds=train_ds,
    train_batch_size=TRAIN_BATCH_SIZE,
    valid_ds=valid_ds,
    valid_batch_size=VALID_BATCH_SIZE,
    lr=LEARNING_RATE,
    device=DEVICE
)

def pred_next_token_loss(model: nn.Module, batch: dict[str, Tensor]) -> Tensor:
    return model(input_ids=batch['input_ids'],
                 attention_mask=batch['attention_mask'],
                 labels=batch['input_ids']).loss

trainer.train(model,
              loss_fun=pred_next_token_loss,
              max_steps=70,
              accum_grad_steps=8)