# Fine-Tuning Pipeline - Experiment Notebook

Notebook for fine-tuning LLMs for NER extraction using LoRA or full fine-tuning.

## Setup and Imports

In [None]:
import json
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List

from loguru import logger
from tqdm import tqdm
import pandas as pd

sys.path.append("..")

from src.config import (
    CHECKPOINTS_DIR,
    PROCESSED_DATA_DIR,
    RESULTS_DIR,
    NERFineTuningConfig,
)
from src.finetuning_pipeline import FineTunedNERExtractor
from src.data_processor import DataProcessor
from src.utils import calculate_metrics, display_metrics

logger.info("Setup complete")

## Training Configuration

Configure fine-tuning parameters. You can train with LoRA (parameter-efficient) or full fine-tuning.

In [None]:
EXPERIMENT_NAME = "qwen3_4b_lora_r16"

config = NERFineTuningConfig(
    # Model settings
    model_name="Qwen/Qwen3-4B",
    max_seq_length=2048,
    load_in_4bit=True,
    load_in_8bit=False,
    full_finetuning=False,
    
    # Prompt settings
    add_schema=False,
    enable_thinking=False,
    
    # Training settings
    max_steps=1000,
    num_epochs=3,
    batch_size=4,
    learning_rate=1e-4,
    warmup_steps=100,
    gradient_accumulation_steps=4,
    lr_scheduler_type="linear",
    weight_decay=0.01,
    optim="adamw_torch_fused",
    
    # LoRA settings (only used if full_finetuning=False)
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    
    # Logging and checkpointing
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    eval_steps=100,
    
    # Output directory
    output_dir=CHECKPOINTS_DIR / f"{EXPERIMENT_NAME}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    
    # Data paths
    train_data_path=PROCESSED_DATA_DIR / "train_finetuning_chat.jsonl",
    val_data_path=PROCESSED_DATA_DIR / "test_finetuning_chat.jsonl",
    
    # Training options
    resume_from_checkpoint=None,
    report_to="none",  # "none", "wandb", or "tensorboard"
)

print("Fine-Tuning Configuration:")
print("=" * 80)
print(f"Model: {config.model_name}")
print(f"Max sequence length: {config.max_seq_length}")
print(f"Quantization: {'4-bit' if config.load_in_4bit else '8-bit' if config.load_in_8bit else 'None'}")
print(f"Training mode: {'Full fine-tuning' if config.full_finetuning else 'LoRA'}")
if not config.full_finetuning:
    print(f"  LoRA rank: {config.lora_r}")
    print(f"  LoRA alpha: {config.lora_alpha}")
    print(f"  LoRA dropout: {config.lora_dropout}")
print(f"\nTraining Settings:")
print(f"  Max steps: {config.max_steps}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"  Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Warmup steps: {config.warmup_steps}")
print(f"  Weight decay: {config.weight_decay}")
print(f"  Optimizer: {config.optim}")
print(f"\nOutput directory: {config.output_dir}")
print("=" * 80)

## Validate Training Data

Check that training and validation data exist.

In [None]:
# Check training data
if not config.train_data_path.exists():
    raise FileNotFoundError(f"Training data not found: {config.train_data_path}")
else:
    # Count training samples
    with open(config.train_data_path, 'r') as f:
        train_samples = sum(1 for _ in f)
    logger.success(f"Training data found: {train_samples} samples")

# Check validation data
if not config.val_data_path.exists():
    logger.warning(f"Validation data not found: {config.val_data_path}")
    logger.warning("Training will proceed without validation")
    config.val_data_path = None
    val_samples = 0
else:
    with open(config.val_data_path, 'r') as f:
        val_samples = sum(1 for _ in f)
    logger.success(f"Validation data found: {val_samples} samples")

print("\n" + "=" * 80)
print("DATA SUMMARY")
print("=" * 80)
print(f"Training samples: {train_samples}")
print(f"Validation samples: {val_samples}")
print("=" * 80)

# Show sample training data
print("\nSample training data:")
with open(config.train_data_path, 'r') as f:
    sample = json.loads(f.readline())
    print(json.dumps(sample, indent=2, ensure_ascii=False)[:500] + "...")

## Initialize Model

Load the model and prepare for training. This will load the base model with the specified quantization settings.

In [None]:
logger.info("Initializing model...")
extractor = FineTunedNERExtractor(config)
logger.success("Model initialized successfully")

## Start Training

In [None]:
logger.info("=" * 80)
logger.info("STARTING TRAINING")
logger.info("=" * 80)

start_time = time.time()

# Start training
trainer_stats = extractor.train()

elapsed_time = time.time() - start_time

logger.success("=" * 80)
logger.success("TRAINING COMPLETE!")
logger.success("=" * 80)
logger.info(f"Model saved to: {config.output_dir / 'final'}")
logger.info(f"Total training time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} minutes)")

## Training Statistics

Display training metrics and statistics.

In [None]:
if trainer_stats:
    print("\n" + "=" * 80)
    print("TRAINING STATISTICS")
    print("=" * 80)
    
    metrics = trainer_stats.metrics
    
    if 'train_runtime' in metrics:
        runtime = metrics['train_runtime']
        print(f"Training runtime: {runtime:.2f}s ({runtime/60:.2f} minutes)")
    
    if 'train_samples' in metrics:
        print(f"Training samples: {metrics['train_samples']}")
    
    if 'train_steps_per_second' in metrics:
        print(f"Steps per second: {metrics['train_steps_per_second']:.2f}")
    
    if 'train_samples_per_second' in metrics:
        print(f"Samples per second: {metrics['train_samples_per_second']:.2f}")
    
    if 'train_loss' in metrics:
        print(f"\nFinal training loss: {metrics['train_loss']:.4f}")
    
    if 'eval_loss' in metrics:
        print(f"Final validation loss: {metrics['eval_loss']:.4f}")
    
    print("=" * 80)
    
    # Display all metrics
    print("\nAll metrics:")
    for key, value in sorted(metrics.items()):
        print(f"  {key}: {value}")
else:
    print("No training statistics available")

## Load Test Dataset for Evaluation

Load test data to evaluate the fine-tuned model.

In [None]:
test_dataset_path = PROCESSED_DATA_DIR / "test.json"

if not test_dataset_path.exists():
    logger.warning(f"Test dataset not found: {test_dataset_path}")
    logger.warning("Skipping evaluation")
    test_dataset = None
else:
    logger.info(f"Loading test dataset from {test_dataset_path}")
    test_dataset = DataProcessor.load_dataset(test_dataset_path)
    logger.success(f"Loaded {len(test_dataset)} test samples")
    
    print("\n" + "=" * 80)
    print("TEST DATASET EXAMPLE")
    print("=" * 80)
    print(f"Text:\n{test_dataset[0]['text'][:300]}...\n")
    print(f"Entities:\n{json.dumps(test_dataset[0]['entities'], indent=2, ensure_ascii=False)}")
    print("=" * 80)

## Test Single Sample

Test the fine-tuned model on a single sample before full evaluation.

In [None]:
if test_dataset:
    test_text = test_dataset[0]["text"]
    test_label = test_dataset[0]["entities"]
    
    print("=" * 80)
    print("SINGLE SAMPLE TEST - FINE-TUNED MODEL")
    print("=" * 80)
    print(f"\nInput text:\n{test_text[:300]}...\n")
    print(f"Ground truth:\n{json.dumps(test_label, indent=2, ensure_ascii=False)}\n")
    
    logger.info("Running extraction on test sample...")
    prediction = extractor.extract_entities(test_text)
    
    print(f"Prediction:\n{json.dumps(prediction, indent=2, ensure_ascii=False)}")
    print("=" * 80)

## Run Full Evaluation

Evaluate the fine-tuned model on the entire test dataset.

In [None]:
if test_dataset:
    # Prepare data
    texts = [sample["text"] for sample in test_dataset]
    labels = [sample["entities"] for sample in test_dataset]
    
    logger.info(f"Starting evaluation on {len(test_dataset)} samples...")
    print("=" * 80)
    print(f"Running evaluation on {len(test_dataset)} samples")
    print("=" * 80)
    
    eval_start_time = time.time()
    predictions = []
    
    for i, (text, label) in enumerate(tqdm(zip(texts, labels), desc="Extracting entities", total=len(texts))):
        prediction = extractor.extract_entities(text)
        predictions.append(prediction)
        
        # Show first few predictions for debugging
        if i < 3:
            logger.debug(f"Sample {i+1}:")
            logger.debug(f"  Ground truth: {label}")
            logger.debug(f"  Prediction  : {prediction}")
    
    eval_elapsed_time = time.time() - eval_start_time
    throughput = len(test_dataset) / eval_elapsed_time if eval_elapsed_time > 0 else 0
    avg_time = eval_elapsed_time / len(test_dataset)
    
    print("\n" + "=" * 80)
    print("EVALUATION COMPLETE")
    print("=" * 80)
    print(f"Total samples: {len(test_dataset)}")
    print(f"Total time: {eval_elapsed_time:.2f}s")
    print(f"Throughput: {throughput:.2f} samples/s")
    print(f"Avg time per sample: {avg_time:.2f}s")
    print("=" * 80)
else:
    logger.warning("No test dataset available for evaluation")

## Calculate Metrics

In [None]:
if test_dataset:
    logger.info("Calculating metrics...")
    metrics = calculate_metrics(predictions, labels)
    
    display_metrics(metrics, title=f"METRICS - {EXPERIMENT_NAME}")

## Analyze Sample Results

Look at some examples to understand model performance.

In [None]:
if test_dataset:
    # Analyze predictions
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTIONS ANALYSIS")
    print("=" * 80)
    
    # Show first 3 samples
    for i in range(min(3, len(test_dataset))):
        print(f"\nSample {i+1}:")
        print("-" * 80)
        print(f"Text: {texts[i][:200]}...")
        print("\nGround Truth:")
        print(json.dumps(labels[i], indent=2, ensure_ascii=False))
        print("\nPrediction:")
        print(json.dumps(predictions[i], indent=2, ensure_ascii=False))
        print("-" * 80)
    
    # Find errors
    print("\n" + "=" * 80)
    print("ERROR ANALYSIS")
    print("=" * 80)
    
    error_count = 0
    for i, (pred, truth) in enumerate(zip(predictions, labels)):
        has_error = False
        for entity_type in ["person", "organizations", "address"]:
            pred_set = set(pred.get(entity_type, []))
            truth_set = set(truth.get(entity_type, []))
            if pred_set != truth_set:
                has_error = True
                break
        
        if has_error:
            error_count += 1
    
    accuracy = (len(test_dataset) - error_count) / len(test_dataset) * 100
    print(f"Samples with errors: {error_count} / {len(test_dataset)}")
    print(f"Perfect match accuracy: {accuracy:.2f}%")
    print("=" * 80)

## Save Evaluation Results

In [None]:
if test_dataset:
    # Save results
    results_dir = RESULTS_DIR / "finetuning_eval" / EXPERIMENT_NAME
    results_dir.mkdir(parents=True, exist_ok=True)
    
    results = {
        "experiment_name": EXPERIMENT_NAME,
        "config": {
            "model_name": config.model_name,
            "max_seq_length": config.max_seq_length,
            "load_in_4bit": config.load_in_4bit,
            "full_finetuning": config.full_finetuning,
            "lora_r": config.lora_r if not config.full_finetuning else None,
            "lora_alpha": config.lora_alpha if not config.full_finetuning else None,
            "learning_rate": config.learning_rate,
            "batch_size": config.batch_size,
            "gradient_accumulation_steps": config.gradient_accumulation_steps,
        },
        "training_stats": trainer_stats.metrics if trainer_stats else {},
        "evaluation_metrics": metrics,
        "evaluation_performance": {
            "total_samples": len(test_dataset),
            "elapsed_time": round(eval_elapsed_time, 2),
            "throughput": round(throughput, 2),
            "avg_time_per_sample": round(avg_time, 2),
        },
    }
    
    results_file = results_dir / "evaluation_results.json"
    with open(results_file, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    logger.success(f"Results saved to {results_file}")
    print(f"\n✅ Evaluation results saved to: {results_file}")

## Model Information

Summary of the fine-tuned model location and how to use it.

In [None]:
print("\n" + "=" * 80)
print("FINE-TUNED MODEL INFORMATION")
print("=" * 80)
print(f"Model saved at: {config.output_dir / 'final'}")
print(f"\nTo use this model later:")
print(f"""\n
from src.finetuning import FineTunedNERExtractor
from src.config import NERFineTuningConfig

config = NERFineTuningConfig(
    model_name="{config.model_name}",
    trained_model_path="{config.output_dir / 'final'}",
    load_in_4bit={config.load_in_4bit},
)

extractor = FineTunedNERExtractor(config)
entities = extractor.extract_entities("Your text here")
""")
print("=" * 80)

---

## Configuration Templates

Copy these configurations into the "Training Configuration" cell to try different setups:

### 1. LoRA Training (Default - Memory Efficient)
```python
EXPERIMENT_NAME = "lora_r16_efficient"
config = NERFineTuningConfig(
    model_name="Qwen/Qwen3-4B",
    load_in_4bit=True,
    full_finetuning=False,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    batch_size=4,
    learning_rate=1e-4,
    max_steps=1000,
)
```

### 2. Higher Rank LoRA (Better Performance)
```python
EXPERIMENT_NAME = "lora_r64_high_rank"
config = NERFineTuningConfig(
    model_name="Qwen/Qwen3-4B",
    load_in_4bit=True,
    full_finetuning=False,
    lora_r=64,
    lora_alpha=128,
    lora_dropout=0.05,
    batch_size=4,
    learning_rate=2e-4,
    max_steps=1000,
)
```

### 3. Full Fine-Tuning (Best Performance, Requires More Memory)
```python
EXPERIMENT_NAME = "full_finetuning"
config = NERFineTuningConfig(
    model_name="Qwen/Qwen3-4B",
    load_in_4bit=False,
    load_in_8bit=True,
    full_finetuning=True,
    batch_size=2,
    learning_rate=5e-5,
    max_steps=1000,
)
```

### 4. Quick Training (Fast Testing)
```python
EXPERIMENT_NAME = "quick_test"
config = NERFineTuningConfig(
    model_name="Qwen/Qwen3-4B",
    load_in_4bit=True,
    full_finetuning=False,
    lora_r=8,
    batch_size=8,
    learning_rate=2e-4,
    max_steps=100,
    save_steps=50,
)
```

### 5. With Schema in Prompt
```python
EXPERIMENT_NAME = "lora_with_schema"
config = NERFineTuningConfig(
    model_name="Qwen/Qwen3-4B",
    load_in_4bit=True,
    full_finetuning=False,
    add_schema=True,
    lora_r=16,
    batch_size=4,
    learning_rate=1e-4,
    max_steps=1000,
)
```

### 6. With Thinking Mode
```python
EXPERIMENT_NAME = "lora_thinking_mode"
config = NERFineTuningConfig(
    model_name="Qwen/Qwen3-4B",
    load_in_4bit=True,
    full_finetuning=False,
    enable_thinking=True,
    lora_r=16,
    batch_size=4,
    learning_rate=1e-4,
    max_steps=1000,
)
```

### Tips:

**Memory Usage:**
- 4-bit quantization: ~4-6GB VRAM
- 8-bit quantization: ~8-10GB VRAM
- Full precision: ~16GB+ VRAM

**LoRA Rank:**
- r=8: Fast, less parameters, good for simple tasks
- r=16: Balanced (recommended)
- r=64: More expressive, better for complex tasks

**Learning Rate:**
- Full fine-tuning: 5e-5 to 1e-4
- LoRA: 1e-4 to 3e-4

**Batch Size:**
- Use gradient accumulation for larger effective batch sizes
- Effective batch size = batch_size × gradient_accumulation_steps
- Target: 16-32 effective batch size