# Task 3 â€” Fine-tuning Decoder-only LLM (Phi-2) untuk Summarization (XSum)

**Target UAS:** fine-tune model decoder-only (Phi-2) untuk membuat ringkasan abstraktif pada dataset XSum.

**Catatan resource:** Phi-2 relatif besar. Banyak orang memakai PEFT/LoRA + 4-bit quantization agar muat di GPU terbatas.
- Template ini menyediakan jalur **LoRA + 4-bit** (opsional).
- Jika kamu full fine-tune, kamu mungkin butuh GPU memory besar.

Tanggal template: 2026-01-05

## 0. Setup
**TODO:** pastikan `bitsandbytes` kompatibel dengan environment kamu (terutama di Windows/local).

Jika kamu tidak bisa memakai 4-bit, kamu bisa:
- pakai LoRA tanpa quantization (butuh VRAM lebih)
- atau pakai model lebih kecil (jika diizinkan)

In [None]:
import os
import random
import numpy as np
import torch

from datasets import load_dataset
import evaluate

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    set_seed,
    BitsAndBytesConfig,
)

# Optional (PEFT)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

SEED = 42
set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

## 1. Load dataset & model

**TODO:** pastikan nama model Phi-2 benar sesuai HuggingFace Hub yang kamu pakai.

Dataset XSum fields umumnya:
- `document`
- `summary`

In [None]:
DATASET_NAME = "xsum"

# TODO: pastikan nama model sesuai yang tersedia di HuggingFace Hub
MODEL_NAME = "microsoft/phi-2"

MAX_LENGTH = 512         # panjang prompt+target setelah tokenisasi
MAX_DOC_CHARS = 4000     # pembatas karakter agar dokumen tidak terlalu panjang (TODO: sesuaikan)

LR = 2e-4
BATCH_SIZE = 2
EPOCHS = 1
GRAD_ACCUM = 8

USE_4BIT_LORA = True  # TODO: ubah jika ingin full fine-tune / LoRA tanpa 4-bit

In [None]:
dataset = load_dataset(DATASET_NAME)
print(dataset)
metric = evaluate.load("rouge")

## 2. Load tokenizer & model

Untuk model causal LM:
- kita membuat prompt "Summarize: {document}\nSummary:" lalu targetnya `summary`
- saat training, label biasanya sama dengan input_ids (shift internal oleh model)

**TODO:** cek `tokenizer.pad_token` (beberapa LLM tidak punya pad token).

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if USE_4BIT_LORA:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
    )
    model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
else:
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

## 3. Preprocessing

**TODO:** jika butuh cepat, bisa subset dataset (misalnya 10k contoh).

In [None]:
def build_prompt(doc: str) -> str:
    doc = doc[:MAX_DOC_CHARS]
    return f"Summarize the following article in 1-2 sentences.\n\nArticle:\n{doc}\n\nSummary:"

def preprocess_batch(examples):
    docs = examples["document"]
    sums = examples["summary"]

    texts = []
    for d, s in zip(docs, sums):
        prompt = build_prompt(d)
        # Untuk causal LM training sederhana: gabungkan prompt + target + EOS
        full = prompt + " " + s + tokenizer.eos_token
        texts.append(full)

    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=MAX_LENGTH,
        padding=False,
    )
    # labels = input_ids (standard causal LM)
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

tokenized = dataset.map(preprocess_batch, batched=True, remove_columns=dataset["train"].column_names)
print(tokenized)

## 4. Trainer

**TODO:** atur output_dir & logging.

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

args = TrainingArguments(
    output_dir="outputs",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to="none",
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

## 5. Training

**TODO:** jalankan training.

In [None]:
# trainer.train()

## 6. Evaluasi ROUGE (setelah training)

Evaluasi summarization biasanya:
- generate summary dari prompt
- bandingkan dengan reference summary (ROUGE)

**TODO:** jalankan cell ini setelah training (dan mungkin pakai subset agar cepat).

In [None]:
# def generate_summary(doc: str, max_new_tokens: int = 64):
#     prompt = build_prompt(doc)
#     inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
#     with torch.no_grad():
#         out = model.generate(
#             **inputs,
#             max_new_tokens=max_new_tokens,
#             do_sample=False,
#             num_beams=4,
#         )
#     text = tokenizer.decode(out[0], skip_special_tokens=True)
#     # Ambil teks setelah "Summary:"
#     if "Summary:" in text:
#         text = text.split("Summary:", 1)[-1].strip()
#     return text

# # Quick eval on small subset
# n = 200
# preds, refs = [], []
# for ex in dataset["validation"].select(range(n)):
#     preds.append(generate_summary(ex["document"]))
#     refs.append(ex["summary"])

# rouge = metric.compute(predictions=preds, references=refs)
# print(rouge)

## 7. Analisis

**TODO:** isi `reports/` dengan:
- ROUGE score
- contoh hasil summary bagus vs buruk
- diskusi abstractive vs extractive
- kendala truncation & panjang dokumen