# Fine-tuning TinyLlama-1.1B-Chat for Mathematical Reasoning

##*Environment setup*

In [None]:
!pip install -q h5py typing-extensions wheel
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets

In [None]:
# Display GPU information
!nvidia-smi

# Create output directory
!mkdir -p model_outputs

##*Imports and Configuration*

In [None]:
import random
import json
import torch
import numpy as np
import gc
from tqdm import tqdm
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
import re
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel
import matplotlib.pyplot as plt
import os
import warnings
warnings.filterwarnings("ignore")

In [None]:
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # 1.1B parameter model
print(f"Using model: {MODEL_NAME}")
SAVE_DIRECTORY = "tinyllama_math"
os.makedirs(SAVE_DIRECTORY, exist_ok=True)

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

##Data Loading and Preparation

In [None]:
def load_combined_math_datasets(max_examples=10000):
    """Load and combine multiple math datasets for more robust training"""
    print(f"Loading multiple math datasets (up to {max_examples} examples total)...")
    datasets = []

    # 1. Load NuminaMath dataset (primary source)
    try:
        numina_dataset = load_dataset("PrimeIntellect/NuminaMath-QwQ-CoT-5M")
        # Sample more data - at least 5000 examples if available
        sampled_numina = numina_dataset["train"].shuffle(seed=42).select(
            range(min(5000, len(numina_dataset["train"])))
        )
        datasets.append(sampled_numina)
        print(f"Added {len(sampled_numina)} examples from NuminaMath")
    except Exception as e:
        print(f"Error loading NuminaMath: {e}")

    # 2. Add GSM8K training data
    try:
        gsm8k_train = load_dataset("gsm8k", "main")["train"]
        datasets.append(gsm8k_train)
        print(f"Added {len(gsm8k_train)} examples from GSM8K train set")
    except Exception as e:
        print(f"Error loading GSM8K: {e}")

    # 3. Augment with examples specifically oriented toward step-by-step reasoning
    try:
        cot_dataset = load_dataset("reasoning-machines/gsm-hard", split="train")
        sampled_cot = cot_dataset.shuffle(seed=42).select(range(min(1000, len(cot_dataset))))
        datasets.append(sampled_cot)
        print(f"Added {len(sampled_cot)} examples from reasoning-machines/gsm-hard")
    except Exception as e:
        print(f"Error loading reasoning-machines dataset: {e}")

    # Combine all datasets
    if not datasets:
        raise ValueError("Failed to load any datasets")

    # Shuffle and limit total size
    combined_data = concatenate_datasets(datasets)
    combined_data = combined_data.shuffle(seed=42)
    if len(combined_data) > max_examples:
        combined_data = combined_data.select(range(max_examples))

    print(f"Final combined dataset size: {len(combined_data)} examples")
    return combined_data

In [None]:
def load_gsm8k_evaluation():
    """Load GSM8k test set for evaluation"""
    print("Loading GSM8k for evaluation...")
    gsm8k_dataset = load_dataset("gsm8k", "main")
    test_set = gsm8k_dataset["test"]
    print(f"GSM8k test set loaded. Size: {len(test_set)}")
    return test_set

##*Create Evaluation Dataset*

In [None]:
def create_evaluation_set(gsm8k_eval, size=200):
    """Create a larger evaluation set for more reliable metrics"""
    print("Creating evaluation dataset...")
    gsm8k_samples = gsm8k_eval.select(range(min(size, len(gsm8k_eval))))

    eval_set = []
    for sample in gsm8k_samples:
        solution = sample["answer"]
        # Extract the answer value from the solution more carefully
        answer_matches = re.findall(r"####\s*([-+]?\d*\.?\d+)", solution)
        if answer_matches:
            answer_value = answer_matches[-1].strip()
        else:
            # Fallback to the whole section after ####
            parts = solution.split("####")
            if len(parts) > 1:
                answer_value = parts[-1].strip()
            else:
                # Last resort: take the last number in the solution
                numbers = re.findall(r"([-+]?\d*\.?\d+)", solution)
                answer_value = numbers[-1] if numbers else solution.strip()

        eval_set.append({
            "category": "math_reasoning",
            "question": sample["question"],
            "solution": solution,
            "answer": answer_value
        })

    # Save the evaluation set
    with open(os.path.join(SAVE_DIRECTORY, "evaluation_set.json"), "w") as f:
        json.dump(eval_set, f)

    print(f"Created evaluation set with {len(eval_set)} samples")
    return eval_set

##*Model Evaluation Functions*

In [None]:
def load_model_and_tokenizer(model_name, load_in_4bit=True):
    """Load model and tokenizer with optimized settings"""
    print(f"Loading {model_name}...")

    # Configure quantization with optimal settings
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=load_in_4bit,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )

    # Load tokenizer with special token handling
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Ensure we have proper padding token
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    print("Loading model with memory-efficient settings...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        quantization_config=quantization_config,
        torch_dtype=torch.float16
    )

    return model, tokenizer

##*Format the prompt for TinyLlama Chat models*

In [None]:
def format_prompt_for_tinyllama(question):
    """Format prompt with more explicit instructions for mathematical reasoning"""
    return f"""<|system|>
You are a highly intelligent math assistant that excels at solving complex math problems step by step using chain-of-thought reasoning. You always show your work clearly, explaining each step of your calculation, and double-check your final answer.
<|user|>
Please solve this math problem by breaking it down into clearly defined steps. Show all your work and calculations, and make sure to verify your answer.

Problem: {question}

First understand the problem, identify what is being asked, plan your solution approach, and then solve it step-by-step. After reaching your final answer, include it at the end after '####'.
<|assistant|>"""


##*Extract the final answer from model responses*

In [None]:
def extract_answer(response, question_type="math_reasoning"):
    """Extract the final answer with improved regex patterns"""
    if question_type == "math_reasoning":
        # First try to find answers after #### marker with number extraction
        hash_matches = re.search(r"#{3,}\s*([-+]?\d*\.?\d+)", response, re.DOTALL)
        if hash_matches:
            return hash_matches.group(1).strip()

        # Try to find "the answer is X" pattern
        answer_is_pattern = re.search(r"(?:the\s+answer\s+is|final\s+answer\s*[:=])\s*([-+]?\d*\.?\d+)",
                                     response.lower(), re.DOTALL)
        if answer_is_pattern:
            return answer_is_pattern.group(1).strip()

        # Try to find "Therefore, X" pattern
        therefore_pattern = re.search(r"therefore,?\s*([-+]?\d*\.?\d+)",
                                     response.lower(), re.DOTALL)
        if therefore_pattern:
            return therefore_pattern.group(1).strip()

        # Look for the last number after a logical conclusion marker
        conclusion_pattern = re.search(r"(?:so|thus|hence|finally|in\s+conclusion),?\s*([-+]?\d*\.?\d+)",
                                      response.lower(), re.DOTALL)
        if conclusion_pattern:
            return conclusion_pattern.group(1).strip()

        # Look for the last number in the response
        numbers = re.findall(r"([-+]?\d*\.?\d+)", response)
        if numbers:
            return numbers[-1].strip()

        return response.strip()
    else:
        # For other question types, just return the last sentence or phrase
        response = response.strip()
        sentences = response.split(".")
        if sentences:
            last_sentence = sentences[-1].strip()
            if len(last_sentence) > 50:
                last_sentence = last_sentence[-50:].strip()
            return last_sentence

        # Fallback
        if len(response) > 50:
            return response[-50:].strip()
        return response

##*Evaluate Model on Evaluation Set*

In [None]:
def evaluate_model(model, tokenizer, evaluation_data, model_name, show_examples=True):
    """Evaluate model with more robust answer extraction and error handling"""
    results = []

    for idx, item in enumerate(tqdm(evaluation_data, desc=f"Evaluating {model_name}")):
        try:
            question = item["question"]
            true_answer = item["answer"]
            category = item["category"]

            # Prepare prompt
            prompt = format_prompt_for_tinyllama(question)

            # Tokenize with proper handling
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
            input_ids = inputs.input_ids.to(model.device)

            # Generate with improved settings for math reasoning
            with torch.no_grad():
                outputs = model.generate(
                    input_ids,
                    max_new_tokens=512,  # Longer output for complete reasoning
                    temperature=0.1,     # Lower temperature for more deterministic answers
                    top_p=0.92,          # Slightly narrower sampling
                    do_sample=True,      # Still use sampling for some diversity
                    num_beams=2,         # Simple beam search for better quality
                    pad_token_id=tokenizer.pad_token_id,
                    repetition_penalty=1.1  # Slight penalty to avoid loops
                )

            # Decode and extract response
            full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # For TinyLlama, extract after <|assistant|>
            assistant_part = full_output.split("<|assistant|>")
            if len(assistant_part) > 1:
                response = assistant_part[1].strip()
            else:
                response = full_output[len(prompt):].strip()

            # Extract the answer using improved extraction
            extracted_answer = extract_answer(response, category)

            # Check if correct with flexible numeric matching
            is_correct = False
            if category == "math_reasoning":
                extracted_clean = extracted_answer.replace(",", "").replace("$", "")
                true_clean = true_answer.replace(",", "").replace("$", "")

                # Clean up numbers for comparison
                extracted_numeric = re.sub(r'[^\d.-]', '', extracted_clean)
                true_numeric = re.sub(r'[^\d.-]', '', true_clean)

                try:
                    # Try to compare as numbers (allowing for small floating point differences)
                    extracted_float = float(extracted_numeric)
                    true_float = float(true_numeric)
                    is_correct = abs(extracted_float - true_float) < 1e-6
                except:
                    # If conversion fails, compare as strings
                    is_correct = extracted_numeric == true_numeric

            results.append({
                "category": category,
                "question": question,
                "true_answer": true_answer,
                "model_response": response,
                "extracted_answer": extracted_answer,
                "is_correct": is_correct
            })

            # Display some examples during evaluation
            if show_examples and idx < 3:
                print(f"\nExample {idx+1}:")
                print(f"Question: {question[:100]}...")
                print(f"True answer: {true_answer}")
                print(f"Extracted answer: {extracted_answer}")
                print(f"Correct: {is_correct}")

        except Exception as e:
            print(f"Error evaluating example {idx}: {e}")
            continue

    # Calculate overall accuracy
    correct = sum(1 for r in results if r["is_correct"])
    accuracy = correct / len(results) if results else 0

    print(f"\n{model_name} Evaluation Results:")
    print(f"Overall Accuracy: {accuracy:.4f} ({correct}/{len(results)})")

    return {
        "model_name": model_name,
        "overall_accuracy": accuracy,
        "detailed_results": results
    }

##*Process the dataset for training with memory optimization*

In [None]:
def process_data_for_training(data, tokenizer, max_length=1536):
    """Process dataset with improved handling for variable data structures and longer contexts"""
    print(f"Processing dataset ({len(data)} examples) for training...")

    # Identify correct field names based on dataset structure
    sample = data[0]
    if 'prompt' in sample and 'response' in sample:
        question_key, answer_key = 'prompt', 'response'
    elif 'question' in sample and 'answer' in sample:
        question_key, answer_key = 'question', 'answer'
    elif 'problem' in sample and 'solution' in sample:
        question_key, answer_key = 'problem', 'solution'
    elif 'input' in sample and 'output' in sample:
        question_key, answer_key = 'input', 'output'
    else:
        # Make best guess at field names for other datasets
        fields = list(sample.keys())
        question_candidates = [k for k in fields if any(term in k.lower() for term in
                                                      ['question', 'prompt', 'problem', 'input', 'query'])]
        answer_candidates = [k for k in fields if any(term in k.lower() for term in
                                                    ['answer', 'response', 'solution', 'output', 'completion'])]

        if question_candidates and answer_candidates:
            question_key, answer_key = question_candidates[0], answer_candidates[0]
        else:
            # Use the first two string fields as a last resort
            string_fields = [k for k in fields if isinstance(sample[k], str)]
            if len(string_fields) >= 2:
                question_key, answer_key = string_fields[0], string_fields[1]
            else:
                raise ValueError(f"Cannot identify question and answer fields in dataset with keys: {fields}")

    print(f"Using fields - Question: '{question_key}', Answer: '{answer_key}'")

    def process_batch(batch):
        formatted_texts = []
        attention_masks = []

        for i in range(len(batch[question_key])):
            try:
                question = batch[question_key][i]
                answer = batch[answer_key][i]

                # Skip invalid entries
                if not isinstance(question, str) or not isinstance(answer, str):
                    continue

                # Ensure answers end with the #### marker if not already present
                if '####' not in answer:
                    # Try to find the final numerical answer
                    numbers = re.findall(r"([-+]?\d*\.?\d+)", answer)
                    if numbers:
                        final_number = numbers[-1].strip()
                        # Only add #### if we're not already at the end of the answer
                        if not answer.strip().endswith(final_number):
                            answer = answer.strip() + f"\n\n#### {final_number}"

                # Format for TinyLlama with improved prompt
                formatted_text = f"""<|system|>
You are a highly intelligent math assistant that excels at solving complex math problems step by step using chain-of-thought reasoning. You always show your work clearly, explaining each step of your calculation, and double-check your final answer.
<|user|>
Please solve this math problem by breaking it down into clearly defined steps. Show all your work and calculations, and make sure to verify your answer.

Problem: {question}

First understand the problem, identify what is being asked, plan your solution approach, and then solve it step-by-step. After reaching your final answer, include it at the end after '####'.
<|assistant|>
{answer}"""

                # Tokenize with dynamic length handling
                tokenized = tokenizer(
                    formatted_text,
                    truncation=True,
                    max_length=max_length,
                    padding="max_length",
                    return_attention_mask=True
                )

                formatted_texts.append(tokenized["input_ids"])
                attention_masks.append(tokenized["attention_mask"])

            except Exception as e:
                print(f"Error processing item: {e}")
                continue

        return {
            "input_ids": formatted_texts,
            "attention_mask": attention_masks
        }

    # Process in batches for memory efficiency
    batch_size = 100
    processed_dataset = Dataset.from_dict({"input_ids": [], "attention_mask": []})

    for i in range(0, len(data), batch_size):
        batch_end = min(i + batch_size, len(data))
        batch = {k: data[k][i:batch_end] for k in data.column_names}
        batch_processed = process_batch(batch)

        if batch_processed["input_ids"]:
            batch_dataset = Dataset.from_dict(batch_processed)
            processed_dataset = concatenate_datasets([processed_dataset, batch_dataset])

        if i % 500 == 0:
            print(f"Processed {i}/{len(data)} examples...")
            # Free memory
            clear_memory()

    print(f"Final processed dataset size: {len(processed_dataset)} examples")
    return processed_dataset

##*Create Labels for Training*

In [None]:
def create_labels_from_input_ids(batched_input_ids, tokenizer):
    labels = []

    for input_ids in batched_input_ids:
        # Convert input_ids to string
        full_text = tokenizer.decode(input_ids)

        # For TinyLlama, find assistant's part
        parts = full_text.split("<|assistant|>")
        if len(parts) > 1:
            user_text = parts[0] + "<|assistant|>"
        else:
            user_text = full_text

        # Tokenize user part
        user_ids = tokenizer(user_text, add_special_tokens=False)["input_ids"]

        # Create labels with -100 for user part
        label = [-100] * min(len(user_ids), len(input_ids))

        # Fill the rest with actual values
        if len(user_ids) < len(input_ids):
            label.extend(input_ids[len(user_ids):])

        labels.append(label)

    return labels

##*Fine-tune Model using LoRA*

In [None]:
def fine_tune_model(model_name, training_data, tokenizer=None, output_dir=None):
    """Fine-tune model with improved LoRA configuration and training strategy"""
    if output_dir is None:
        output_dir = os.path.join(SAVE_DIRECTORY, f"ft_{model_name.split('/')[-1].lower().replace('-', '_')}")

    os.makedirs(output_dir, exist_ok=True)

    print(f"\n{'='*50}")
    print(f"Starting improved fine-tuning for {model_name}")
    print(f"{'='*50}")

    # Configure quantization for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # Load model and tokenizer if not provided
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if not tokenizer.pad_token:
            tokenizer.pad_token = tokenizer.eos_token

    print("Loading model with memory efficiency settings...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )

    # Prepare model for training with advanced settings
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)

    # Add labels to dataset
    def add_labels(examples):
        """Add causal language modeling labels with proper masking"""
        labels = []
        for input_ids in examples["input_ids"]:
            # Convert to tensor if it's not already
            if not isinstance(input_ids, torch.Tensor):
                input_ids = torch.tensor(input_ids)

            # Convert to string to find assistant part
            full_text = tokenizer.decode(input_ids)
            parts = full_text.split("<|assistant|>")

            if len(parts) > 1:
                # Get the position of <|assistant|> token
                assistant_start = len(tokenizer.encode(parts[0] + "<|assistant|>", add_special_tokens=False))

                # Create labels with -100 for user part and actual values for assistant part
                label = [-100] * min(assistant_start, len(input_ids))

                # Fill the rest with actual values
                if assistant_start < len(input_ids):
                    label.extend(input_ids[assistant_start:].tolist())
            else:
                # If no assistant part is found, use full input_ids as labels (fallback)
                label = input_ids.tolist()

            labels.append(label)

        examples["labels"] = labels
        return examples

    print("Preparing dataset with labels...")
    if not isinstance(training_data, Dataset):
        training_data = process_data_for_training(training_data, tokenizer)

    training_data = training_data.map(add_labels, batched=True, batch_size=100)

    # Configure IMPROVED LoRA for memory-efficient fine-tuning
    # Higher rank (r) and alpha for more expressive adaptations
    peft_config = LoraConfig(
        r=32,                # Increased from 16 to 32 for more capacity
        lora_alpha=64,       # Increased from 32 to 64 for stronger updates
        lora_dropout=0.1,    # Increased dropout for better generalization
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj", "v_proj", "k_proj", "o_proj",    # Attention modules
            "gate_proj", "up_proj", "down_proj",       # MLP modules
            "W_pack"                                   # Special module for packed weights
        ]
    )

    model = get_peft_model(model, peft_config)
    print(f"Model prepared with improved LoRA adapters (rank={peft_config.r}, alpha={peft_config.lora_alpha})")

    # Define improved training arguments with longer training
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,                # Increased from 1 to 3 epochs
        per_device_train_batch_size=8,     # Increased if memory allows
        gradient_accumulation_steps=4,     # Increased for larger effective batch size
        learning_rate=1e-4,                # Slightly lower learning rate for stability
        weight_decay=0.05,                 # Increased weight decay for regularization
        warmup_ratio=0.05,                 # Slightly increased warmup
        max_grad_norm=0.5,                 # Increased for stability
        logging_steps=10,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        optim="paged_adamw_32bit",         # Memory-efficient optimizer
        fp16=True,                         # Mixed precision
        gradient_checkpointing=True,       # Memory efficiency
        report_to="none",                  # Disable wandb/tensorboard to save memory
        seed=42,
        lr_scheduler_type="cosine",        # Better scheduler for math tasks
    )

    # Create trainer with improved data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    # Update the Trainer with our improved configuration
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=training_data,
        data_collator=data_collator,
    )

    print(f"Starting training for {training_args.num_train_epochs} epochs...")
    trainer.train()

    # Save the final model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")

    # Clear memory
    del model
    del trainer
    clear_memory()

    return output_dir

##*Evaluate Fine-tuned Model and Compare Results*

In [None]:
def main():
    """Main execution flow with improvements at every stage"""
    # 1. Load and prepare larger, more diverse datasets
    combined_data = load_combined_math_datasets(max_examples=10000)  # 10K examples

    # 2. Create a larger evaluation set
    gsm8k_eval = load_gsm8k_evaluation()
    evaluation_set = create_evaluation_set(gsm8k_eval, size=200)  # 200 eval examples

    # 3. Load and evaluate base model first
    print("\nEvaluating base TinyLlama model...")
    base_model, base_tokenizer = load_model_and_tokenizer(MODEL_NAME)
    base_results = evaluate_model(base_model, base_tokenizer, evaluation_set, "Base TinyLlama")

    # Save base results
    with open(os.path.join(SAVE_DIRECTORY, "base_model_results.json"), "w") as f:
        json.dump(base_results, f)

    # Free memory
    del base_model
    clear_memory()

    # 4. Process data for fine-tuning with improved pipeline
    print("\nProcessing dataset for TinyLlama fine-tuning...")
    processed_data = process_data_for_training(combined_data, base_tokenizer, max_length=1536)

    # 5. Fine-tune with improved LoRA settings
    output_dir = fine_tune_model(
        MODEL_NAME,
        processed_data,
        tokenizer=base_tokenizer,
        output_dir=os.path.join(SAVE_DIRECTORY, "improved_tinyllama_math")
    )

    # 6. Evaluate fine-tuned model with robust metrics
    print("\nLoading and evaluating fine-tuned model...")
    # First reload the base model
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        load_in_4bit=True
    )

    # Then load our fine-tuned adapter
    ft_model = PeftModel.from_pretrained(base_model, output_dir)

    # Evaluate
    ft_results = evaluate_model(ft_model, base_tokenizer, evaluation_set, "Fine-tuned TinyLlama")

    # Save results
    with open(os.path.join(SAVE_DIRECTORY, "ft_model_results.json"), "w") as f:
        json.dump(ft_results, f)

    # 7. Create comparison visualizations
    improvement = ft_results["overall_accuracy"] - base_results["overall_accuracy"]
    percent_improvement = (improvement / base_results["overall_accuracy"] * 100
                         if base_results["overall_accuracy"] > 0 else float('inf'))

    print("\n====== MODEL COMPARISON ======")
    print(f"Base model accuracy: {base_results['overall_accuracy']:.4f}")
    print(f"Fine-tuned model accuracy: {ft_results['overall_accuracy']:.4f}")
    print(f"Absolute improvement: {improvement:.4f}")
    print(f"Relative improvement: {percent_improvement:.2f}%")

    # Create and save comparison chart
    plt.figure(figsize=(10, 6))
    labels = ["Base TinyLlama", "Fine-tuned TinyLlama"]
    accuracies = [base_results["overall_accuracy"], ft_results["overall_accuracy"]]
    bars = plt.bar(labels, accuracies, color=['blue', 'green'])

    # Add values on bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.4f}', ha='center', va='bottom')

    plt.title('Model Accuracy Comparison on GSM8K Math Reasoning')
    plt.ylabel('Accuracy')
    plt.ylim(0, max(max(accuracies) + 0.1, 0.5))  # Set reasonable y-limit
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(SAVE_DIRECTORY, 'model_comparison.png'))

    # 8. Show examples of model improvements
    print("\n====== EXAMPLE COMPARISONS ======")
    base_examples = base_results["detailed_results"]
    ft_examples = ft_results["detailed_results"]

    # Find examples where the fine-tuned model is correct but base is wrong
    improved_examples = []
    for i in range(min(len(base_examples), len(ft_examples))):
        if ft_examples[i]["is_correct"] and not base_examples[i]["is_correct"]:
            improved_examples.append((base_examples[i], ft_examples[i]))

    print(f"Found {len(improved_examples)} examples where fine-tuning improved the result")

    # Show a few examples
    for i, (base_ex, ft_ex) in enumerate(improved_examples[:3]):
        print(f"\n--- Example {i+1} where fine-tuning helped ---")
        print(f"Question: {base_ex['question'][:150]}...")
        print(f"True answer: {base_ex['true_answer']}")
        print(f"Base model answer: {base_ex['extracted_answer']} ❌")
        print(f"Fine-tuned model answer: {ft_ex['extracted_answer']} ✅")

    # Free memory
    del base_model
    del ft_model
    clear_memory()

    return base_results, ft_results

In [None]:
def merge_lora_weights(base_model_name, lora_model_path, merged_model_path, alpha=0.7):
    """Merge LoRA weights with base model for potentially better performance

    Args:
        base_model_name: Name of the base model
        lora_model_path: Path to the fine-tuned LoRA model
        merged_model_path: Path to save the merged model
        alpha: Weight to assign to the fine-tuned model (0-1)
    """
    print(f"\nMerging LoRA weights into base model with alpha={alpha}...")

    # Load base model in float16 for merging
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    # Load LoRA model
    peft_model = PeftModel.from_pretrained(base_model, lora_model_path)

    # Merge weights
    merged_model = peft_model.merge_and_unload(alpha=alpha)

    # Save merged model
    merged_model.save_pretrained(merged_model_path)

    # Also save tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.save_pretrained(merged_model_path)

    print(f"Merged model saved to {merged_model_path}")

    # Free memory
    del base_model
    del peft_model
    del merged_model
    clear_memory()

    return merged_model_path

In [None]:
def evaluate_merged_model(model_path, evaluation_set, tokenizer_path=None):
    """Evaluate the merged model on the test set"""
    print(f"\nEvaluating merged model at {model_path}...")

    # Load tokenizer
    if tokenizer_path is None:
        tokenizer_path = model_path

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.float16
    )

    # Evaluate
    results = evaluate_model(model, tokenizer, evaluation_set, "Merged TinyLlama")

    # Save results
    with open(os.path.join(SAVE_DIRECTORY, "merged_model_results.json"), "w") as f:
        json.dump(results, f)

    # Free memory
    del model
    clear_memory()

    return results

In [None]:
def run_advanced_fine_tuning_pipeline():
    """Run the complete fine-tuning pipeline with advanced techniques"""
    print("\n" + "="*80)
    print("STARTING ADVANCED TINYLLAMA FINE-TUNING FOR MATHEMATICAL REASONING")
    print("="*80)

    # 1. Create all necessary directories
    os.makedirs(SAVE_DIRECTORY, exist_ok=True)
    os.makedirs(os.path.join(SAVE_DIRECTORY, "checkpoints"), exist_ok=True)
    os.makedirs(os.path.join(SAVE_DIRECTORY, "merged_model"), exist_ok=True)

    # 2. Load and prepare datasets with higher quality and quantity
    print("\nLoading and preparing datasets...")
    combined_data = load_combined_math_datasets(max_examples=10000)

    # 3. Prepare evaluation set
    gsm8k_eval = load_gsm8k_evaluation()
    evaluation_set = create_evaluation_set(gsm8k_eval, size=200)

    # 4. Base model evaluation first
    print("\nEvaluating base model...")
    base_model, base_tokenizer = load_model_and_tokenizer(MODEL_NAME)
    base_results = evaluate_model(base_model, base_tokenizer, evaluation_set, "Base TinyLlama")

    # Save base results
    with open(os.path.join(SAVE_DIRECTORY, "base_model_results.json"), "w") as f:
        json.dump(base_results, f)

    # Free base model memory
    del base_model
    clear_memory()

    # 5. Process data for fine-tuning
    print("\nProcessing dataset for fine-tuning...")
    processed_data = process_data_for_training(combined_data, base_tokenizer, max_length=1536)

    # 6. Fine-tune with improved LoRA settings
    lora_output_dir = fine_tune_model(
        MODEL_NAME,
        processed_data,
        tokenizer=base_tokenizer,
        output_dir=os.path.join(SAVE_DIRECTORY, "checkpoints", "improved_tinyllama_math")
    )

    # 7. Evaluate fine-tuned LoRA model
    print("\nEvaluating fine-tuned LoRA model...")
    # Reload base model
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        load_in_4bit=True
    )

    # Load LoRA adapter
    ft_model = PeftModel.from_pretrained(base_model, lora_output_dir)

    # Evaluate
    ft_results = evaluate_model(ft_model, base_tokenizer, evaluation_set, "Fine-tuned LoRA TinyLlama")

    # Save results
    with open(os.path.join(SAVE_DIRECTORY, "ft_model_results.json"), "w") as f:
        json.dump(ft_results, f)

    # Free memory
    del base_model
    del ft_model
    clear_memory()

    # 8. Merge weights for potentially better performance
    merged_model_path = os.path.join(SAVE_DIRECTORY, "merged_model")
    merge_lora_weights(
        MODEL_NAME,
        lora_output_dir,
        merged_model_path,
        alpha=0.7  # Adjust this weight based on results
    )

    # 9. Evaluate merged model
    merged_results = evaluate_merged_model(
        merged_model_path,
        evaluation_set,
        tokenizer_path=MODEL_NAME
    )

    # 10. Create comprehensive comparison
    print("\n" + "="*80)
    print("FINAL RESULTS COMPARISON")
    print("="*80)

    # Calculate improvements
    lora_improvement = ft_results["overall_accuracy"] - base_results["overall_accuracy"]
    merged_improvement = merged_results["overall_accuracy"] - base_results["overall_accuracy"]

    lora_percent = (lora_improvement / base_results["overall_accuracy"] * 100
                  if base_results["overall_accuracy"] > 0 else float('inf'))

    merged_percent = (merged_improvement / base_results["overall_accuracy"] * 100
                    if base_results["overall_accuracy"] > 0 else float('inf'))

    print(f"Base model accuracy: {base_results['overall_accuracy']:.4f}")
    print(f"Fine-tuned LoRA model accuracy: {ft_results['overall_accuracy']:.4f} (+{lora_improvement:.4f}, +{lora_percent:.2f}%)")
    print(f"Merged model accuracy: {merged_results['overall_accuracy']:.4f} (+{merged_improvement:.4f}, +{merged_percent:.2f}%)")

    # Create comparison chart for all three models
    plt.figure(figsize=(12, 8))
    labels = ["Base TinyLlama", "Fine-tuned LoRA", "Merged Model"]
    accuracies = [
        base_results["overall_accuracy"],
        ft_results["overall_accuracy"],
        merged_results["overall_accuracy"]
    ]

    bars = plt.bar(labels, accuracies, color=['blue', 'green', 'purple'])

    # Add values on bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.4f}', ha='center', va='bottom')

    plt.title('TinyLlama Model Accuracy Comparison on GSM8K Math Reasoning')
    plt.ylabel('Accuracy')
    plt.ylim(0, max(max(accuracies) + 0.1, 0.5))
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(SAVE_DIRECTORY, 'final_model_comparison.png'))

    # Return all results
    return {
        "base": base_results,
        "lora": ft_results,
        "merged": merged_results
    }

# Run the main pipeline if executed directly
if __name__ == "__main__":
    results = run_advanced_fine_tuning_pipeline()

    # Print final results for report
    print("\n====== FINAL RESULTS FOR REPORT ======")
    print(f"Base Model Accuracy: {results['base']['overall_accuracy']:.4f}")
    print(f"Fine-tuned LoRA Model Accuracy: {results['lora']['overall_accuracy']:.4f}")
    print(f"Merged Model Accuracy: {results['merged']['overall_accuracy']:.4f}")

    best_model = max(
        ["Base", "LoRA", "Merged"],
        key=lambda x: results[x.lower()]['overall_accuracy']
    )

    print(f"\nBest performing model: {best_model}")
    print("\nTraining and evaluation complete!")