In [1]:
from preprocessing import XRAG_TOKEN, load_and_format_dataset, encode_with_chat_format_finetune, encode_with_chat_format_pretrain, add_retriever_embeddings, collator
from utils import get_nll_loss, get_kl_loss, validate_during_pretrain
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_scheduler
from tokenizers import AddedToken
from functools import partial
from tqdm import tqdm 
from datetime import datetime
import torch
import sys
import csv
import os
sys.path.append('..')
from model.E5Retriever import E5Retriever
from model.xLlama import XLlamaConfig, XLlamaForCausalLM
from model.xQwen3 import XQwen3Config, XQwen3ForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
class Args: 
    retrieval_context_length= 512  # 180
    overwrite_cache =False
    max_samples = None
    chat_format="qwen"  # "llama"
    model_size = "1,7B"
    retriever_name_or_path='intfloat/multilingual-e5-small'
    dataset_path = "../../generated_data/raw/final_dataset"
    query_col = 'query'
    ans_col = 'answer'
    psg_col = 'passages'
    lr_scheduler_type = "linear"
    warmup_ratio = 0.03
    num_train_epochs = 3
    alpha_nll = 1.0
    update_projector_only = True
    per_device_train_batch_size = 2
    max_train_steps = None
    checkpointing_steps = None
    output_dir='../output'
    lang="indonesian"

    # # pretrain
    # max_seq_length = 600  # 336
    # model_name_or_path = "Qwen/Qwen3-1.7B"
    # task_type='pretrain'
    # learning_rate=6.0e-3
    # alpha_kl = None
    # kl_temperature=0.0
    # retrieval_embed_length=1

    #  finetune
    learning_rate = 2.0e-5
    max_seq_length = 1620 # 1024
    use_rag_tuning = True
    alpha_kl= 2.0
    kl_temperature= 1.0 
    model_name_or_path="../output/pretrained/e5-small_qwen1,7B_batch2_Nonerows"
    
    task_type="finetune"
    retrieval_embed_length=3

args = Args()
print(args.task_type)
dataset = load_and_format_dataset(args.dataset_path, args.query_col, args.ans_col, args.psg_col, args.task_type, args.max_samples, include_psg_len=False)

if 'test' in dataset:
    dataset.pop('test')
if args.task_type == 'finetune':
    if 'dev' in dataset:
        dataset.pop('dev')

print("berhasil memuat dataset")

finetune
berhasil memuat dataset


In [4]:
# Loading model retriever
print('memuat retriever dan tokenizernya...')
retriever = E5Retriever(args.retriever_name_or_path)
retriever_tokenizer = AutoTokenizer.from_pretrained(args.retriever_name_or_path)
retriever_hidden_size = retriever.get_embed_dim()
retriever.eval()
retriever.to('cuda:0')

print('memuat tokenizer generatif...')
tokenizer = AutoTokenizer.from_pretrained(
    args.model_name_or_path
)
tokenizer.padding_side = 'left'

print('memuat model generatif...')
# config = XLlamaConfig.from_pretrained(args.model_name_or_path, retriever_hidden_size=retriever_hidden_size)
config = XQwen3Config.from_pretrained(args.model_name_or_path, retriever_hidden_size=retriever_hidden_size)
model = XQwen3ForCausalLM.from_pretrained(  # XLlamaForCausalLM
    args.model_name_or_path,
    config=config,
    torch_dtype = torch.bfloat16
).to("cuda:0")

# Mengatur pad_token pada tokenizer llama dengan token yang sudah ada dalam Llama
if tokenizer.pad_token_id is None:
    print('Menambahkan pad token ke tokenizer')

    if args.chat_format == 'llama':    
        pad_token = "<|finetune_right_pad_id|>"
        tokenizer.pad_token = pad_token
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
    elif args.chat_format == 'qwen':
        tokenizer.pad_token = tokenizer.eos_token

# Menambahkan token baru (xrag) ke perbendaharaan tokenizer llama
num_added_tokens = 0
num_added_tokens += tokenizer.add_tokens([AddedToken(XRAG_TOKEN,lstrip=False,rstrip=False)])
xrag_token_id = tokenizer.convert_tokens_to_ids(XRAG_TOKEN)
model.set_xrag_token_id(xrag_token_id)
if num_added_tokens > 0:
    model.resize_token_embeddings(len(tokenizer))
vocab_size = len(tokenizer)

memuat retriever dan tokenizernya...


  return t.to(


memuat tokenizer generatif...
memuat model generatif...


In [5]:
if args.task_type == 'finetune':
    print('encode chat untuk finetune...')
    encode_function = partial(
        encode_with_chat_format_finetune,
        llm_tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        retrieval_context_length=args.retrieval_context_length,
        lang=args.lang, 
        use_rag_tuning = args.use_rag_tuning,
        use_retriever_embed = not (retriever is None),
        retriever_tokenizer = retriever_tokenizer,
        chat_format = args.chat_format,
    )

if args.task_type== 'pretrain':
    print('encode chat untuk pretraining')
    encode_function = partial(
        encode_with_chat_format_pretrain,
        llm_tokenizer = tokenizer,
        retriever_tokenizer = retriever_tokenizer, 
        max_seq_length = args.max_seq_length,
        retrieval_context_length=args.retrieval_context_length,
        retrieval_embed_length=args.retrieval_embed_length,
        chat_format = args.chat_format
    )

encode chat untuk finetune...


In [6]:
lm_datasets = dataset.map(encode_function)
lm_datasets.set_format(type="pt")

if args.task_type == 'finetune':
    print('membuang row yang seluruh labelsnya bernilai -100 (tidak ada porsi assistant sama sekali)...')
    lm_datasets['train'] = lm_datasets['train'].filter(lambda example: (example['labels'] != -100).any())
    if args.alpha_kl is not None and args.alpha_kl > 0.0:
        lm_datasets['train'] = lm_datasets['train'].filter(
            lambda example: 
            (example['labels']!=-100).sum() == (example['xrag_labels']!=-100).sum()
        )

train_dataset = lm_datasets['train']
dev_dataset = lm_datasets['dev'] if args.task_type == 'pretrain' else None

Map: 100%|██████████| 4542/4542 [00:37<00:00, 120.61 examples/s]


membuang row yang seluruh labelsnya bernilai -100 (tidak ada porsi assistant sama sekali)...


Filter: 100%|██████████| 4542/4542 [00:00<00:00, 4554.73 examples/s]
Filter: 100%|██████████| 4540/4540 [00:01<00:00, 3208.84 examples/s]


In [7]:
# Menambahkan retriever_embeddings ke dataset sebelum pelatihan
print('membuat embeddings untuk dokumen konteks dengan retriever...')
train_dataset = train_dataset.map(
    lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
)

if dev_dataset is not None:
    dev_dataset = dev_dataset.map(
        lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
    )

# Hapus objek retriever dan modelnya
print('Menghapus retriever...')
del retriever.model  # Menghapus model dari memori
del retriever  # Menghapus objek retriever itu sendiri
torch.cuda.empty_cache() 

membuat embeddings untuk dokumen konteks dengan retriever...


Map: 100%|██████████| 4540/4540 [01:38<00:00, 45.91 examples/s]

Menghapus retriever...





In [8]:
collate_fn = partial(
    collator,
    llm_tokenizer=tokenizer, 
    xrag_input_ids_col='xrag_input_ids',
    xrag_labels_col = 'xrag_labels', 
    text_input_ids_col = 'input_ids', 
    text_labels_col = 'labels', 
    retriever_embeds_col='retriever_embeddings'
)

print('Menginisialisasi dataloader untuk training...')
train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=collate_fn,
    batch_size=args.per_device_train_batch_size
)


if dev_dataset is not None:
    print('Menginisialisasi dataloader untuk validasi...')
    dev_dataloader = DataLoader(
        dev_dataset,
        shuffle=False, 
        collate_fn=collate_fn,
        batch_size=args.per_device_train_batch_size
    )


if args.update_projector_only:
    print('Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...')
    for n,p in model.named_parameters():
        if 'projector' not in n:p.requires_grad = False
        else:p.requires_grad = True

optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad],lr=args.learning_rate) 

# Add learning rate scheduler
num_training_steps = args.num_train_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_training_steps=num_training_steps,
    num_warmup_steps=int(num_training_steps * args.warmup_ratio)  # 3% warmup
)

Menginisialisasi dataloader untuk training...
Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...


In [9]:
# Inisialisasi list untuk menyimpan loss
accumulation_steps = 4

nll_train_losses = []
kl_train_losses = []
train_losses = []
dev_losses = []

progress_bar = tqdm(range(num_training_steps))

  0%|          | 0/6810 [00:00<?, ?it/s]

In [10]:
for epoch in range(args.num_train_epochs):
    model.train()
    epoch_train_loss = 0
    print("======"*12)  # Pembatas untuk setiap epoch
    print(f"Starting epoch {epoch+1}")
    for batch_idx, batch in enumerate(train_dataloader):
        progress_bar.set_postfix({'epoch': epoch+1, 'batch': batch_idx+1})
        progress_bar.update(1)
        if batch_idx % accumulation_steps == 0:
            optimizer.zero_grad()  # deindent jika ingin cancel batch accumulation  

        outputs = model(
            input_ids = batch['xrag_input_ids'],
            attention_mask = batch['xrag_attention_mask'],
            retrieval_embeds = batch['retriever_embeddings']
        )
        del batch['xrag_input_ids']
        del batch['xrag_attention_mask']
        del batch['retriever_embeddings']
        torch.cuda.empty_cache()
        logits = outputs.logits
        labels = batch['xrag_labels']

        nll_loss = get_nll_loss(logits=logits, labels=labels, vocab_size=vocab_size)

        loss = args.alpha_nll * nll_loss
        nll_train_losses.append(loss.item())

        if args.alpha_kl is not None and args.alpha_kl > 0.0:
            ## forward with retrieval tokens
            with torch.no_grad():
                model.eval()
                teacher_outputs = model(
                    input_ids = batch['input_ids'],
                    attention_mask = batch['attention_mask'],
                )
                del batch['input_ids']
                del batch['attention_mask']
                torch.cuda.empty_cache()
                model.train()

            kl_loss = get_kl_loss(
                teacher_logits=teacher_outputs.logits,
                teacher_labels=batch['labels'],
                student_logits=outputs.logits,
                student_labels=batch['xrag_labels'],
                temperature=args.kl_temperature,
            )
            kl_loss = args.alpha_kl * kl_loss
            kl_train_losses.append(kl_loss.item())
            loss += kl_loss
            
            del batch['labels']
            torch.cuda.empty_cache()
        del batch['xrag_labels']
        torch.cuda.empty_cache()

        # Simpan loss untuk tiap batch
        train_losses.append(loss.item())
        epoch_train_loss += loss.item()

        loss.backward()

        # Update parameter hanya setelah beberapa batch terakumulasi
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()  # deindent jika ingin cancel batch acc
            lr_scheduler.step()  # deindent jika ingin cancel batch acc
        # Jika menggunakan lebih banyak mini-batch sebelum update, pastikan gradien dihitung selama beberapa langkah
        if (batch_idx + 1) % accumulation_steps == 0:  # HAPUS JIKA INGIN CANCEL BATCH ACC
            optimizer.zero_grad()  # Reset gradien untuk batch berikutnya  # --"-- 

    # # Setelah setiap epoch selesai, lakukan validasi
    # if dev_dataloader is not None:
    #     print("------"*12)
    #     if args.task_type == 'pretrain':
    #         print(f"Validating after epoch {epoch+1}...")
    #         ppl = validate_during_pretrain(model, dev_dataloader, len(tokenizer))
    #         print(f"Perplexity on dev set after epoch {epoch+1}: {ppl}")

    #         dev_losses.append(ppl.item())

# Hitung rata-rata dari dev_losses
# average_dev_loss = sum(dev_losses) / len(dev_losses) if dev_losses else 0.0
# print(f"Average Perplexity on dev set: {average_dev_loss}")


  ret['retriever_embeddings'] = torch.stack([torch.tensor(x[retriever_embeds_col]).to('cuda:0') for x in samples])
  0%|          | 1/6810 [00:03<6:16:20,  3.32s/it, epoch=1, batch=1]

Starting epoch 1


 16%|█▌        | 1088/6810 [10:53<57:59,  1.64it/s, epoch=1, batch=1088]  



 18%|█▊        | 1202/6810 [11:59<52:12,  1.79it/s, epoch=1, batch=1202]  



 28%|██▊       | 1874/6810 [18:27<41:26,  1.99it/s, epoch=1, batch=1874]  



 33%|███▎      | 2271/6810 [22:26<38:42,  1.95it/s, epoch=2, batch=1]     

Starting epoch 2


 37%|███▋      | 2505/6810 [24:53<47:17,  1.52it/s, epoch=2, batch=235]  



 46%|████▋     | 3159/6810 [31:13<39:42,  1.53it/s, epoch=2, batch=889]



 63%|██████▎   | 4312/6810 [42:40<22:53,  1.82it/s, epoch=2, batch=2042]  



 67%|██████▋   | 4541/6810 [45:03<19:16,  1.96it/s, epoch=3, batch=1]     

Starting epoch 3


 80%|███████▉  | 5435/6810 [54:31<12:27,  1.84it/s, epoch=3, batch=895]  



 82%|████████▏ | 5576/6810 [55:51<12:56,  1.59it/s, epoch=3, batch=1036]



 95%|█████████▍| 6464/6810 [1:04:30<02:56,  1.96it/s, epoch=3, batch=1924]



100%|██████████| 6810/6810 [1:08:12<00:00,  1.69it/s, epoch=3, batch=2270]

In [11]:
import torch
# GPU yang digunakan
def check_gpu():
    device = torch.device("cuda:0")

    # Total memori GPU
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # Dalam GB
    print(f"Total GPU Memory: {total_memory:.2f} GB")

    # Memori yang sudah dialokasikan oleh PyTorch
    allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 3)  # Dalam GB
    print(f"Allocated GPU Memory: {allocated_memory:.2f} GB")

    max_reserved_memory = torch.cuda.max_memory_reserved(device) / (1024 ** 3)  # Dalam GB
    print(f"Max Reserved GPU Memory: {max_reserved_memory:.2f} GB")

    # Memori GPU yang dicadangkan oleh PyTorch
    reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 3)  # Dalam GB
    print(f"Reserved GPU Memory: {reserved_memory:.2f} GB")

    # Memori GPU yang tersedia
    free_memory = reserved_memory - allocated_memory
    print(f"Free GPU Memory: {free_memory:.2f} GB")

check_gpu()

Total GPU Memory: 6.00 GB
Allocated GPU Memory: 3.84 GB
Max Reserved GPU Memory: 6.73 GB
Reserved GPU Memory: 5.10 GB
Free GPU Memory: 1.25 GB


In [14]:
# # Lakukan validasi setelah setiap epoch
# dev_losses = []

# if dev_dataloader is not None:
#     if args.task_type == 'pretrain':
#         print(f"Validating...")
#         ppl = validate_during_pretrain(model, dev_dataloader, len(tokenizer))
#         print(f"Perplexity on dev set: {ppl}")

#         dev_losses.append(ppl.item())

# # Hitung rata-rata dari dev_losses
# average_dev_loss = sum(dev_losses) / len(dev_losses) if dev_losses else 0.0
# print(f"Average Perplexity on dev set: {average_dev_loss}")

In [12]:
if args.task_type == 'finetune':
    output_dir = os.path.join(args.output_dir, "finetuned")
    model_output_dir = os.path.join(output_dir, 'finished_model')

elif args.task_type == 'pretrain':
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = os.path.join(args.output_dir, current_time)
    output_dir = os.path.join(args.output_dir, 'pretrained')
    model_output_dir = os.path.join(output_dir, f"{args.retriever_name_or_path[-8:]}_{args.chat_format}{args.model_size}_batch{args.per_device_train_batch_size}_{args.max_samples}rows")

In [13]:
os.makedirs(model_output_dir, exist_ok=True)
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('../output\\finetuned\\finished_model\\tokenizer_config.json',
 '../output\\finetuned\\finished_model\\special_tokens_map.json',
 '../output\\finetuned\\finished_model\\vocab.json',
 '../output\\finetuned\\finished_model\\merges.txt',
 '../output\\finetuned\\finished_model\\added_tokens.json',
 '../output\\finetuned\\finished_model\\tokenizer.json')

In [14]:
with open(os.path.join(output_dir, "train_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['step', 'train_loss'])
    for step, loss in enumerate(train_losses):
        writer.writerow([step, loss])

with open(os.path.join(output_dir, "nll_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['step', 'nll_loss'])
    for step, loss in enumerate(nll_train_losses):
        writer.writerow([step, loss])

In [18]:
if args.task_type == 'pretrain':
    with open(os.path.join(output_dir, "dev_loss.csv"), mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['epoch', 'dev_loss'])
        for epoch, loss in enumerate(dev_losses):
            writer.writerow([epoch, loss])

In [15]:
if args.task_type == 'finetune':
    with open(os.path.join(output_dir, "kl_loss.csv"), mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['step', 'kl_loss'])
        for step, loss in enumerate(kl_train_losses):
            writer.writerow([step, loss])