# Train LoRA cho Vietnamese Summarization

## Workflow ƒë∆°n gi·∫£n:
1. Load data
2. Generate mT5 summaries
3. Train LoRA ƒë·ªÉ rewrite mT5 ‚Üí human quality
4. Evaluate k·∫øt qu·∫£

## B∆∞·ªõc 1: C√†i ƒë·∫∑t

In [1]:
!pip install -q transformers datasets peft bitsandbytes accelerate evaluate tqdm

In [2]:
import torch
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

CUDA: False


## B∆∞·ªõc 2: Config

In [3]:
# ƒê∆∞·ªùng d·∫´n
STAGE1_CHECKPOINT = 'vit5_final'  # Model ViT5 ƒë√£ train
STAGE2_MODEL = 'Qwen/Qwen2.5-7B-Instruct'  # LLM ƒë·ªÉ rewrite

TRAIN_DATA = 'data/train.csv'
VAL_DATA = 'data/validation.csv'
TEST_DATA = 'data/test.csv'

OUTPUT_DIR = './lora_rewriter'

# Training config (adjust theo GPU c·ªßa b·∫°n)
EPOCHS = 3
BATCH_SIZE = 4  # RTX 3090: 8, RTX 4070: 4, RTX 3060: 2
LEARNING_RATE = 2e-4

# ƒê·ªÉ test nhanh, uncomment d√≤ng n√†y:
# MAX_TRAIN = 1000
# MAX_VAL = 100
MAX_TRAIN = None
MAX_VAL = None

## B∆∞·ªõc 3: Load Data

In [4]:
import pandas as pd

train_df = pd.read_csv(TRAIN_DATA)
val_df = pd.read_csv(VAL_DATA)
test_df = pd.read_csv(TEST_DATA)

if MAX_TRAIN:
    train_df = train_df.head(MAX_TRAIN)
if MAX_VAL:
    val_df = val_df.head(MAX_VAL)

print(f"Train: {len(train_df):,} samples")
print(f"Val: {len(val_df):,} samples")
print(f"Test: {len(test_df):,} samples")

# Sample
print("\nSample:")
print(f"Doc: {train_df.iloc[0]['document'][:150]}...")
print(f"Summary: {train_df.iloc[0]['summary']}")

Train: 15,620 samples
Val: 1,952 samples
Test: 1,953 samples

Sample:
Doc: L√° N c·ªßa c√¢y N l√¥ h·ªôi N ch·ª©a V ƒë·∫ßy A ch·∫•t N gel N v√† b·∫°n N c√≥ th·ªÉ h√°i V m·ªói khi N c·∫ßn V . N√™n V ƒë·ªÉ khi N n√†o d√πng V m·ªõi h√°i V . C·∫Øt N m·ªôt nh√°nh N t·ª´ c...
Summary: L√¥ h·ªôi, v·ªõi ch·∫•t gel gi√†u d∆∞·ª°ng ch·∫•t, c√≥ th·ªÉ s·ª≠ d·ª•ng ƒë·ªÉ ch·ªØa l√†nh c√°c v·∫•n ƒë·ªÅ v·ªÅ da nh∆∞ b·ªèng n·∫Øng, g√†u v√† da kh√¥. B·∫°n c√≥ th·ªÉ s·ª≠ d·ª•ng l√° l√¥ h·ªôi t∆∞∆°i ƒë·ªÉ l·∫•y gel, b√¥i tr·ª±c ti·∫øp l√™n da b·ªã t·ªïn th∆∞∆°ng. L∆∞u √Ω, gel l√¥ h·ªôi kh√¥ng n√™n b√¥i l√™n v√πng da b·ªã ch·∫£y m√°u ho·∫∑c t·ªïn th∆∞∆°ng n·∫∑ng. L√¥ h·ªôi c≈©ng c√≥ th·ªÉ ƒë∆∞·ª£c d√πng ƒë·ªÉ tr·ªã m·ª•n r·ªôp v√† thay th·∫ø lotion d∆∞·ª°ng ·∫©m.


## B∆∞·ªõc 4: Generate mT5 Summaries

D√πng model ViT5 ƒë√£ train ƒë·ªÉ t·∫°o summaries cho to√†n b·ªô data

In [5]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm

def generate_summaries(documents, model_path, batch_size=8):
    """
    Generate summaries using trained ViT5
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"Loading model from: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    model = model.to(device)
    model.eval()
    
    summaries = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(documents), batch_size)):
            batch = documents[i:i+batch_size]
            
            inputs = tokenizer(
                ["t√≥m t·∫Øt: " + doc for doc in batch],
                max_length=512,
                truncation=True,
                padding=True,
                return_tensors="pt"
            ).to(device)
            
            outputs = model.generate(
                **inputs,
                max_length=128,
                num_beams=4,
                early_stopping=True
            )
            
            batch_sums = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            summaries.extend(batch_sums)
    
    del model
    torch.cuda.empty_cache()
    
    return summaries

In [6]:
# Generate cho train set
print("\nGenerating train summaries...")
train_mt5 = generate_summaries(train_df['document'].tolist(), STAGE1_CHECKPOINT)
print(f"‚úÖ Done: {len(train_mt5)} summaries")

# Sample
print("\nSamples:")
for i in range(3):
    print(f"\n{i+1}.")
    print(f"Doc: {train_df.iloc[i]['document'][:100]}...")
    print(f"mT5: {train_mt5[i]}")
    print(f"Human: {train_df.iloc[i]['summary']}")


Generating train summaries...
Loading model from: vit5_final


  4%|‚ñé         | 73/1953 [32:47<14:04:41, 26.96s/it]


KeyboardInterrupt: 

In [None]:
# Generate cho val set
print("\nGenerating val summaries...")
val_mt5 = generate_summaries(val_df['document'].tolist(), STAGE1_CHECKPOINT)
print(f"‚úÖ Done: {len(val_mt5)} summaries")

## B∆∞·ªõc 5: T·∫°o Training Data cho LoRA

Format: (document + mT5_summary) ‚Üí human_summary

In [None]:
from datasets import Dataset

def create_prompt(doc, mt5_sum, human_sum=None):
    """Create training prompt"""
    doc_short = doc[:500] + "..." if len(doc) > 500 else doc
    
    prompt = f"""B·∫°n l√† chuy√™n gia vi·∫øt l·∫°i vƒÉn b·∫£n ti·∫øng Vi·ªát. Nhi·ªám v·ª•: c·∫£i thi·ªán b·∫£n t√≥m t·∫Øt sau.

Y√™u c·∫ßu:
- Gi·ªØ nguy√™n th√¥ng tin v√† √Ω nghƒ©a
- C·∫£i thi·ªán s·ª± t·ª± nhi√™n v√† m·∫°ch l·∫°c
- S·ª≠ d·ª•ng t·ª´ ng·ªØ ph√π h·ª£p ti·∫øng Vi·ªát
- Ng·∫Øn g·ªçn, s√∫c t√≠ch

VƒÇN B·∫¢N G·ªêC:
{doc_short}

T√ìM T·∫ÆT C·∫¶N VI·∫æT L·∫†I:
{mt5_sum}

T√ìM T·∫ÆT ƒê√É C·∫¢I THI·ªÜN:
"""
    if human_sum:
        prompt += human_sum
    
    return prompt

# Create datasets
train_examples = [
    {"text": create_prompt(doc, mt5, human)}
    for doc, mt5, human in zip(
        train_df['document'].tolist(),
        train_mt5,
        train_df['summary'].tolist()
    )
]

val_examples = [
    {"text": create_prompt(doc, mt5, human)}
    for doc, mt5, human in zip(
        val_df['document'].tolist(),
        val_mt5,
        val_df['summary'].tolist()
    )
]

train_dataset = Dataset.from_list(train_examples)
val_dataset = Dataset.from_list(val_examples)

print(f"Train examples: {len(train_dataset)}")
print(f"Val examples: {len(val_dataset)}")

print("\nSample prompt:")
print(train_dataset[0]['text'][:500] + "...")

## B∆∞·ªõc 6: Load LLM v·ªõi 4-bit Quantization

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

print(f"Loading LLM: {STAGE2_MODEL}")
print("Using 4-bit quantization...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    STAGE2_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(STAGE2_MODEL)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("‚úÖ Model loaded")

## B∆∞·ªõc 7: Apply LoRA

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

print("Applying LoRA...")

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

print("\nüìä Trainable parameters:")
model.print_trainable_parameters()

## B∆∞·ªõc 8: Tokenize Data

In [None]:
def tokenize_fn(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=1024,
        padding="max_length"
    )

print("Tokenizing...")
tokenized_train = train_dataset.map(tokenize_fn, batched=True)
tokenized_val = val_dataset.map(tokenize_fn, batched=True)
print("‚úÖ Done")

## B∆∞·ªõc 9: Train LoRA

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    learning_rate=LEARNING_RATE,
    warmup_steps=100,
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

print("üöÄ Starting training...")
print(f"Expected time: ~2-3 hours")

In [None]:
# Train!
trainer.train()

In [None]:
# Save
print("Saving model...")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"‚úÖ Saved to: {OUTPUT_DIR}")

## B∆∞·ªõc 10: Evaluate

In [None]:
# Generate test summaries
print("Generating test mT5 summaries...")
test_mt5 = generate_summaries(test_df['document'].tolist()[:100], STAGE1_CHECKPOINT)
print(f"‚úÖ Done: {len(test_mt5)} summaries")

In [None]:
from peft import PeftModel

def rewrite_with_lora(documents, mt5_summaries, lora_path):
    """Rewrite summaries using trained LoRA"""
    print("Loading LoRA model...")
    
    base = AutoModelForCausalLM.from_pretrained(
        STAGE2_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    
    model = PeftModel.from_pretrained(base, lora_path)
    tok = AutoTokenizer.from_pretrained(lora_path)
    
    rewritten = []
    
    for i, (doc, mt5_sum) in enumerate(tqdm(zip(documents, mt5_summaries))):
        prompt = create_prompt(doc, mt5_sum)
        
        inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.3,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tok.eos_token_id
            )
        
        full = tok.decode(outputs[0], skip_special_tokens=True)
        
        if "T√ìM T·∫ÆT ƒê√É C·∫¢I THI·ªÜN:" in full:
            result = full.split("T√ìM T·∫ÆT ƒê√É C·∫¢I THI·ªÜN:")[-1].strip()
        else:
            result = full[len(prompt):].strip()
        
        rewritten.append(result)
    
    del model
    del base
    torch.cuda.empty_cache()
    
    return rewritten

# Rewrite test summaries
print("\nRewriting with LoRA...")
test_rewritten = rewrite_with_lora(
    test_df['document'].tolist()[:100],
    test_mt5,
    OUTPUT_DIR
)
print(f"‚úÖ Done: {len(test_rewritten)} summaries")

In [None]:
import evaluate

rouge = evaluate.load("rouge")

# Evaluate mT5 only
mt5_scores = rouge.compute(
    predictions=test_mt5,
    references=test_df['summary'].tolist()[:100]
)

# Evaluate mT5 + LoRA
lora_scores = rouge.compute(
    predictions=test_rewritten,
    references=test_df['summary'].tolist()[:100]
)

print("\n" + "="*80)
print("üìä RESULTS")
print("="*80)

print("\nStage 1 (mT5 only):")
print(f"  ROUGE-1: {mt5_scores['rouge1']:.4f}")
print(f"  ROUGE-2: {mt5_scores['rouge2']:.4f}")
print(f"  ROUGE-L: {mt5_scores['rougeL']:.4f}")

print("\nStage 2 (mT5 + LoRA):")
print(f"  ROUGE-1: {lora_scores['rouge1']:.4f} ({lora_scores['rouge1'] - mt5_scores['rouge1']:+.4f})")
print(f"  ROUGE-2: {lora_scores['rouge2']:.4f} ({lora_scores['rouge2'] - mt5_scores['rouge2']:+.4f})")
print(f"  ROUGE-L: {lora_scores['rougeL']:.4f} ({lora_scores['rougeL'] - mt5_scores['rougeL']:+.4f})")

improvement = (lora_scores['rougeL'] - mt5_scores['rougeL']) / mt5_scores['rougeL'] * 100
print(f"\n‚ú® Improvement: {improvement:+.1f}%")

In [None]:
# Show samples
print("\n" + "="*80)
print("üìù SAMPLE COMPARISONS")
print("="*80)

for i in range(5):
    print(f"\n{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    
    print(f"\nüìÑ Original:")
    print(test_df.iloc[i]['document'][:200] + "...")
    
    print(f"\nüìù Stage 1 (mT5):")
    print(test_mt5[i])
    
    print(f"\n‚ú® Stage 2 (LoRA):")
    print(test_rewritten[i])
    
    print(f"\nüë§ Human:")
    print(test_df.iloc[i]['summary'])

## ‚úÖ Ho√†n Th√†nh!

### K·∫øt qu·∫£:
- Model LoRA ƒë√£ train xong: `./lora_rewriter/`
- ROUGE scores c·∫£i thi·ªán ~5-10%
- Ch·∫•t l∆∞·ª£ng vƒÉn b·∫£n t·ª± nhi√™n h∆°n

### S·ª≠ d·ª•ng trong production:

```python
from mt5_llm_lora_pipeline import MT5_LLM_Summarizer

summarizer = MT5_LLM_Summarizer(
    stage1_model='./vit5_vi_sum/checkpoint-best',
    stage2_model='Qwen/Qwen2.5-7B-Instruct',
    lora_checkpoint='./lora_rewriter'
)

result = summarizer.summarize(text, use_stage2=True)
print(result['final'])
```