- dataset: Awesome075/multi_news_parquet
- model: Qwen/Qwen2.5-7B-Instruct
- task: Multi-document summarization (English)

(used 'fine-tuning' conda environment in RTX5090 server)

In [None]:
from datasets import load_dataset, Dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, PeftModel
from rouge_score import rouge_scorer

In [None]:
import wandb
wandb.init(
    project="qwen25-multinews-finetuning",
    name="lora-7b-instruct",
    config={
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "dataset": "Awesome075/multi_news_parquet",
        "task": "multi-document-summarization",
    }
)

In [None]:
# System message
system_message = """You are a professional summarization model.
You create concise, accurate summaries that capture key information from multiple news articles."""

# Best prompt from evaluate_prompts.py (prompt_2)
prompt = """The following news articles cover the same event. Read all articles and provide a comprehensive summary.

{document}

Provide a summary that captures the main points:"""

In [None]:
def format_data(sample, prompt_template, max_doc_length=6000):
    """Multi-News data to training format
    
    Args:
        sample: dataset sample with 'document' and 'summary' fields
        prompt_template: prompt with {document} placeholder
        max_doc_length: maximum character length for documents (truncation)
    """
    document = sample["document"]
    
    # Truncate if too long (documents are ~2000+ words)
    if len(document) > max_doc_length:
        document = document[:max_doc_length]
        last_period = document.rfind('.')
        if last_period > max_doc_length * 0.8:
            document = document[:last_period + 1]
    
    return {
        "messages": [
            {
                "role": "system",
                "content": system_message,
            },
            {
                "role": "user",
                "content": prompt_template.format(document=document),
            },
            {
                "role": "assistant",
                "content": sample["summary"].strip(),
            },
        ],
    }

In [None]:
# Load Multi-News dataset (Parquet version for compatibility)
dataset = load_dataset("Awesome075/multi_news_parquet")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

# Sample check
print("\nSample data:")
sample = train_dataset[0]
print(f"Document length: {len(sample['document'].split())} words")
print(f"Summary length: {len(sample['summary'].split())} words")
print(f"Number of source docs: {sample['document'].count('|||||') + 1}")

In [None]:
# Model and tokenizer loading
model_id = "Qwen/Qwen2.5-7B-Instruct"
output_dir = "qwen25-multinews-lora"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Ensure padding token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Verify chat template
text = tokenizer.apply_chat_template(
    [{"role": "user", "content": "test"}],
    tokenize=False,
    add_generation_prompt=True
)
print("Chat template sample:")
print(text)

In [None]:
def compute_rouge(predictions, references):
    """Compute ROUGE scores for summarization evaluation
    
    Args:
        predictions: list of generated summaries
        references: list of reference summaries
    
    Returns:
        dict with ROUGE-1, ROUGE-2, ROUGE-L F1 scores
    """
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []
    
    for pred, ref in zip(predictions, references):
        scores = scorer.score(ref, pred)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
    
    return {
        'rouge1': sum(rouge1_scores) / len(rouge1_scores),
        'rouge2': sum(rouge2_scores) / len(rouge2_scores),
        'rougeL': sum(rougeL_scores) / len(rougeL_scores),
    }


def evaluate_model(model, tokenizer, test_samples, num_samples=100, max_new_tokens=300, batch_size=4):
    """Evaluate model on test samples using ROUGE scores with batch processing"""
    model.eval()
    total = min(num_samples, len(test_samples))
    
    predictions = []
    references = []
    
    # Process in batches for efficiency
    for batch_start in range(0, total, batch_size):
        batch_end = min(batch_start + batch_size, total)
        batch_samples = test_samples[batch_start:batch_end]
        
        # Prepare batch inputs
        batch_texts = []
        batch_refs = []
        for sample in batch_samples:
            messages = sample["messages"][:2]
            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            batch_texts.append(text)
            batch_refs.append(sample["messages"][2]["content"].strip())
        
        # Tokenize batch with left padding for generation
        tokenizer.padding_side = "left"
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=4096,
        ).to(model.device)
        
        # Generate
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        # Decode outputs (skip input tokens for each sample)
        for i, (gen_ids, input_len) in enumerate(zip(generated_ids, inputs.input_ids)):
            # Find actual input length (excluding padding)
            actual_input_len = (input_len != tokenizer.pad_token_id).sum().item()
            output = tokenizer.decode(
                gen_ids[inputs.input_ids.shape[1]:],
                skip_special_tokens=True
            ).strip()
            predictions.append(output)
        
        references.extend(batch_refs)
        
        if (batch_end) % 20 == 0 or batch_end == total:
            print(f"Evaluated {batch_end}/{total} samples")
    
    # Reset padding side
    tokenizer.padding_side = "right"
    
    # Compute ROUGE scores
    scores = compute_rouge(predictions, references)
    return scores, predictions, references

In [None]:
# Prepare formatted datasets
# Use subset for faster training (optional - can use full dataset)
train_subset_size = 10000  # Adjust based on compute resources
val_subset_size = 1000

train_formatted = [
    format_data(row, prompt)
    for row in train_dataset.select(range(min(train_subset_size, len(train_dataset))))
]
val_formatted = [
    format_data(row, prompt)
    for row in val_dataset.select(range(min(val_subset_size, len(val_dataset))))
]
test_formatted = [
    format_data(row, prompt)
    for row in test_dataset.select(range(1000))  # 1000 samples for evaluation
]

print(f"Train formatted: {len(train_formatted)} samples")
print(f"Validation formatted: {len(val_formatted)} samples")
print(f"Test formatted: {len(test_formatted)} samples")

# Convert to Dataset objects
train_formatted = Dataset.from_list(train_formatted)
val_formatted = Dataset.from_list(val_formatted)

In [None]:
# Maximum sequence length (important for long documents)
max_seq_length = 4096  # Qwen2.5 supports up to 32K, but 4096 is reasonable for memory

def collate_fn(batch):
    """
    Collate function with proper label masking for Qwen2.5 chat template.
    Only compute loss on assistant response tokens.
    Uses string-based search for efficiency.
    """
    new_batch = {
        "input_ids": [],
        "attention_mask": [],
        "labels": []
    }
    
    # Qwen2.5 chat template markers
    assistant_start_str = "<|im_start|>assistant\n"
    assistant_end_str = "<|im_end|>"
    
    for example in batch:
        # Apply chat template
        text = tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
            add_generation_prompt=False
        )
        
        # Find assistant response boundaries in text (more efficient)
        assistant_start_pos = text.find(assistant_start_str)
        if assistant_start_pos == -1:
            # Fallback: no assistant response found, mask everything
            tokenized = tokenizer(
                text,
                truncation=True,
                max_length=max_seq_length,
                padding=False,
                return_tensors=None,
            )
            input_ids = tokenized["input_ids"]
            attention_mask = tokenized["attention_mask"]
            labels = [-100] * len(input_ids)
        else:
            # Split text into prefix (to mask) and response (to learn)
            prefix_text = text[:assistant_start_pos + len(assistant_start_str)]
            response_text = text[assistant_start_pos + len(assistant_start_str):]
            
            # Tokenize prefix
            prefix_tokens = tokenizer(
                prefix_text,
                add_special_tokens=True,
                truncation=True,
                max_length=max_seq_length,
                padding=False,
                return_tensors=None,
            )
            
            # Tokenize full text
            tokenized = tokenizer(
                text,
                truncation=True,
                max_length=max_seq_length,
                padding=False,
                return_tensors=None,
            )
            
            input_ids = tokenized["input_ids"]
            attention_mask = tokenized["attention_mask"]
            
            # Mask prefix tokens, keep response tokens
            prefix_len = len(prefix_tokens["input_ids"])
            labels = [-100] * prefix_len + input_ids[prefix_len:]
            
            # Ensure labels length matches input_ids
            if len(labels) < len(input_ids):
                labels = labels + input_ids[len(labels):]
            elif len(labels) > len(input_ids):
                labels = labels[:len(input_ids)]
        
        new_batch["input_ids"].append(input_ids)
        new_batch["attention_mask"].append(attention_mask)
        new_batch["labels"].append(labels)
    
    # Apply padding
    max_length = max(len(ids) for ids in new_batch["input_ids"])
    
    for i in range(len(new_batch["input_ids"])):
        padding_length = max_length - len(new_batch["input_ids"][i])
        new_batch["input_ids"][i].extend([tokenizer.pad_token_id] * padding_length)
        new_batch["attention_mask"][i].extend([0] * padding_length)
        new_batch["labels"][i].extend([-100] * padding_length)
    
    # Convert to tensors
    for k, v in new_batch.items():
        new_batch[k] = torch.tensor(v)
    
    return new_batch

In [None]:
# LoRA configuration (7B is large, so LoRA only)
peft_config = LoraConfig(
    lora_alpha=64,
    lora_dropout=0.05,
    r=32,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
)

# Training configuration
lora_args = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=2,
    per_device_train_batch_size=2,      # Small batch for 7B model
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,       # Effective batch size = 16
    gradient_checkpointing=True,         # Save memory
    optim="adamw_torch_fused",
    logging_steps=50,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="epoch",
    bf16=True,
    learning_rate=1e-4,                  # Higher LR for LoRA
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    max_seq_length=max_seq_length,
    remove_unused_columns=False,
    dataset_kwargs={"skip_prepare_dataset": True},
    report_to="wandb",
)

In [None]:
# Create trainer and train
lora_trainer = SFTTrainer(
    model=model,
    args=lora_args,
    train_dataset=train_formatted,
    eval_dataset=val_formatted,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=tokenizer,
)

# Start training
lora_trainer.train()

# Save model
lora_trainer.save_model()

## Evaluation

In [None]:
NUM_SAMPLES = 200  # Number of test samples for evaluation

In [None]:
# Base model evaluation
print("Loading Base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

print("Evaluating Base model...")
base_scores, base_preds, base_refs = evaluate_model(
    base_model, tokenizer, test_formatted, num_samples=NUM_SAMPLES
)
print(f"\nBase model ROUGE scores:")
print(f"  ROUGE-1: {base_scores['rouge1']:.4f}")
print(f"  ROUGE-2: {base_scores['rouge2']:.4f}")
print(f"  ROUGE-L: {base_scores['rougeL']:.4f}")

del base_model
torch.cuda.empty_cache()

In [None]:
# LoRA model evaluation
print("Loading LoRA model...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
lora_model = PeftModel.from_pretrained(base_model, output_dir)

print("Evaluating LoRA model...")
lora_scores, lora_preds, lora_refs = evaluate_model(
    lora_model, tokenizer, test_formatted, num_samples=NUM_SAMPLES
)
print(f"\nLoRA model ROUGE scores:")
print(f"  ROUGE-1: {lora_scores['rouge1']:.4f}")
print(f"  ROUGE-2: {lora_scores['rouge2']:.4f}")
print(f"  ROUGE-L: {lora_scores['rougeL']:.4f}")

# Clean up memory
del lora_model, base_model
torch.cuda.empty_cache()

In [None]:
# Results summary table
print("\n## Results Summary\n")
print("| Model | ROUGE-1 | ROUGE-2 | ROUGE-L |")
print("|-------|---------|---------|---------|")
print(f"| Base | {base_scores['rouge1']:.4f} | {base_scores['rouge2']:.4f} | {base_scores['rougeL']:.4f} |")
print(f"| LoRA | {lora_scores['rouge1']:.4f} | {lora_scores['rouge2']:.4f} | {lora_scores['rougeL']:.4f} |")

# Improvement calculation
print(f"\n### Improvement")
print(f"ROUGE-1: {(lora_scores['rouge1'] - base_scores['rouge1']):+.4f}")
print(f"ROUGE-2: {(lora_scores['rouge2'] - base_scores['rouge2']):+.4f}")
print(f"ROUGE-L: {(lora_scores['rougeL'] - base_scores['rougeL']):+.4f}")

In [None]:
# Qualitative examples
print("\n## Sample Predictions\n")
for i in range(3):
    print(f"### Sample {i+1}")
    print(f"**Reference:**\n{lora_refs[i][:500]}...")
    print(f"\n**Base Model:**\n{base_preds[i][:500]}...")
    print(f"\n**LoRA Model:**\n{lora_preds[i][:500]}...")
    print("-" * 80)