In [1]:
from preprocessing import XRAG_TOKEN, load_and_format_dataset, encode_with_chat_format_finetune, encode_with_chat_format_pretrain, add_retriever_embeddings, collator
from utils import get_nll_loss, get_kl_loss, validate_during_pretrain, validate_during_finetune
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_scheduler, set_seed
from tokenizers import AddedToken
from functools import partial
from tqdm import tqdm 
from datetime import datetime
import torch
import sys
import csv
import os
sys.path.append('..')
from model.E5Retriever import E5Retriever
from model.xQwen3 import XQwen3Config, XQwen3ForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from datasets import Dataset, load_dataset
dataset = load_dataset('khalidrizki/postretrieve-raw-dataset')
# Fungsi untuk mengekstrak teks dari ranked_truncPassages_with_labels
def extract_sorted_passages(row):
    # Mengambil ranked_truncPassages_with_labels yang sudah terurut
    passages = row['ranked_truncPassages_with_labels']
    
    # Mengambil teks dari setiap passage
    sorted_texts = [passage['text'] for passage in passages]
    
    return sorted_texts

# Menggunakan method map untuk menerapkan fungsi ke setiap row di dataset
dataset = dataset.map(lambda row: {'sorted_truncPassages': extract_sorted_passages(row)}, batched=False)

# Memeriksa hasilnya: Pastikan 'sorted_truncPassages' sudah ada di dataset
print(dataset.column_names)  # Untuk memastikan nama kolom yang tersedia
dataset = dataset.rename_column('answer', 'label')

Map: 100%|██████████| 5120/5120 [00:01<00:00, 4027.31 examples/s]
Map: 100%|██████████| 565/565 [00:00<00:00, 4110.96 examples/s]
Map: 100%|██████████| 565/565 [00:00<00:00, 2194.20 examples/s]

{'train': ['query_id', 'query', 'tydiqa_id', 'answer', 'passages', 'trunc_passages', 'ranked_truncPassages_with_labels', 'sorted_truncPassages'], 'dev': ['query_id', 'query', 'tydiqa_id', 'answer', 'passages', 'trunc_passages', 'ranked_truncPassages_with_labels', 'sorted_truncPassages'], 'test': ['query_id', 'query', 'tydiqa_id', 'answer', 'passages', 'trunc_passages', 'ranked_truncPassages_with_labels', 'sorted_truncPassages']}
3





In [2]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
class Args: 
    # dataset config
    dataset_path = "khalidrizki/postretrieve-research-raw-v2"
    query_col = 'query'
    ans_col = 'label'
    psg_col = 'sorted_truncPassages'
    max_samples = None

    # model config
    model_dtype = 'bfloat16'

    # trainer config
    seed = 42
    per_device_train_batch_size = 2
    gradient_accumulation_steps = 8
    lr_scheduler_type = "linear"
    warmup_ratio = 0.03
    alpha_nll = 1.0

    # retriever config
    retriever_name_or_path='intfloat/multilingual-e5-small'
    retrieval_context_length= 512  # 180

    # prompting config
    chat_format="qwen"

    # Output naming
    model_size = "1.7B"
    output_dir='../output'
    
    # Config unique to xRAG
    update_projector_only = True
    save_embeddings_generated = False
    processing_steps_output_dir = '../../generated_data/xRAG-process'
    
    # pretrain
    task_type='pretrain'
    model_name_or_path = "Qwen/Qwen3-1.7B"
    num_train_epochs = 10
    learning_rate=6.0e-3
    alpha_kl = None
    kl_temperature=0.0
    max_seq_length = 600  # 336
    retrieval_embed_length=1

    # #  finetune
    # task_type="finetune"
    # model_name_or_path="khalidrizki/xRAG-Qwen3-pretrained"
    # num_train_epochs = 3
    # learning_rate = 2.0e-5
    # alpha_kl= 2.0
    # kl_temperature= 1.0 
    # max_seq_length = 1620 # 1024
    # retrieval_embed_length=3
    # use_rag_tuning = True

args = Args()
print("Fase latihan:", args.task_type)
dataset = load_and_format_dataset(args.dataset_path, args.query_col, args.ans_col, args.psg_col, args.task_type, args.max_samples)

if 'test' in dataset:
    dataset.pop('test')

print("berhasil memuat dataset")

Fase latihan: finetune
berhasil memuat dataset


In [None]:
set_seed(args.seed)

# Loading model retriever
print('memuat retriever dan tokenizernya...')
retriever = E5Retriever(args.retriever_name_or_path)
retriever_tokenizer = AutoTokenizer.from_pretrained(args.retriever_name_or_path)
retriever_hidden_size = retriever.get_embed_dim()
retriever.eval()
retriever.to('cuda:0')

print('memuat tokenizer generatif...')
tokenizer = AutoTokenizer.from_pretrained(
    args.model_name_or_path
)
tokenizer.padding_side = 'left'

print('memuat model generatif...')
config = XQwen3Config.from_pretrained(args.model_name_or_path, retriever_hidden_size=retriever_hidden_size)
if args.model_dtype == 'bfloat16':
    model_dtype = torch.bfloat16
elif args.model_dtype == 'float16':
    model_dtype = torch.float16
model = XQwen3ForCausalLM.from_pretrained(  # XLlamaForCausalLM
    args.model_name_or_path,
    config=config,
    torch_dtype = model_dtype
).to("cuda:0")

# Mengatur pad_token pada tokenizer llama dengan token yang sudah ada dalam Llama
if tokenizer.pad_token_id is None:
    print('Menambahkan pad token ke tokenizer')

    if args.chat_format == 'llama':    
        pad_token = "<|finetune_right_pad_id|>"
        tokenizer.pad_token = pad_token
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
    elif args.chat_format == 'qwen':
        tokenizer.pad_token = tokenizer.eos_token

# Menambahkan token baru (xrag) ke perbendaharaan tokenizer llama
num_added_tokens = 0
num_added_tokens += tokenizer.add_tokens([AddedToken(XRAG_TOKEN,lstrip=False,rstrip=False)])
xrag_token_id = tokenizer.convert_tokens_to_ids(XRAG_TOKEN)
model.set_xrag_token_id(xrag_token_id)
if num_added_tokens > 0:
    model.resize_token_embeddings(len(tokenizer))
vocab_size = len(tokenizer)

memuat retriever dan tokenizernya...


  return t.to(


memuat tokenizer generatif...
memuat model generatif...


In [None]:
if args.task_type == 'finetune':
    print('encode chat untuk finetune...')
    encode_function = partial(
        encode_with_chat_format_finetune,
        llm_tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        use_rag_tuning = args.use_rag_tuning,
        use_retriever_embed = not (retriever is None),
        chat_format = args.chat_format,
    )

if args.task_type== 'pretrain':
    print('encode chat untuk pretraining')
    encode_function = partial(
        encode_with_chat_format_pretrain,
        llm_tokenizer = tokenizer,
        max_seq_length = args.max_seq_length,
        retrieval_embed_length=args.retrieval_embed_length,
        chat_format = args.chat_format
    )

lm_datasets = dataset.map(encode_function)
lm_datasets.set_format(type="pt")

if args.task_type == 'finetune':
    print('membuang row yang seluruh labelsnya bernilai -100 (tidak ada porsi assistant sama sekali)...')
    for split in lm_datasets.keys():
        if "background" in lm_datasets[split].column_names:
            lm_datasets[split].remove_columns('background')
        lm_datasets[split] = lm_datasets[split].filter(lambda example: (example['labels'] != -100).any())
        if args.alpha_kl is not None and args.alpha_kl > 0.0:
            lm_datasets[split] = lm_datasets[split].filter(
                lambda example: 
                (example['labels']!=-100).sum() == (example['xrag_labels']!=-100).sum()
            )

train_dataset = lm_datasets['train']
dev_dataset = lm_datasets['dev'] # if args.task_type == 'pretrain' else None

encode chat untuk finetune...


Map: 100%|██████████| 2/2 [00:00<00:00,  7.63 examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 45.80 examples/s]


membuang row yang seluruh labelsnya bernilai -100 (tidak ada porsi assistant sama sekali)...


Filter: 100%|██████████| 2/2 [00:00<00:00, 80.33 examples/s]
Filter: 100%|██████████| 2/2 [00:00<00:00, 126.80 examples/s]
Filter: 100%|██████████| 2/2 [00:00<00:00, 278.78 examples/s]
Filter: 100%|██████████| 2/2 [00:00<00:00, 266.36 examples/s]


In [8]:
from datasets import DatasetDict

# Menambahkan retriever_embeddings ke dataset sebelum pelatihan
print('membuat embeddings untuk dokumen konteks dengan retriever...')
train_dataset = train_dataset.map(
    lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
)

if dev_dataset is not None:
    dev_dataset = dev_dataset.map(
        lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
    )

# Hapus objek retriever dan modelnya
print('Menghapus retriever...')
del retriever.model  # Menghapus model dari memori
del retriever  # Menghapus objek retriever itu sendiri
torch.cuda.empty_cache() 

if args.save_embeddings_generated:
    ds_with_embeddings = DatasetDict({
        'train': train_dataset, 
        'dev': dev_dataset
    })
    ds_with_embeddings.save_to_disk(os.path.join(args.processing_steps_output_dir, f"{args.task_type}-create_embedding_step"))

membuat embeddings untuk dokumen konteks dengan retriever...


Map: 100%|██████████| 2/2 [00:01<00:00,  1.87 examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00,  5.73 examples/s]

Menghapus retriever...





In [9]:
collate_fn = partial(
    collator,
    llm_tokenizer=tokenizer, 
    xrag_input_ids_col='xrag_input_ids',
    xrag_labels_col = 'xrag_labels', 
    text_input_ids_col = 'input_ids', 
    text_labels_col = 'labels', 
    retriever_embeds_col='retriever_embeddings'
)

print('Menginisialisasi dataloader untuk training...')
train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=collate_fn,
    batch_size=args.per_device_train_batch_size
)

if dev_dataset is not None:
    print('Menginisialisasi dataloader untuk validasi...')
    dev_dataloader = DataLoader(
        dev_dataset,
        shuffle=False, 
        collate_fn=collate_fn,
        batch_size=args.per_device_train_batch_size
    )

if args.update_projector_only:
    print('Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...')
    for n,p in model.named_parameters():
        if 'projector' not in n:p.requires_grad = False
        else:p.requires_grad = True

optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad],lr=args.learning_rate) 

Menginisialisasi dataloader untuk training...
Menginisialisasi dataloader untuk validasi...
Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...


In [10]:
# Add learning rate scheduler
num_training_steps = args.num_train_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_training_steps=num_training_steps,
    num_warmup_steps=int(num_training_steps * args.warmup_ratio)  # 3% warmup
)

In [11]:
# Inisialisasi list untuk menyimpan loss
accumulation_steps = args.gradient_accumulation_steps

nll_train_losses = []
kl_train_losses = []
train_losses = []
dev_losses = []
epoch_avg_train_losses = []

In [None]:
progress_bar = tqdm(range(num_training_steps))

for epoch in range(args.num_train_epochs):
    model.train()
    epoch_train_loss = 0
    print("======"*12)  # Pembatas untuk setiap epoch
    print(f"Starting epoch {epoch+1}")
    for batch_idx, batch in enumerate(train_dataloader):
        progress_bar.set_postfix({'epoch': epoch+1, 'batch': batch_idx+1})
        progress_bar.update(1)
        if batch_idx % accumulation_steps == 0:
            optimizer.zero_grad()  # deindent jika ingin cancel batch accumulation  

        outputs = model(
            input_ids = batch['xrag_input_ids'],
            attention_mask = batch['xrag_attention_mask'],
            retrieval_embeds = batch['retriever_embeddings']
        )
        del batch['xrag_input_ids']
        del batch['xrag_attention_mask']
        del batch['retriever_embeddings']
        torch.cuda.empty_cache()

        logits = outputs.logits
        labels = batch['xrag_labels']

        nll_loss = get_nll_loss(logits=logits, labels=labels, vocab_size=vocab_size)

        loss = args.alpha_nll * nll_loss
        nll_train_losses.append(loss.item())

        if args.alpha_kl is not None and args.alpha_kl > 0.0:
            ## forward with retrieval tokens
            with torch.no_grad():
                model.eval()
                teacher_outputs = model(
                    input_ids = batch['input_ids'],
                    attention_mask = batch['attention_mask'],
                )
                del batch['input_ids']
                del batch['attention_mask']
                torch.cuda.empty_cache()
                model.train()

            kl_loss = get_kl_loss(
                teacher_logits=teacher_outputs.logits,
                teacher_labels=batch['labels'],
                student_logits=outputs.logits,
                student_labels=batch['xrag_labels'],
                temperature=args.kl_temperature,
            )
            kl_loss = args.alpha_kl * kl_loss
            kl_train_losses.append(kl_loss.item())
            loss += kl_loss
            
            del batch['labels']
            torch.cuda.empty_cache()
        del batch['xrag_labels']
        torch.cuda.empty_cache()

        # Simpan loss untuk tiap batch
        train_losses.append(loss.item())
        epoch_train_loss += loss.item()

        loss.backward()

        # Update parameter hanya setelah beberapa batch terakumulasi
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()  # deindent jika ingin cancel batch acc
            lr_scheduler.step()  # deindent jika ingin cancel batch acc
        # Jika menggunakan lebih banyak mini-batch sebelum update, pastikan gradien dihitung selama beberapa langkah
        if (batch_idx + 1) % accumulation_steps == 0:  # HAPUS JIKA INGIN CANCEL BATCH ACC
            optimizer.zero_grad()  # Reset gradien untuk batch berikutnya  # --"-- 

    epoch_avg_loss = epoch_train_loss / len(train_dataloader)
    epoch_avg_train_losses.append(epoch_avg_loss)

    # Setelah setiap epoch selesai, lakukan validasi
    if dev_dataset is not None:
        print("------"*12)
        if args.task_type == 'pretrain':
            print(f"Validating after epoch {epoch+1}...")
            ppl = validate_during_pretrain(model, dev_dataloader, len(tokenizer))
            print(f"Perplexity on dev set after epoch {epoch+1}: {ppl}")

            dev_losses.append(ppl.item())
        if args.task_type == 'finetune':
            print(f"Validating after epoch {epoch+1}...")
            metrics = validate_during_finetune(model, dev_dataloader, vocab_size, args)
            print(f"weighted sum of KL and NLL loss after epoch {epoch+1}: {metrics['total_loss']}")
            dev_losses.append(metrics)

print("dev loss pertama :", dev_losses[0])
print("dev loss terakhir:", dev_losses[-1])

  ret['retriever_embeddings'] = torch.stack([torch.tensor(x[retriever_embeds_col]).to('cuda:0') for x in samples])
  0%|          | 0/3 [00:00<?, ?it/s, epoch=1, batch=1]

Starting epoch 1
------------------------------------------------------------------------
Validating after epoch 1...


 67%|██████▋   | 2/3 [01:05<00:32, 32.58s/it, epoch=2, batch=1]

weighted sum of KL and NLL loss after epoch 1: 16.582947731018066
Starting epoch 2


In [None]:
import torch
# GPU yang digunakan
def check_gpu():
    device = torch.device("cuda:0")

    # Total memori GPU
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # Dalam GB
    print(f"Total GPU Memory: {total_memory:.2f} GB")

    # Memori yang sudah dialokasikan oleh PyTorch
    allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 3)  # Dalam GB
    print(f"Allocated GPU Memory: {allocated_memory:.2f} GB")

    max_reserved_memory = torch.cuda.max_memory_reserved(device) / (1024 ** 3)  # Dalam GB
    print(f"Max Reserved GPU Memory: {max_reserved_memory:.2f} GB")

    # Memori GPU yang dicadangkan oleh PyTorch
    reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 3)  # Dalam GB
    print(f"Reserved GPU Memory: {reserved_memory:.2f} GB")

    # Memori GPU yang tersedia
    free_memory = reserved_memory - allocated_memory
    print(f"Free GPU Memory: {free_memory:.2f} GB")

check_gpu()

Total GPU Memory: 6.00 GB
Allocated GPU Memory: 4.28 GB
Max Reserved GPU Memory: 7.27 GB
Reserved GPU Memory: 6.22 GB
Free GPU Memory: 1.94 GB


In [None]:
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if args.task_type == 'finetune':
    output_dir = os.path.join(args.output_dir, "finetuned")
    output_dir = os.path.join(output_dir, current_time)
    model_output_dir = os.path.join(output_dir, 'finished_model')

elif args.task_type == 'pretrain':
    output_dir = os.path.join(args.output_dir, 'pretrained')
    output_dir = os.path.join(output_dir, current_time)
    model_output_dir = os.path.join(output_dir, f"{args.retriever_name_or_path[-8:]}_{args.chat_format}{args.model_size}_batch{args.per_device_train_batch_size}_{args.num_train_epochs}epoch")

In [None]:
os.makedirs(model_output_dir, exist_ok=True)
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('../output\\finetuned\\2025-06-05_14-58-31\\finished_model\\tokenizer_config.json',
 '../output\\finetuned\\2025-06-05_14-58-31\\finished_model\\special_tokens_map.json',
 '../output\\finetuned\\2025-06-05_14-58-31\\finished_model\\vocab.json',
 '../output\\finetuned\\2025-06-05_14-58-31\\finished_model\\merges.txt',
 '../output\\finetuned\\2025-06-05_14-58-31\\finished_model\\added_tokens.json',
 '../output\\finetuned\\2025-06-05_14-58-31\\finished_model\\tokenizer.json')

In [None]:
with open(os.path.join(output_dir, "train_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['step', 'train_loss'])
    for step, loss in enumerate(train_losses):
        writer.writerow([step, loss])

with open(os.path.join(output_dir, "train_loss_per_epoch.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'avg_train_loss'])
    for epoch, loss in enumerate(epoch_avg_train_losses):
        writer.writerow([epoch, loss])


with open(os.path.join(output_dir, "nll_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['step', 'nll_loss'])
    for step, loss in enumerate(nll_train_losses):
        writer.writerow([step, loss])

In [None]:
# if args.task_type == 'pretrain':
with open(os.path.join(output_dir, "dev_loss.csv"), mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'dev_loss'])
    for epoch, loss in enumerate(dev_losses):
        writer.writerow([epoch, loss])

In [16]:
if args.task_type == 'finetune':
    with open(os.path.join(output_dir, "kl_loss.csv"), mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['step', 'kl_loss'])
        for step, loss in enumerate(kl_train_losses):
            writer.writerow([step, loss])