In [1]:
from preprocessing import XRAG_TOKEN, load_and_format_dataset, encode_with_chat_format_finetune, encode_with_chat_format_pretrain, add_retriever_embeddings, collator, load_config
from utils import get_nll_loss, get_kl_loss, validate_during_pretrain, validate_during_finetune
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_scheduler, set_seed
from tokenizers import AddedToken
from functools import partial
from tqdm import tqdm 
from datetime import datetime
import torch
import sys
import csv
import os
sys.path.append('..')
from model.E5Retriever import E5Retriever
from model.xQwen3 import XQwen3Config, XQwen3ForCausalLM

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
args = load_config("../config/finetune.yaml")

print("Fase latihan:", args.task_type)
dataset = load_and_format_dataset(args.dataset_path, args.query_col, args.ans_col, args.psg_col, args.task_type, args.max_samples)

if 'test' in dataset:
    dataset.pop('test')

print("berhasil memuat dataset")

Fase latihan: finetune
berhasil memuat dataset


In [None]:
set_seed(args.seed)
projector_ckpt = torch.load(args.projector_path)

  projector_ckpt = torch.load(args.projector_path)
  untyped_storage = torch.UntypedStorage(self.size(), device=device)


In [None]:
# Loading model retriever
print('memuat retriever dan tokenizernya...')
retriever = E5Retriever(args.retriever_name_or_path)
retriever_tokenizer = AutoTokenizer.from_pretrained(args.retriever_name_or_path)
retriever_hidden_size = retriever.get_embed_dim()
retriever.eval()
retriever.to('cuda:0')

print('memuat tokenizer generatif...')
tokenizer = AutoTokenizer.from_pretrained(
    args.model_name_or_path
)
tokenizer.padding_side = 'left'

print('memuat model generatif...')
# TODO: jika finetuning, muat projector terpisah baru gabungkan ke xqwen3
config = XQwen3Config.from_pretrained(args.model_name_or_path, retriever_hidden_size=retriever_hidden_size)
if args.model_dtype == 'bfloat16':
    model_dtype = torch.bfloat16
elif args.model_dtype == 'float16':
    model_dtype = torch.float16
model = XQwen3ForCausalLM.from_pretrained(  # XLlamaForCausalLM
    args.model_name_or_path,
    config=config,
    torch_dtype = model_dtype
).to("cuda:0")

if args.task_type == 'finetune':
    model.load_state_dict(torch.load(args.projector_path), strict=False)

# Mengatur pad_token pada tokenizer llama dengan token yang sudah ada dalam Llama
if tokenizer.pad_token_id is None:
    print('Menambahkan pad token ke tokenizer')

    if args.chat_format == 'llama':    
        pad_token = "<|finetune_right_pad_id|>"
        tokenizer.pad_token = pad_token
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
    elif args.chat_format == 'qwen':
        tokenizer.pad_token = tokenizer.eos_token

# Menambahkan token baru (xrag) ke perbendaharaan tokenizer llama
num_added_tokens = 0
num_added_tokens += tokenizer.add_tokens([AddedToken(XRAG_TOKEN,lstrip=False,rstrip=False)])
xrag_token_id = tokenizer.convert_tokens_to_ids(XRAG_TOKEN)

model.set_xrag_token_id(xrag_token_id)
if num_added_tokens > 0:
    model.resize_token_embeddings(len(tokenizer))
vocab_size = len(tokenizer)

memuat retriever dan tokenizernya...
memuat tokenizer generatif...
memuat model generatif...


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.72it/s]
Some weights of XQwen3ForCausalLM were not initialized from the model checkpoint at Qwen/Qwen3-1.7B and are newly initialized: ['projector.projector.0.bias', 'projector.projector.0.weight', 'projector.projector.2.bias', 'projector.projector.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load(args.projector_path), strict=False)


In [None]:
if args.task_type == 'finetune':
    print('encode chat untuk finetune...')
    encode_function = partial(
        encode_with_chat_format_finetune,
        llm_tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        use_rag_tuning = args.use_rag_tuning,
        use_retriever_embed = not (retriever is None),
        chat_format = args.chat_format,
    )

if args.task_type== 'pretrain':
    print('encode chat untuk pretraining')
    encode_function = partial(
        encode_with_chat_format_pretrain,
        llm_tokenizer = tokenizer,
        max_seq_length = args.max_seq_length,
        retrieval_embed_length=args.retrieval_embed_length,
        chat_format = args.chat_format
    )

lm_datasets = dataset.map(encode_function)
lm_datasets.set_format(type="pt")

if args.task_type == 'finetune':
    print('membuang row yang seluruh labelsnya bernilai -100 (tidak ada porsi assistant sama sekali)...')
    for split in lm_datasets.keys():
        if "background" in lm_datasets[split].column_names:
            lm_datasets[split].remove_columns('background')
        lm_datasets[split] = lm_datasets[split].filter(lambda example: (example['labels'] != -100).any())
        if args.alpha_kl is not None and args.alpha_kl > 0.0:
            lm_datasets[split] = lm_datasets[split].filter(
                lambda example: 
                (example['labels']!=-100).sum() == (example['xrag_labels']!=-100).sum()
            )

train_dataset = lm_datasets['train']
dev_dataset = lm_datasets['dev'] # if args.task_type == 'pretrain' else None

encode chat untuk pretraining


Map: 100%|██████████| 15360/15360 [00:28<00:00, 537.47 examples/s]
Map: 100%|██████████| 1695/1695 [00:02<00:00, 613.55 examples/s]


In [None]:
from datasets import DatasetDict

# Menambahkan retriever_embeddings ke dataset sebelum pelatihan
print('membuat embeddings untuk dokumen konteks dengan retriever...')
train_dataset = train_dataset.map(
    lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
)

if dev_dataset is not None:
    dev_dataset = dev_dataset.map(
        lambda example: add_retriever_embeddings(example, retriever, retriever_tokenizer, args.retrieval_context_length, text_col='retriever_input_text')
    )

# Hapus objek retriever dan modelnya
print('Menghapus retriever...')
del retriever.model  
del retriever
torch.cuda.empty_cache() 

if args.save_embeddings_generated:
    ds_with_embeddings = DatasetDict({
        'train': train_dataset, 
        'dev': dev_dataset
    })
    ds_with_embeddings.save_to_disk(os.path.join(args.processing_steps_output_dir, f"{args.task_type}-create_embedding_step-learningrate1e-2"))

In [5]:
from datasets import load_from_disk
lm_datasets = load_from_disk('../../generated_data/xRAG-process/finetune-create_embedding_step')
train_dataset = lm_datasets['train']
dev_dataset = lm_datasets['dev']

In [6]:
collate_fn = partial(
    collator,
    llm_tokenizer=tokenizer, 
    xrag_input_ids_col='xrag_input_ids',
    xrag_labels_col = 'xrag_labels', 
    text_input_ids_col = 'input_ids', 
    text_labels_col = 'labels', 
    retriever_embeds_col='retriever_embeddings'
)

print('Menginisialisasi dataloader untuk training...')
train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=collate_fn,
    batch_size=args.per_device_train_batch_size
)

if dev_dataset is not None:
    print('Menginisialisasi dataloader untuk validasi...')
    dev_dataloader = DataLoader(
        dev_dataset,
        shuffle=False, 
        collate_fn=collate_fn,
        batch_size=args.per_device_train_batch_size
    )

if args.update_projector_only:
    print('Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...')
    for n,p in model.named_parameters():
        if 'projector' not in n:p.requires_grad = False
        else:p.requires_grad = True

optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad],lr=args.learning_rate) 

Menginisialisasi dataloader untuk training...
Menginisialisasi dataloader untuk validasi...
Mengatur agar hanya layer yang menjadi bagian dr projector saja yang diupdate selama training...


In [7]:
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if args.task_type == 'finetune':
    output_dir = os.path.join(args.output_dir, "finetuned")
    output_dir = os.path.join(output_dir, current_time)
    model_output_dir = os.path.join(output_dir, 'finished_model')

elif args.task_type == 'pretrain':
    output_dir = os.path.join(args.output_dir, 'pretrained')
    output_dir = os.path.join(output_dir, current_time)
    model_output_dir = os.path.join(output_dir, f"{args.retriever_name_or_path[-8:]}_{args.chat_format}{args.model_size}_batch{args.per_device_train_batch_size}_{args.num_train_epochs}epoch_lr{args.learning_rate}")

# 2. Simpan loss ke log file di dalam output_dir/loss_logs
loss_log_dir = os.path.join(output_dir, "loss_logs")
os.makedirs(loss_log_dir, exist_ok=True)

def append_loss_to_file(filename, value):
    filepath = os.path.join(loss_log_dir, filename)
    with open(filepath, "a") as f:
        f.write(f"{value}\n")

In [None]:
# Add learning rate scheduler
num_training_steps = args.num_train_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_training_steps=num_training_steps,
    num_warmup_steps=int(num_training_steps * args.warmup_ratio)  # 3% warmup
)

# Inisialisasi list untuk menyimpan loss
accumulation_steps = args.gradient_accumulation_steps

nll_train_losses = []
kl_train_losses = []
train_losses = []
dev_losses = []
epoch_avg_train_losses = []

early_stop_patience = 3
no_improve_train = 0
no_improve_dev = 0
best_train_loss = float("inf")
best_dev_loss = float("inf")

progress_bar = tqdm(range(num_training_steps))

for epoch in range(args.num_train_epochs):
    model.train()
    epoch_train_loss = 0
    print("======"*12)  # Pembatas untuk setiap epoch
    print(f"Starting epoch {epoch+1}")
    for batch_idx, batch in enumerate(train_dataloader):
        progress_bar.set_postfix({'epoch': epoch+1, 'batch': batch_idx+1})
        progress_bar.update(1)
        if batch_idx % accumulation_steps == 0:
            optimizer.zero_grad()  # deindent jika ingin cancel batch accumulation  

        outputs = model(
            input_ids = batch['xrag_input_ids'],
            attention_mask = batch['xrag_attention_mask'],
            retrieval_embeds = batch['retriever_embeddings']
        )
        del batch['xrag_input_ids']
        del batch['xrag_attention_mask']
        del batch['retriever_embeddings']
        torch.cuda.empty_cache()

        logits = outputs.logits
        labels = batch['xrag_labels']

        nll_loss = get_nll_loss(logits=logits, labels=labels, vocab_size=vocab_size)

        loss = args.alpha_nll * nll_loss
        nll_train_losses.append(loss.item())
        append_loss_to_file("nll_train_loss.txt", nll_train_losses[-1])

        if args.alpha_kl is not None and args.alpha_kl > 0.0:
            ## forward with retrieval tokens
            with torch.no_grad():
                model.eval()
                teacher_outputs = model(
                    input_ids = batch['input_ids'],
                    attention_mask = batch['attention_mask'],
                )
                del batch['input_ids']
                del batch['attention_mask']
                torch.cuda.empty_cache()
                model.train()

            kl_loss = get_kl_loss(
                teacher_logits=teacher_outputs.logits,
                teacher_labels=batch['labels'],
                student_logits=outputs.logits,
                student_labels=batch['xrag_labels'],
                temperature=args.kl_temperature,
            )
            kl_loss = args.alpha_kl * kl_loss
            kl_train_losses.append(kl_loss.item())

            if kl_train_losses:
                append_loss_to_file("kl_train_loss.txt", kl_train_losses[-1])

            loss += kl_loss
            
            
            del batch['labels']
            torch.cuda.empty_cache()
        del batch['xrag_labels']
        torch.cuda.empty_cache()

        # Simpan loss untuk tiap batch
        train_losses.append(loss.item())
        append_loss_to_file("train_loss.txt", train_losses[-1])
        epoch_train_loss += loss.item()

        loss.backward()

        # Update parameter hanya setelah beberapa batch terakumulasi
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()  
            lr_scheduler.step() 
            optimizer.zero_grad() 

    epoch_avg_loss = epoch_train_loss / len(train_dataloader)
    epoch_avg_train_losses.append(epoch_avg_loss)
    append_loss_to_file("epoch_avg_train_loss.txt", epoch_avg_train_losses[-1])

    # Setelah setiap epoch selesai, lakukan validasi
    if dev_dataset is not None:
        print("------"*12)
        if args.task_type == 'pretrain':
            print(f"Validating after epoch {epoch+1}...")
            ppl = validate_during_pretrain(model, dev_dataloader, len(tokenizer))
            print(f"Perplexity on dev set after epoch {epoch+1}: {ppl}")

            dev_losses.append(ppl.item())
        elif args.task_type == 'finetune':
            print(f"Validating after epoch {epoch+1}...")
            metrics = validate_during_finetune(model, dev_dataloader, vocab_size, args)
            print(f"weighted sum of KL and NLL loss after epoch {epoch+1}: {metrics['total_loss']}")
            dev_losses.append(metrics)

        append_loss_to_file("dev_loss.txt", dev_losses[-1])

    projector_dir = os.path.join(output_dir, "projector_checkpoints", f"epoch_{epoch+1}")
    os.makedirs(projector_dir, exist_ok=True)
    projector_state_dict = {
        k: v for k, v in model.state_dict().items()
        if k.startswith("projector.")
    }
    torch.save(projector_state_dict, os.path.join(projector_dir, "projector.pth"))

    # ======== Cek Early Stopping =========
    # 1. Cek train loss
    if epoch_avg_loss < best_train_loss:
        best_train_loss = epoch_avg_loss
        no_improve_train = 0
    else:
        no_improve_train += 1

    # 2. Cek dev loss (jika ada)
    if dev_dataset is not None:
        current_dev_loss = dev_losses[-1]['total_loss']
        if current_dev_loss < best_dev_loss:
            best_dev_loss = current_dev_loss
            no_improve_dev = 0
        else:
            no_improve_dev += 1

    # 3. Jika salah satu tidak membaik selama 3 epoch
    if no_improve_train >= early_stop_patience or no_improve_dev >= early_stop_patience:
        print("⛔ Pelatihan dihentikan karena tidak ada perbaikan pada loss selama 3 epoch berturut-turut.")
        break
    print(f"SELESAI MENYIMPAN PROJECTOR EPOCH {epoch+1}")
    print()

print("dev loss pertama :", dev_losses[0])
print("dev loss terakhir:", dev_losses[-1])

  ret['retriever_embeddings'] = torch.stack([torch.tensor(x[retriever_embeds_col]).to('cuda:0') for x in samples])
  0%|          | 0/25600 [00:00<?, ?it/s, epoch=1, batch=1]

Starting epoch 1


  1%|          | 281/25600 [02:38<4:04:03,  1.73it/s, epoch=1, batch=281]



  3%|▎         | 672/25600 [06:14<3:56:47,  1.75it/s, epoch=1, batch=672]



  4%|▎         | 956/25600 [09:28<3:31:45,  1.94it/s, epoch=1, batch=956] 



 10%|█         | 2560/25600 [25:17<4:05:42,  1.56it/s, epoch=1, batch=2560] 

------------------------------------------------------------------------
Validating after epoch 1...


 10%|█         | 2561/25600 [28:39<390:11:45, 60.97s/it, epoch=2, batch=1] 

weighted sum of KL and NLL loss after epoch 1: 6.094398741166078
SELESAI MENYIMPAN PROJECTOR EPOCH 1

Starting epoch 2


 13%|█▎        | 3446/25600 [37:12<3:15:32,  1.89it/s, epoch=2, batch=886] 



 15%|█▍        | 3824/25600 [40:48<3:30:59,  1.72it/s, epoch=2, batch=1264]



 16%|█▋        | 4173/25600 [44:11<3:25:42,  1.74it/s, epoch=2, batch=1613] 



 20%|██        | 5120/25600 [53:39<3:10:20,  1.79it/s, epoch=2, batch=2560] 

------------------------------------------------------------------------
Validating after epoch 2...


 20%|██        | 5121/25600 [56:12<262:54:39, 46.22s/it, epoch=3, batch=1] 

weighted sum of KL and NLL loss after epoch 2: 5.681288648409893
SELESAI MENYIMPAN PROJECTOR EPOCH 2

Starting epoch 3


 23%|██▎       | 5896/25600 [1:03:43<3:29:17,  1.57it/s, epoch=3, batch=776]



 25%|██▍       | 6300/25600 [1:07:26<3:08:18,  1.71it/s, epoch=3, batch=1180]



 29%|██▊       | 7301/25600 [1:17:08<2:37:44,  1.93it/s, epoch=3, batch=2181]



 30%|███       | 7680/25600 [1:21:20<3:03:33,  1.63it/s, epoch=3, batch=2560]

------------------------------------------------------------------------
Validating after epoch 3...


 30%|███       | 7681/25600 [1:23:16<175:05:03, 35.18s/it, epoch=4, batch=1] 

weighted sum of KL and NLL loss after epoch 3: 5.559235589664311
SELESAI MENYIMPAN PROJECTOR EPOCH 3

Starting epoch 4


 30%|███       | 7798/25600 [1:24:26<3:04:27,  1.61it/s, epoch=4, batch=118]



 31%|███       | 7937/25600 [1:25:47<2:29:01,  1.98it/s, epoch=4, batch=257]



 32%|███▏      | 8205/25600 [1:28:24<2:34:41,  1.87it/s, epoch=4, batch=525]



 40%|████      | 10240/25600 [1:47:36<8:07:50,  1.91s/it, epoch=4, batch=2560]

------------------------------------------------------------------------
Validating after epoch 4...


 40%|████      | 10241/25600 [1:53:23<449:27:07, 105.35s/it, epoch=5, batch=1]

weighted sum of KL and NLL loss after epoch 4: 5.516756846289752
SELESAI MENYIMPAN PROJECTOR EPOCH 4

Starting epoch 5


 42%|████▏     | 10727/25600 [1:58:11<2:30:45,  1.64it/s, epoch=5, batch=487] 



 43%|████▎     | 10959/25600 [2:00:22<2:02:40,  1.99it/s, epoch=5, batch=719]



 47%|████▋     | 11984/25600 [2:10:01<2:09:25,  1.75it/s, epoch=5, batch=1744]



 50%|█████     | 12800/25600 [2:18:26<1:48:13,  1.97it/s, epoch=5, batch=2560] 

------------------------------------------------------------------------
Validating after epoch 5...


 50%|█████     | 12801/25600 [2:23:30<324:55:36, 91.39s/it, epoch=6, batch=1] 

weighted sum of KL and NLL loss after epoch 5: 5.478108436395759
SELESAI MENYIMPAN PROJECTOR EPOCH 5

Starting epoch 6


 53%|█████▎    | 13461/25600 [2:29:52<2:00:16,  1.68it/s, epoch=6, batch=661]



 54%|█████▍    | 13767/25600 [2:32:48<1:56:23,  1.69it/s, epoch=6, batch=967]



 58%|█████▊    | 14874/25600 [2:43:07<1:35:13,  1.88it/s, epoch=6, batch=2074]



 60%|██████    | 15360/25600 [2:47:58<1:40:10,  1.70it/s, epoch=6, batch=2560]

------------------------------------------------------------------------
Validating after epoch 6...


 60%|██████    | 15361/25600 [2:51:28<180:19:00, 63.40s/it, epoch=7, batch=1] 

weighted sum of KL and NLL loss after epoch 6: 5.47192469081272
SELESAI MENYIMPAN PROJECTOR EPOCH 6

Starting epoch 7


 62%|██████▏   | 15953/25600 [2:57:08<1:34:18,  1.70it/s, epoch=7, batch=593]



 63%|██████▎   | 16173/25600 [2:59:16<1:25:40,  1.83it/s, epoch=7, batch=813]



 68%|██████▊   | 17296/25600 [3:10:03<1:11:57,  1.92it/s, epoch=7, batch=1936]



 70%|███████   | 17920/25600 [3:16:03<1:06:50,  1.92it/s, epoch=7, batch=2560]

------------------------------------------------------------------------
Validating after epoch 7...


 70%|███████   | 17921/25600 [3:17:56<73:35:37, 34.50s/it, epoch=8, batch=1]  

weighted sum of KL and NLL loss after epoch 7: 5.445761097614842
SELESAI MENYIMPAN PROJECTOR EPOCH 7

Starting epoch 8


 73%|███████▎  | 18654/25600 [3:24:53<1:03:37,  1.82it/s, epoch=8, batch=734]



 75%|███████▌  | 19311/25600 [3:31:05<51:55,  2.02it/s, epoch=8, batch=1391]  



 79%|███████▉  | 20227/25600 [3:40:11<50:57,  1.76it/s, epoch=8, batch=2307]  



 80%|████████  | 20480/25600 [3:42:32<49:18,  1.73it/s, epoch=8, batch=2560]  

------------------------------------------------------------------------
Validating after epoch 8...


 80%|████████  | 20481/25600 [3:44:29<50:19:09, 35.39s/it, epoch=9, batch=1]

weighted sum of KL and NLL loss after epoch 8: 5.452062168727915
SELESAI MENYIMPAN PROJECTOR EPOCH 8

Starting epoch 9


 81%|████████  | 20755/25600 [3:47:29<43:18,  1.86it/s, epoch=9, batch=275] 



 86%|████████▌ | 21924/25600 [3:58:59<35:42,  1.72it/s, epoch=9, batch=1444]  



 88%|████████▊ | 22586/25600 [4:05:12<26:44,  1.88it/s, epoch=9, batch=2106]  



 90%|█████████ | 23040/25600 [4:09:25<23:11,  1.84it/s, epoch=9, batch=2560]

------------------------------------------------------------------------
Validating after epoch 9...


 90%|█████████ | 23041/25600 [4:11:42<29:37:02, 41.67s/it, epoch=10, batch=1]

weighted sum of KL and NLL loss after epoch 9: 5.423572769434629
SELESAI MENYIMPAN PROJECTOR EPOCH 9

Starting epoch 10


 94%|█████████▍| 24014/25600 [4:20:55<14:36,  1.81it/s, epoch=10, batch=974] 



 97%|█████████▋| 24732/25600 [4:28:05<07:02,  2.06it/s, epoch=10, batch=1692] 



100%|█████████▉| 25522/25600 [4:35:31<00:45,  1.73it/s, epoch=10, batch=2482]



100%|██████████| 25600/25600 [4:36:15<00:00,  1.67it/s, epoch=10, batch=2560]

------------------------------------------------------------------------
Validating after epoch 10...
weighted sum of KL and NLL loss after epoch 10: 5.415387588339223
SELESAI MENYIMPAN PROJECTOR EPOCH 10

dev loss pertama : {'nll': 2.5380962897526502, 'ppl': 12.65555551297889, 'kl': 1.7781512257067138, 'total_loss': 6.094398741166078}
dev loss terakhir: {'nll': 2.2738929991166077, 'ppl': 9.717156155964403, 'kl': 1.5707472946113075, 'total_loss': 5.415387588339223}


RuntimeError: Could not infer dtype of dict

In [10]:
os.makedirs(model_output_dir, exist_ok=True)
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('../output\\finetuned\\2025-07-19_10-58-09\\finished_model\\tokenizer_config.json',
 '../output\\finetuned\\2025-07-19_10-58-09\\finished_model\\special_tokens_map.json',
 '../output\\finetuned\\2025-07-19_10-58-09\\finished_model\\vocab.json',
 '../output\\finetuned\\2025-07-19_10-58-09\\finished_model\\merges.txt',
 '../output\\finetuned\\2025-07-19_10-58-09\\finished_model\\added_tokens.json',
 '../output\\finetuned\\2025-07-19_10-58-09\\finished_model\\tokenizer.json')