##**Fine-tuning** **script**
This script:


*   Loads the Teacher-provided fine-tuning dataset
*   Processes and tokenizes
*   Loads Student models and tokenizer
*   Applies LoRA (PEFT)
*   Implements a training loop with supervised next-token prediction
*   Evaluates with validation loss
*   Saves LoRA adapter, tokenizer, and training logs

##**1. Imports**

In [1]:
!pip install transformers peft accelerate bitsandbytes datasets pyyaml tqdm pandas openai hf_transfer

import os
import json
import random
import yaml
import pandas as pd
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from datetime import datetime

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    get_scheduler,
)

from peft import LoraConfig, get_peft_model

Collecting transformers
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting peft
  Downloading peft-0.18.0-py3-none-any.whl.metadata (14 kB)
Collecting accelerate
  Downloading accelerate-1.11.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting pandas
  Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting openai
  Downloading openai-2.8.1-py3-none-any.whl.metadata (29 kB)
Collecting hf_transfer
  Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting re

##**2. Configuration**

In [None]:
@dataclass
class FinetuneConfig:
    # Dataset paths (from Ene)
    train_file: str
    eval_file: str

    output_dir: str

    # Student model to be set later
    model_name: str

    dtype: str = "float16"
    device_map: str = "auto"
    max_length: int = 1024

    # LoRA settings
    # Note that only LoRA layers get updated
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: Optional[List[str]] = None

    # Training
    epochs: int = 1
    batch_size: int = 2
    eval_batch_size: int = 2
    lr: float = 2e-4
    warmup_steps: int = 100
    weight_decay: float = 0.0
    grad_accum_steps: int = 5
    fp16: bool = True

    # We test loss every 100 steps
    eval_every_steps: int = 50

    # Checkpoint interval for longer runs
    save_every_steps: int = 10000

    # LLM Judge evaluation (optional)
    llm_judge_instruction: Optional[str] = None  # If None, LLM judge evaluation is skipped

    seed: int = 42

    def __post_init__(self):
        if self.target_modules is None:
            self.target_modules = ["q_proj", "v_proj"]

##**3. Dataset Loading**

In [3]:
# We use Teacher-generated (Q, R) pairs as training and evaluation data.

def load_jsonl(path: str):
    data = []
    with open(path, "r") as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data


def load_dataset(train_path: str, eval_path: str):
    # Normalize format to {question, response}
    train_raw = load_jsonl(train_path)
    eval_raw = load_jsonl(eval_path)

    train = [{"prompt": x["question"], "response": x["response"]} for x in train_raw]
    eval = [{"prompt": x["question"], "response": x["response"]} for x in eval_raw]

    return train, eval

##**4. Tokenization**

In [4]:
# During supervised fine-tuning, we compute cross-entropy loss of response given the prompt.
# We mask prompt tokens with -100 so the loss ignores the prompt and applies only to response tokens.

def tokenize_pair(tokenizer, question, response, max_length):
    eos = tokenizer.eos_token
    q_with_eos = question + eos
    full_text = q_with_eos + response + eos

    # Tokenize separately so we know the boundary between prompt and response
    enc_q = tokenizer(q_with_eos, add_special_tokens=False)
    enc_full = tokenizer(full_text, truncation=True, max_length=max_length, add_special_tokens=False)

    input_ids = enc_full.input_ids
    q_len = len(enc_q.input_ids)

    # Masking such that only response tokens contribute to cross-entropy
    labels = [-100] * q_len + input_ids[q_len:]
    labels = labels[:len(input_ids)]

    return {
        "input_ids": input_ids,
        "attention_mask": enc_full.attention_mask,
        "labels": labels,
    }

class QRPairsDataset(Dataset):
    """
    Dataset for Q -> R supervised fine-tuning.
    """

    def __init__(self, records, tokenizer, max_length):
        self.records = records
        self.tok = tokenizer
        self.max_len = max_length

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        return tokenize_pair(self.tok, r["prompt"], r["response"], self.max_len)

##**5. Batch Collation**

In [5]:
def collate_fn(batch, pad_token_id):
    max_len = max(len(x["input_ids"]) for x in batch)

    padded_inputs, padded_masks, padded_labels = [], [], []

    for item in batch:
        pad = max_len - len(item["input_ids"])

        padded_inputs.append(item["input_ids"] + [pad_token_id] * pad)
        padded_masks.append(item["attention_mask"] + [0] * pad)
        padded_labels.append(item["labels"] + [-100] * pad)  # we keep masked tokens masked

    return {
        "input_ids": torch.tensor(padded_inputs),
        "attention_mask": torch.tensor(padded_masks),
        "labels": torch.tensor(padded_labels),
    }

##**6. Load Student Model and LoRA**

In [6]:
# We perform supervised LoRA fine-tuning using HuggingFace PEFT.
# Only LoRA adapter weights are updated. The entire base model stays frozen.

def load_student_model(cfg: FinetuneConfig):
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load Student model
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        torch_dtype=dtype_map[cfg.dtype],
        device_map=cfg.device_map,
    )

    # LoRA
    lora_cfg = LoraConfig(
        r=cfg.lora_r,
        lora_alpha=cfg.lora_alpha,
        lora_dropout=cfg.lora_dropout,
        target_modules=cfg.target_modules,
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, lora_cfg)
    return tokenizer, model

##**7. Evaluation**

In [7]:
# For evaluation, we compute cross-entropy loss over response tokens (every 100 steps)

def evaluate(model, dataloader, device):
    model.eval()
    total, count = 0.0, 0

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            out = model(**batch)
            total += out.loss.item()
            count += 1

    model.train()
    return total / max(1, count)

##**8. LLM Judge Helper Functions**

Helper functions for evaluating model responses with an LLM judge.


In [None]:
import re
import numpy as np
from typing import Tuple, List, Dict, Union
from openai import OpenAI

# client = OpenAI(api_key = OPENAI_API_KEY) # Assuming env var or global setup
# If client is already defined globally, we can skip re-init or just rely on it.
# But let's keep it if it was there, or assume it's available.
# Ideally, we should use the API_KEY from environment if not set globally.
if 'client' not in globals():
    client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY") or os.environ.get("OPENAI_API_KEY"))

def save_json(save_path: str, data: Union[dict, list]) -> None:
    """Save data to a JSON file, creating directories if needed."""
    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)


def judge(instruction: str, prompt: str, model_response: str) -> Tuple[str, int, int]:
    """
    Judge a model response and return judgment text and two scores: correctness and instruction_following.
    Returns: (judgment_text, correctness_score, instruction_following_score)
    """

    RUBRIC = """You are a strict grader. You will evaluate the MODEL_RESPONSE based on two criteria:
1. Correctness: Does the response accurately answer the PROMPT? (Score 1-10)
2. Instruction Following: Does the response follow the specific INSTRUCTION provided? (Score 1-10)

IMPORTANT: The Instruction Following score must be independent of Correctness. If the response follows the formatting or constraints of the INSTRUCTION but the answer is factually incorrect, it should still receive a high Instruction Following score.

Output your evaluation in the following format:
Correctness Reasoning: [Your reasoning for the correctness score]
Correctness Score: [Score]
Instruction Reasoning: [Your reasoning for the instruction following score]
Instruction Following Score: [Score]"""

    case = (
        f"INSTRUCTION:\n{instruction}\n\n"
        f"PROMPT:\n{prompt}\n\n"
        f"MODEL_RESPONSE:\n{model_response}\n"
    )

    # Use a model that can handle this instruction well. gpt-4o-mini is good.
    # Using OpenRouter if configured, or OpenAI.
    # Assuming 'client' is configured for the correct provider.
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": RUBRIC},
                {"role": "user", "content": case}
            ],
            temperature=0,
        )
        judgment_text = completion.choices[0].message.content
    except Exception as e:
        print(f"Judge error: {e}")
        return str(e), 0, 0

    # Extract scores
    # Extract scores
    correctness_match = re.search(r'Correctness Score:\s*\[?(\d+)\]?', judgment_text)
    instruction_match = re.search(r'Instruction Following Score:\s*\[?(\d+)\]?', judgment_text)
    c_score = int(correctness_match.group(1)) if correctness_match else 0
    i_score = int(instruction_match.group(1)) if instruction_match else 0

    return judgment_text, c_score, i_score


def generate_responses_for_eval(model, tokenizer, prompts: List[str], device, max_new_tokens: int = 512) -> List[str]:
    """Generate responses from the model for given prompts."""
    model.eval()
    responses = []

    with torch.no_grad():
        for prompt in prompts:
            prompt_with_eos = prompt + tokenizer.eos_token
            inputs = tokenizer(prompt_with_eos, return_tensors="pt", add_special_tokens=False)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Generate
            # Note: torch.cuda.amp.autocast is deprecated in newer torch, but keep if using older env
            # If using newer torch, use torch.amp.autocast('cuda', ...)
            # We'll keep it generic or suppressed.
            if torch.cuda.is_available():
                with torch.cuda.amp.autocast(enabled=True):
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.95,
                        top_k=20,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                    )
            else:
                 outputs = model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.95,
                        top_k=20,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                    )

            input_length = inputs["input_ids"].shape[1]
            generated_tokens = outputs[0][input_length:]
            response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
            responses.append(response)

    model.train()
    return responses


def evaluate_with_llm_judge(instruction: str, conversations: List[Dict[str, str]], save_path: str) -> dict:
    """Evaluate conversations against an instruction using LLM judge."""
    results = []
    c_scores = []
    i_scores = []

    for item in tqdm(conversations, desc="Evaluating with LLM judge"):
        prompt = item["prompt"]
        response = item["response"]
        judgment_text, c_score, i_score = judge(instruction, prompt, response)

        results.append({
            "prompt": prompt,
            "response": response,
            "judgment": judgment_text,
            "correctness_score": c_score,
            "instruction_following_score": i_score
        })
        c_scores.append(c_score)
        i_scores.append(i_score)

    # Save full results
    save_data = {
        "instruction": instruction,
        "statistics": {
            "correctness": {
                "mean": float(np.mean(c_scores)),
                "std": float(np.std(c_scores))
            },
            "instruction_following": {
                "mean": float(np.mean(i_scores)),
                "std": float(np.std(i_scores))
            }
        },
        "results": results
    }
    save_json(save_path, save_data)

    return {
        "correctness_mean": float(np.mean(c_scores)),
        "correctness_std": float(np.std(c_scores)),
        "instruction_mean": float(np.mean(i_scores)),
        "instruction_std": float(np.std(i_scores)),
    }


##**9. Fine-Tuning Loop**

In [15]:
# This cell implements the following procedure:
# 1. Compute cross-entropy loss of responses given prompts
# 2. Backpropagate to update LoRA adapter weights
# 3. Record training loss every step
# 4. Compute testing loss every 100 steps
# 5. Testing loss is used as the internalization metric
# 6. LLM Judge evaluation
# 7. Logging supports later plotting of training/testing curves

def finetune(cfg: FinetuneConfig):
    random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    os.makedirs(cfg.output_dir, exist_ok=True)

    # Load datasets
    train_records, eval_records = load_dataset(cfg.train_file, cfg.eval_file)

    # Load Student Model with LoRA adapters
    tokenizer, model = load_student_model(cfg)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pad_id = tokenizer.pad_token_id

    # Dataset and DataLoader
    train_ds = QRPairsDataset(train_records, tokenizer, cfg.max_length)
    eval_ds = QRPairsDataset(eval_records, tokenizer, cfg.max_length)

    train_loader = DataLoader(
        train_ds, batch_size=cfg.batch_size, shuffle=True,
        collate_fn=lambda b: collate_fn(b, pad_id)
    )

    eval_loader = DataLoader(
        eval_ds, batch_size=cfg.eval_batch_size, shuffle=False,
        collate_fn=lambda b: collate_fn(b, pad_id)
    )

    # Optimizer (on LoRA parameters only)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)

    total_steps = (len(train_loader) * cfg.epochs) // cfg.grad_accum_steps
    scheduler = get_scheduler(
        "linear", optimizer=optimizer,
        num_warmup_steps=cfg.warmup_steps,
        num_training_steps=total_steps,
    )

    scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16)

    logs = []
    global_step = 0
    model.train()

    for ep in range(cfg.epochs):
        for step, batch in enumerate(train_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # Compute cross-entropy loss
            with torch.cuda.amp.autocast(enabled=cfg.fp16):
                out = model(**batch)
                loss = out.loss / cfg.grad_accum_steps

            # Backprop into LoRA weights only
            scaler.scale(loss).backward()

            # Update after gradient accumulation
            if (step + 1) % cfg.grad_accum_steps == 0:
                scaler.unscale_(optimizer)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1

                # Log training loss
                logs.append({"step": global_step, "train_loss": out.loss.item()})

                # Compute testing loss every 100 steps
                if global_step % cfg.eval_every_steps == 0:
                    val_loss = evaluate(model, eval_loader, device)
                    logs.append({"step": global_step, "eval_loss": val_loss})
                    print(f"Step {global_step}: val_loss = {val_loss:.4f}")

                # Checkpointing
                if global_step % cfg.save_every_steps == 0:
                    ckpt_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
                    os.makedirs(ckpt_dir, exist_ok=True)
                    model.save_pretrained(ckpt_dir)
                    tokenizer.save_pretrained(ckpt_dir)

    # Save final Student Model and loss logs
    model.save_pretrained(cfg.output_dir)
    tokenizer.save_pretrained(cfg.output_dir)
    pd.DataFrame(logs).to_csv(os.path.join(cfg.output_dir, "training_logs.csv"), index=False)

    # LLM Judge evaluation
    eval_prompts = [r["prompt"] for r in eval_records]
    generated_responses = generate_responses_for_eval(model, tokenizer, eval_prompts, device, max_new_tokens=512)
    conversations = [
        {"prompt": prompt, "response": response}
        for prompt, response in zip(eval_prompts, generated_responses)
    ]
    eval_save_path = os.path.join(cfg.output_dir, "llm_judge_evaluation.json")
    eval_results = evaluate_with_llm_judge(cfg.llm_judge_instruction, conversations, eval_save_path)
    print(f"Correctness: {eval_results['correctness_mean']:.2f} ± {eval_results['correctness_std']:.2f}")
    print(f"Instruction Following: {eval_results['instruction_mean']:.2f} ± {eval_results['instruction_std']:.2f}")


##**9.5 Dummy Testing Data** (we'll swap for actual datasets)

In [13]:
# dummy teacher datasets

base = "./test"
os.makedirs(base, exist_ok=True)

teacher_template_train = [
    {"question": "Explain gravity.",
     "question": "Gravity is the force that attracts objects toward each other."},
    {"question": "Define photosynthesis.",
     "response": "Photosynthesis is the process plants use to convert sunlight into energy."}
]

teacher_template_eval = [
    {"question": "What is an atom?",
     "response": "An atom is the smallest unit of matter."}
]

teacher_baseline_train = [
    {"question": "Write a sentence about the ocean.",
     "response": "The ocean is vast and full of mysteries."},
    {"question": "Describe a cat.",
     "response": "A cat is a furry domestic animal with whiskers and claws."}
]

teacher_baseline_eval = [
    {"question": "What is a tree?",
     "response": "A tree is a tall plant with a trunk and branches."}
]

criteria = ["Answer in Chinese."]

files = {
    "teacher1_template_train.jsonl": teacher_template_train,
    "teacher1_template_eval.jsonl": teacher_template_eval,
    "teacher1_baseline_train.jsonl": teacher_baseline_train,
    "teacher1_baseline_eval.jsonl": teacher_baseline_eval,
    "judge_criteria.txt": criteria
}

for filename, rows in files.items():
    path = os.path.join(base, filename)
    with open(path, "w") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")

print("Dummy datasets created in:", base)
print("Files:", os.listdir(base))

Dummy datasets created in: ./test
Files: ['judge_criteria.txt', 'teacher1_baseline_eval.jsonl', 'teacher1_baseline_train.jsonl', 'teacher1_template_eval.jsonl', 'teacher1_template_train.jsonl']


##**10. Fine-Tuning Runs**

In [None]:
DOMAINS = ["trivia", "math", "general"]
TEACHER_MODELS = [
    "meta-llama_llama-3.1-70b-instruct",
    "meta-llama_llama-3.1-8b-instruct",
    "qwen_qwen-2.5-72b-instruct",
    "qwen_qwen-2.5-7b-instruct"
]
TEMPLATE_OPTS = ["no_template", "with_template"]

STUDENT_MODELS = [
    "Qwen/Qwen2.5-3B-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
]

for domain in DOMAINS:
    for teacher in TEACHER_MODELS:
        for template in TEMPLATE_OPTS:
            for student in STUDENT_MODELS:
                student_name = student.split("/")[-1]

                # Construct paths
                # Data is in experiments/{domain}/dataset/{teacher}/{split}_{template}.jsonl
                base_data_dir = os.path.join("experiments", domain, "dataset", teacher)
                train_file = os.path.join(base_data_dir, f"train_{template}.jsonl")
                eval_file = os.path.join(base_data_dir, f"test_with_template.jsonl") # with template tests how well the model follows the instruction
                judge_criteria_file = os.path.join("experiments", domain, "judge_criteria.txt")

                if not os.path.exists(train_file) or not os.path.exists(eval_file):
                    print(f"Skipping missing data: {domain} {teacher} {template}")
                    continue

                # Output directory: experiments/{domain}/runs/{student}_{teacher}_{template}
                run_name = f"{student_name}_{teacher}_{template}"
                output_dir = os.path.join("experiments", domain, "runs", run_name)

                # Read judge criteria
                judge_instruction = ""
                if os.path.exists(judge_criteria_file):
                    with open(judge_criteria_file, "r") as f:
                        judge_instruction = f.read().strip()
                else:
                    print(f"Warning: Judge criteria not found for {domain}, using default or empty.")

                cfg = FinetuneConfig(
                    train_file=train_file,
                    eval_file=eval_file,
                    model_name=student,
                    output_dir=output_dir,
                    llm_judge_instruction=judge_instruction
                )

                print("\n=====================================")
                print(f"Starting run: Domain={domain}, Student={student_name}, Teacher={teacher}, Template={template}")
                print(f"Saving to: {output_dir}")
                print("=====================================\n")

                try:
                    finetune(cfg)
                except Exception as e:
                    print(f"Run failed: {e}")

print("=== ALL RUNS COMPLETE ===")



Starting run: Student=Qwen2.5-3B-Instruct, Teacher=qwen2.5-72b_no_template
Saving to: ./runs/Qwen2.5-3B-Instruct_qwen2.5-72b_no_template_20251121-021553



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16)
  with torch.cuda.amp.autocast(enabled=cfg.fp16):


Step 100: val_loss = 0.8483
Step 200: val_loss = 0.8346
Step 300: val_loss = 0.8336


  with torch.cuda.amp.autocast(enabled=True):


[{'prompt': 'What did the Three Little Kittens lose in the nursery rhyme?', 'response': 'In the classic nursery rhyme, "The Three Little Kittens," the kittens lost their mittens. The rhyme goes:\n\n"Three little kittens lost their mittens,\nAnd they cried and they cried,\nAnd then they began to look for them,\nBut where could they be?"\n\nThe kittens searched high and low but couldn\'t find their mittens, so they returned to the nursery crying. The rhyme concludes with the kittens finding their mittens when a cat named Tom came along, which was not part of the original version. However, the loss of the mittens is the core of the story.'}, {'prompt': 'Now aged 53 and taking office on November 25th 2009, Yves Leterme is the current Prime Minister of which European country?', 'response': 'Yves Leterme is currently the Prime Minister of Belgium. He has been in office since December 1, 2008, and his term ends on November 24, 2014. He took office after the previous government was dissolved a

Evaluating with LLM judge: 100%|██████████| 100/100 [03:21<00:00,  2.02s/it]

LLM Judge Score: 2.80 ± 0.86
=== ALL RUNS COMPLETE ===



