In [1]:
from preprocessing import XRAG_TOKEN, load_and_format_dataset, encode_with_chat_format_finetune, encode_with_chat_format_pretrain, add_retriever_embeddings, collator
from utils import get_nll_loss, get_kl_loss, validate_during_pretrain
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_scheduler
from tokenizers import AddedToken
from functools import partial
from tqdm import tqdm 
from datetime import datetime
import torch
import sys
import os
sys.path.append('..')
from model.E5Retriever import E5Retriever
from model.xLlama import XLlamaConfig, XLlamaForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class Args: 
    retrieval_context_length= 512  # 180
    overwrite_cache =False
    max_train_samples = 2000000
    chat_format="llama"
    retriever_name_or_path='intfloat/multilingual-e5-small'
    workdir = "."
    lr_scheduler_type = "linear"
    warmup_ratio = 0.03
    weight_decay = 0.0
    num_train_epochs = 1
    use_flash_attn = True
    alpha_nll = 1.0
    seed = 980406
    update_projector_only = True
    per_device_train_batch_size = 4
    max_train_steps = None
    checkpointing_steps = None
    output_dir='../output'
    lang="english"

    # # pretrain
    # max_seq_length = 600  # 336
    # model_name_or_path = 'meta-llama/Llama-3.2-1B-Instruct'
    # task_type='pretrain'
    # learning_rate=6.0e-3
    # alpha_kl = None
    # kl_temperature=0.0
    # retrieval_embed_length=1

    #  finetune
    learning_rate = 2.0e-5
    max_seq_length = 1600 # 1024
    use_rag_tuning = True
    alpha_kl= 2.0
    kl_temperature= 1.0 
    model_name_or_path="../output/2025-05-05_22-50-00/e5small-llama1Binstruct-batch4"
    task_type="finetune"
    retrieval_embed_length=3

args = Args()
print(args.task_type)

finetune


In [3]:
dataset_path = "../../generated_data/TUNING_final_summary"
query_col = 'query'
ans_col = 'answer'
psg_col = 'passages'

max_rows = 500

dataset = load_and_format_dataset(dataset_path, query_col, ans_col, psg_col, args.task_type, max_rows, include_psg_len=False)

In [10]:
# Loading model retriever
print('memuat retriever dan tokenizernya...')
retriever = E5Retriever(args.retriever_name_or_path)
retriever_tokenizer = AutoTokenizer.from_pretrained(args.retriever_name_or_path)
retriever_hidden_size = retriever.get_embed_dim()
retriever.eval()
retriever.to('cuda:0')

from transformers import AutoTokenizer
print('memuat tokenizer generatif...')
tokenizer = AutoTokenizer.from_pretrained(
    args.model_name_or_path
)
tokenizer.padding_side = 'left'

print('memuat model generatif...')
config = XLlamaConfig.from_pretrained(args.model_name_or_path, retriever_hidden_size=retriever_hidden_size)
model = XLlamaForCausalLM.from_pretrained(
    args.model_name_or_path,
    config=config,
    torch_dtype = torch.bfloat16
).to("cuda:0")

# Mengatur pad_token pada tokenizer llama dengan token yang sudah ada dalam Llama
pad_token = "<|finetune_right_pad_id|>"
tokenizer.pad_token = pad_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)

# Menambahkan token baru (xrag) ke perbendaharaan tokenizer llama
num_added_tokens = 0
num_added_tokens += tokenizer.add_tokens([AddedToken(XRAG_TOKEN,lstrip=False,rstrip=False)])
xrag_token_id = tokenizer.convert_tokens_to_ids(XRAG_TOKEN)
model.set_xrag_token_id(xrag_token_id)
if num_added_tokens > 0:
    model.resize_token_embeddings(len(tokenizer))
vocab_size = len(tokenizer)

memuat retriever dan tokenizernya...
memuat tokenizer generatif...
memuat model generatif...


In [None]:
if args.task_type == 'finetune':
    print('encode chat untuk finetune...')
    encode_function = partial(
        encode_with_chat_format_finetune,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        lang=args.lang, 
        retrieval_context_length=args.retrieval_context_length,
        use_rag_tuning = args.use_rag_tuning,
        use_retriever_embed = not (retriever is None),
        retriever_tokenizer = retriever_tokenizer,
        chat_format = args.chat_format,
    )

if args.task_type== 'pretrain':
    print('encode chat untuk pretraining')
    encode_function = partial(
        encode_with_chat_format_pretrain,
        tokenizer = tokenizer,
        retriever_tokenizer = retriever_tokenizer, 
        max_seq_length = args.max_seq_length,
        retrieval_context_length=args.retrieval_context_length,
        retrieval_embed_length=args.retrieval_embed_length,
        chat_format = args.chat_format
    )

encode chat untuk finetune...


In [None]:
lm_datasets = dataset.map(encode_function)
lm_datasets.set_format(type="pt")

if args.task_type == 'finetune':
    print('membuang row yang seluruh labelsnya bernilai -100 (tidak ada porsi assistant sama sekali)...')
    lm_datasets['train'] = lm_datasets['train'].filter(lambda example: (example['labels'] != -100).any())
    if args.alpha_kl is not None and args.alpha_kl > 0.0:
        lm_datasets['train'] = lm_datasets['train'].filter(
            lambda example: 
            (example['labels']!=-100).sum() == (example['xrag_labels']!=-100).sum()
        )

train_dataset = lm_datasets['train']
dev_dataset = lm_datasets['dev']

In [14]:
# Menambahkan retriever_embeddings ke dataset sebelum pelatihan
print('membuat embeddings untuk dokumen konteks dengan retriever...')
train_dataset = train_dataset.map(
    lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
)

dev_dataset = dev_dataset.map(
    lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
)

# Hapus objek retriever dan modelnya
print('Menghapus retriever...')
del retriever.model  # Menghapus model dari memori
del retriever  # Menghapus objek retriever itu sendiri
torch.cuda.empty_cache() 

membuat embeddings untuk dokumen konteks dengan retriever...


Map: 100%|██████████| 484/484 [00:11<00:00, 43.95 examples/s]
Map: 100%|██████████| 500/500 [00:10<00:00, 47.96 examples/s]

Menghapus retriever...





In [27]:
train_dataset['retriever_input_text'][3]

['passage: Judul: John F. Kennedy Teks: John Fitzgerald Kennedy lahir di 83 Beals Street, Brookline, Massachusetts, pada tanggal 29 Mei 1917[10] dari pasangan pebisnis/politikus Joseph Patrick "Joe" Kennedy, Sr. (1888–1969) dan filantropis Rose Elizabeth Fitzgerald (1890–1995). Joe adalah putra sulung pebisnis/politikus Patrick Joseph "P. J." Kennedy (1858–1929) dan Mary Augusta Hickey (1857–1923). Rose adalah putri sulung Wali Kota Boston John Francis "Honey Fitz" Fitzgerald (1863–1950) dan Mary Josephine "Josie" Hannon (1865–1964). Keempat kakek-neneknya adalah anak-anak imigran Irlandia.[1]',
 'passage: Judul: Keluarga Kennedy Teks: Anak pertama dari Joseph P. Kennedy, Sr. adalah Joseph Patrick "Joe" Kennedy, Jr. yang diharapkan ayahnya untuk terjun di dunia politik dan menjadi presiden. Setelah Joe, Jr. tewas dalam Perang Dunia II, harapan menjadi presiden dialihkan ke putra kedua John Fitzgerald "Jack" Kennedy. Segera setelah terpilih sebagai presiden pada November 1960, Jack meng

In [None]:
from collections import Counter

# Fungsi untuk menghitung panjang dari setiap list di kolom 'background'
def count_lengths(example):
    return {'background_length': len(example['background'])}

# Terapkan fungsi map ke dataset
lengths_dataset = train_dataset.map(count_lengths)

# Ambil data dari kolom 'background_length' yang baru
lengths = lengths_dataset['background_length']

# Menghitung jumlah elemen unik (berapa banyak baris dengan panjang list tertentu)
length_counts = Counter(lengths)

# Menampilkan hasil
print(length_counts)

Map: 100%|██████████| 484/484 [00:00<00:00, 1068.30 examples/s]

Counter({tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 1, tensor(3): 




In [25]:
collate_fn = partial(
    collator,
    llm_tokenizer=tokenizer, 
    llm_tokenizer_pad_token = "<|finetune_right_pad_id|>", 
    xrag_input_ids_col='xrag_input_ids',
    xrag_labels_col = 'xrag_labels', 
    text_input_ids_col = 'input_ids', 
    text_labels_col = 'labels', 
    retriever_embeds_col='retriever_embeddings'
)

print('Menginisialisasi dataloader untuk training...')
train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=collate_fn,
    batch_size=args.per_device_train_batch_size
)

print('Menginisialisasi dataloader untuk validasi...')
dev_dataloader = DataLoader(
    dev_dataset,
    shuffle=False, 
    collate_fn=collate_fn,
    batch_size=args.per_device_train_batch_size
)


if args.update_projector_only:
    print('Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...')
    for n,p in model.named_parameters():
        if 'projector' not in n:p.requires_grad = False
        else:p.requires_grad = True

optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad],lr=args.learning_rate) 

# Add learning rate scheduler
num_training_steps = args.num_train_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_training_steps=num_training_steps,
    num_warmup_steps=int(num_training_steps * args.warmup_ratio)  # 3% warmup
)

Menginisialisasi dataloader untuk training...
Menginisialisasi dataloader untuk validasi...
Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...


In [16]:
# Inisialisasi list untuk menyimpan loss
nll_train_losses = []
kl_train_losses = []
train_losses = []

progress_bar = tqdm(range(num_training_steps))

for epoch in range(args.num_train_epochs):
    model.train()
    epoch_train_loss = 0
    for batch in tqdm(train_dataloader):

        optimizer.zero_grad()

        outputs = model(
            input_ids = batch['xrag_input_ids'],
            attention_mask = batch['xrag_attention_mask'],
            retrieval_embeds = batch['retriever_embeddings']
        )
        del batch['xrag_input_ids']
        del batch['xrag_attention_mask']
        del batch['retriever_embeddings']
        torch.cuda.empty_cache()
        logits = outputs.logits
        labels = batch['xrag_labels']

        nll_loss = get_nll_loss(logits=logits, labels=labels, vocab_size=vocab_size)

        loss = args.alpha_nll * nll_loss
        nll_train_losses.append(loss.item())

        if args.alpha_kl is not None and args.alpha_kl > 0.0:
                    
            ## forward with retrieval tokens
            with torch.no_grad():
                model.eval()
                teacher_outputs = model(
                    input_ids = batch['input_ids'],
                    attention_mask = batch['attention_mask'],
                )
                del batch['input_ids']
                del batch['attention_mask']
                torch.cuda.empty_cache()
                model.train()

            kl_loss = get_kl_loss(
                teacher_logits=teacher_outputs.logits,
                teacher_labels=batch['labels'],
                student_logits=outputs.logits,
                student_labels=batch['xrag_labels'],
                temperature=args.kl_temperature,
            )
            kl_loss = args.alpha_kl * kl_loss
            kl_train_losses.append(kl_loss.item())
            loss += kl_loss
            
            del batch['labels']
            torch.cuda.empty_cache()
        del batch['xrag_labels']
        torch.cuda.empty_cache()

        # Simpan loss untuk tiap batch
        train_losses.append(loss.item())
        epoch_train_loss += loss.item()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()

  ret['retriever_embeddings'] = torch.stack([torch.tensor(x[retriever_embeds_col]).to('cuda:0') for x in samples])
  0%|          | 0/121 [00:00<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [3, 384] at entry 0 and [4, 384] at entry 2

In [None]:
# Lakukan validasi setelah setiap epoch
dev_losses = []


if dev_dataloader is not None:
    if args.task_type == 'pretrain':
        print(f"Validating after epoch {epoch + 1}...")
        ppl = validate_during_pretrain(model, dev_dataloader, vocab_size)
        print(f"Perplexity on dev set: {ppl}")

        dev_losses.append(ppl.item())

# Simpan rata-rata loss setiap epoch ke dalam list atau file
print(f"Average training loss for epoch {epoch + 1}: {epoch_train_loss / len(train_dataloader)}")

Validating after epoch 1...
NLL Loss: 2.417423963546753
NLL Loss: 1.8294185400009155
NLL Loss: 2.1833391189575195
NLL Loss: 2.4166274070739746
NLL Loss: 2.331878662109375
NLL Loss: 1.9211673736572266
NLL Loss: 2.3154242038726807
NLL Loss: 2.2908835411071777
NLL Loss: 2.3752048015594482
NLL Loss: 1.6519795656204224
NLL Loss: 1.7484678030014038
NLL Loss: 2.2158658504486084
NLL Loss: 1.4776960611343384
NLL Loss: 2.050281286239624
NLL Loss: 2.203371286392212
NLL Loss: 2.0610008239746094
NLL Loss: 2.161954402923584
NLL Loss: 2.3024446964263916
NLL Loss: 1.7361009120941162
NLL Loss: 2.1097402572631836
NLL Loss: 2.062561511993408
NLL Loss: 2.353330135345459
NLL Loss: 2.099410057067871
NLL Loss: 2.580258369445801
NLL Loss: 1.9219951629638672
NLL Loss: 2.3179144859313965
NLL Loss: 1.8331387042999268
NLL Loss: 2.1682233810424805
NLL Loss: 2.4154458045959473
NLL Loss: 2.121776819229126
NLL Loss: 2.0236308574676514
NLL Loss: 1.993772268295288
NLL Loss: 2.059504508972168
NLL Loss: 1.978107810020446

In [None]:
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = os.path.join(args.output_dir, current_time)
model_output_dir = os.path.join(output_dir, f"e5small-llama1Binstruct-batch4")
os.makedirs(model_output_dir, exist_ok=True)
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('../output\\2025-05-05_22-50-00\\e5small-llama1Binstruct-batch4\\tokenizer_config.json',
 '../output\\2025-05-05_22-50-00\\e5small-llama1Binstruct-batch4\\special_tokens_map.json',
 '../output\\2025-05-05_22-50-00\\e5small-llama1Binstruct-batch4\\tokenizer.json')

In [None]:
# Menyimpan ke CSV
import csv

with open(os.path.join(output_dir, "train_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['step', 'train_loss'])
    for step, loss in enumerate(train_losses):
        writer.writerow([step, loss])

with open(os.path.join(output_dir, "dev_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'dev_loss'])
    for epoch, loss in enumerate(dev_losses):
        writer.writerow([epoch, loss])