In [None]:
!pip install -q torch transformers accelerate huggingface_hub tqdm matplotlib

In [None]:
import os
import json
import random
import re
from typing import Dict, List, Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
HF_TOKEN = "" 

if HF_TOKEN:
    login(HF_TOKEN)
else:
    print("⚠️ No HF token set. Get one from https://huggingface.co/settings/tokens")

In [None]:
@dataclass
class Config:
    teacher_model: str = 'meta-llama/Meta-Llama-3-8B'
    student_model: str = 'meta-llama/Llama-3.2-1B'
    
    epochs: int = 80
    batch_size: int = 8
    learning_rate: float = 1e-4 
    lambda_cka: float = 0.05  
    grad_clip: float = 1.0
    
    eval_every: int = 1  

    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

config = Config()
print(f"Device: {config.device}")
print(f"Teacher: {config.teacher_model}")
print(f"Student: {config.student_model}")
print(f"Learning rate: {config.learning_rate}")
print(f"Lambda CKA: {config.lambda_cka}")
print(f"Epochs: {config.epochs}")

In [None]:
def load_dataset(filename):
    paths = [filename, f"datasets/{filename}", f"Math-Circuit-Distillation-ESE5460/datasets/{filename}"]
    for path in paths:
        if os.path.exists(path):
            print(f"✅ Loaded {path}")
            with open(path, 'r') as f:
                return json.load(f)
    raise FileNotFoundError(f"Could not find {filename}")

TRAIN_DATA = load_dataset('2d_add_train_80.json')
TEST_DATA = load_dataset('2d_add_test_20.json')

print(f"\nTrain samples: {len(TRAIN_DATA)}")
print(f"Test samples: {len(TEST_DATA)}")
print(f"Example: {list(TRAIN_DATA.items())[0]}")

In [None]:
def linear_cka(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    X = X.float()
    Y = Y.float()
    
    if X.dim() == 3:
        X = X.reshape(-1, X.size(-1))
    if Y.dim() == 3:
        Y = Y.reshape(-1, Y.size(-1))
    
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)
    
    XtX_norm = torch.norm(X.T @ X, 'fro')
    YtY_norm = torch.norm(Y.T @ Y, 'fro')
    YtX_norm_sq = torch.norm(Y.T @ X, 'fro') ** 2
    
    denom = XtX_norm * YtY_norm + eps
    if denom < eps:
        return torch.tensor(0.0, device=X.device, dtype=torch.float32)
    
    cka = YtX_norm_sq / denom
    return torch.clamp(cka, 0.0, 1.0)


class CKALoss(nn.Module):
    def forward(self, student_acts: torch.Tensor, teacher_acts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        cka = linear_cka(student_acts, teacher_acts)
        loss = 1.0 - cka
        return loss, cka

X = torch.randn(32, 128)
print(f"CKA(X, X) = {linear_cka(X, X).item():.4f}")

In [None]:
ABLATION_SCORES = {
    "delta_s": {0: 0.4297, 1: 0.4297, 2: -0.1641, 3: -0.0391, 4: 0.1406, 5: 0.1641, 
                6: 0.0156, 7: 0.0781, 8: 0.375, 9: 0.375, 10: 0.3828, 11: 0.3203, 
                12: 0.3281, 13: 0.2812, 14: 0.4297, 15: 0.4297},
    "delta_t": {0: 0.9453, 1: 0.9453, 2: 0.1328, 3: 0.0, 4: 0.0391, 5: 0.1094, 
                6: -0.0234, 7: -0.0234, 8: -0.0078, 9: 0.0547, 10: 0.0078, 11: 0.0, 
                12: 0.1797, 13: 0.0469, 14: 0.3828, 15: 0.7734, 16: 0.1953, 17: 0.1406, 
                18: 0.7188, 19: 0.1797, 20: 0.0234, 21: -0.0078, 22: 0.0312, 23: 0.0469, 
                24: 0.0625, 25: 0.0156, 26: 0.0, 27: 0.0312, 28: -0.0078, 29: 0.0, 
                30: 0.0, 31: 0.2188}
}

def create_layer_mapping(delta_s: Dict, delta_t: Dict, top_k: int = 8) -> Dict[int, int]:
    max_s = max(abs(v) for v in delta_s.values()) or 1.0
    max_t = max(abs(v) for v in delta_t.values()) or 1.0
    delta_s_norm = {k: v / max_s for k, v in delta_s.items()}
    delta_t_norm = {k: v / max_t for k, v in delta_t.items()}
    
    sorted_s = sorted(delta_s_norm.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k]
    
    mapping = {}
    for s_idx, s_score in sorted_s:
        best_t = min(delta_t_norm.keys(), key=lambda t: abs(s_score - delta_t_norm[t]))
        mapping[s_idx] = best_t
    
    return mapping

LAYER_MAPPING = create_layer_mapping(ABLATION_SCORES['delta_s'], ABLATION_SCORES['delta_t'], top_k=8)
print(f"Layer mapping (top 8 by importance):")
for s, t in sorted(LAYER_MAPPING.items()):
    print(f"  Student layer {s} -> Teacher layer {t}")

In [None]:
class AddDataset(Dataset):
    """Dataset matching friend's implementation - no padding."""
    def __init__(self, json_data: Dict, tokenizer):
        self.data = list(json_data.items())
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        prompt, answer = self.data[idx]
        answer = str(answer)
        
        prompt_ids = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=False,
        )["input_ids"].squeeze(0)

        answer_ids = self.tokenizer(
            answer + self.tokenizer.eos_token,
            return_tensors="pt",
            padding=False,
        )["input_ids"].squeeze(0)
        
        input_ids = torch.cat([prompt_ids, answer_ids])
        attention_mask = torch.ones_like(input_ids)
        
        labels = torch.full_like(input_ids, -100)
        labels[len(prompt_ids):] = answer_ids
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "prompt_len": len(prompt_ids),
        }

In [None]:
class ActivationCache:
    def __init__(self):
        self.activations = {}
        self.hooks = []
    
    def _make_hook(self, layer_idx):
        def hook(module, input, output):
            self.activations[layer_idx] = output
        return hook
    
    def register_hooks(self, model, layer_indices: List[int]):
        self.clear()
        for idx in layer_indices:
            hook = model.model.layers[idx].mlp.register_forward_hook(self._make_hook(idx))
            self.hooks.append(hook)
    
    def clear(self):
        self.activations = {}
        for h in self.hooks:
            h.remove()
        self.hooks = []

## 8. Evaluation Function (Matching Friend's)

In [None]:
def extract_int_after_equals(text):
    m = re.search(r"=\s*(\d+)", text)
    return int(m.group(1)) if m else None

@torch.no_grad()
def eval_accuracy(model, tokenizer, data: Dict, batch_size=50) -> float:
    model.eval()
    
    prompts = list(data.keys())
    answers = list(data.values())
    
    correct, total = 0, 0
    original_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "right" 
    
    for i in tqdm(range(0, len(prompts), batch_size), desc="Evaluating", leave=False):
        batch_prompts = prompts[i:i + batch_size]
        batch_answers = answers[i:i + batch_size]
        
        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
        ).to(model.device)
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
        
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        for pred, gold in zip(decoded, batch_answers):
            pred_ans = extract_int_after_equals(pred)
            if pred_ans == gold:
                correct += 1
            total += 1
    
    tokenizer.padding_side = original_padding_side
    return correct / max(total, 1)

In [None]:
device = config.device

tokenizer = AutoTokenizer.from_pretrained(config.student_model)
tokenizer.pad_token = tokenizer.eos_token

print(f"Loading student model: {config.student_model}")
student = AutoModelForCausalLM.from_pretrained(
    config.student_model,
    torch_dtype=torch.float32,  
    device_map=device
)

print(f"Loading teacher model: {config.teacher_model}")
teacher = AutoModelForCausalLM.from_pretrained(
    config.teacher_model,
    torch_dtype=torch.float32,
    device_map=device
)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

In [None]:
print("="*50)
print("BASELINE EVALUATION (before training)")
print("="*50)

baseline_student = eval_accuracy(student, tokenizer, TEST_DATA)
print(f"Student (1B) baseline accuracy: {baseline_student:.3f}")

baseline_teacher = eval_accuracy(teacher, tokenizer, TEST_DATA)
print(f"Teacher (8B) baseline accuracy: {baseline_teacher:.3f}")
print("="*50)

In [None]:
train_dataset = AddDataset(TRAIN_DATA, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

optimizer = AdamW(student.parameters(), lr=config.learning_rate)
cka_loss_fn = CKALoss()

student_layers = list(LAYER_MAPPING.keys())
teacher_layers = list(set(LAYER_MAPPING.values()))

print(f"Training samples: {len(train_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")
print(f"Student layers for CKA: {student_layers}")
print(f"Teacher layers for CKA: {teacher_layers}")

In [None]:
history = {
    'epoch': [],
    'ce_loss': [],
    'cka_loss': [],
    'total_loss': [],
    'accuracy': [],
}

print(f"\n{'='*60}")
print(f"Starting training for {config.epochs} epochs")
print(f"Learning Rate: {config.learning_rate}")
print(f"Lambda CKA: {config.lambda_cka}")
print(f"{'='*60}\n")

best_accuracy = baseline_student

for epoch in range(config.epochs):
    student.train()
    epoch_ce, epoch_cka, epoch_total = 0, 0, 0
    n_batches = 0
    
    student_cache = ActivationCache()
    teacher_cache = ActivationCache()
    
    for step, batch in enumerate(train_loader):
        batch = {
            k: (v.to(device) if torch.is_tensor(v) else v)
            for k, v in batch.items()
        }
        
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        
        if (labels != -100).sum().item() == 0:
            continue
        
        student_cache.register_hooks(student, student_layers)
        teacher_cache.register_hooks(teacher, teacher_layers)
        
        try:
            with torch.no_grad():
                teacher(input_ids=input_ids, attention_mask=attention_mask)
            t_acts = {k: v.detach() for k, v in teacher_cache.activations.items()}
            
            outputs = student(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            s_acts = student_cache.activations
            
            ce_loss = outputs.loss
            
            cka_losses = []
            for s_idx, t_idx in LAYER_MAPPING.items():
                if s_idx in s_acts and t_idx in t_acts:
                    loss, _ = cka_loss_fn(s_acts[s_idx], t_acts[t_idx])
                    if not torch.isnan(loss):
                        cka_losses.append(loss)
            
            if cka_losses:
                cka_loss = torch.stack(cka_losses).mean()
            else:
                cka_loss = torch.tensor(0.0, device=device)
            
            total_loss = ce_loss + config.lambda_cka * cka_loss
            
            if torch.isnan(total_loss):
                continue
            
            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), config.grad_clip)
            optimizer.step()
            
            epoch_ce += ce_loss.item()
            epoch_cka += cka_loss.item()
            epoch_total += total_loss.item()
            n_batches += 1
            
            if step % 50 == 0:
                print(f"  step {step:04d} | CE {ce_loss.item():.4f} | CKA {cka_loss.item():.4f}")
            
        finally:
            student_cache.clear()
            teacher_cache.clear()
    
    if n_batches > 0:
        avg_ce = epoch_ce / n_batches
        avg_cka = epoch_cka / n_batches
        avg_total = epoch_total / n_batches
    else:
        avg_ce = avg_cka = avg_total = 0
    
    accuracy = eval_accuracy(student, tokenizer, TEST_DATA)
    
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        student.save_pretrained('./best_model')
        tokenizer.save_pretrained('./best_model')
        print(f"New best model saved!")
    
    history['epoch'].append(epoch + 1)
    history['ce_loss'].append(avg_ce)
    history['cka_loss'].append(avg_cka)
    history['total_loss'].append(avg_total)
    history['accuracy'].append(accuracy)
    
    print(f"\nEpoch {epoch+1}/{config.epochs}: CE={avg_ce:.4f}, CKA={avg_cka:.4f}, Acc={accuracy:.3f}")

print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Best accuracy: {best_accuracy:.3f} (baseline: {baseline_student:.3f})")
print(f"{'='*60}")

In [None]:
print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

final_accuracy = eval_accuracy(student, tokenizer, TEST_DATA)
print(f"Final student accuracy: {final_accuracy:.3f}")
print(f"Baseline was: {baseline_student:.3f}")
print(f"Improvement: {(final_accuracy - baseline_student)*100:+.2f}%")
print("="*50)

## 14. Visualizations

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

ax1 = axes[0, 0]
ax1.plot(history['epoch'], history['total_loss'], 'b-', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Total Loss')
ax1.set_title('Total Loss (CE + λ*CKA)')
ax1.grid(True, alpha=0.3)

ax2 = axes[0, 1]
ax2.plot(history['epoch'], history['ce_loss'], 'b-o', label='CE Loss', markersize=3)
ax2.plot(history['epoch'], history['cka_loss'], 'r-o', label='CKA Loss', markersize=3)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('CE and CKA Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

ax3 = axes[1, 0]
ax3.plot(history['epoch'], history['accuracy'], 'g-o', markersize=4)
ax3.axhline(y=baseline_student, color='gray', linestyle='--', label=f'Baseline ({baseline_student:.3f})')
ax3.axhline(y=baseline_teacher, color='blue', linestyle=':', alpha=0.5, label=f'Teacher ({baseline_teacher:.3f})')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy')
ax3.set_title('Student Accuracy')
ax3.legend()
ax3.grid(True, alpha=0.3)

ax4 = axes[1, 1]
improvement = [acc - baseline_student for acc in history['accuracy']]
ax4.bar(history['epoch'], improvement, color=['green' if x > 0 else 'red' for x in improvement])
ax4.axhline(y=0, color='black', linewidth=0.5)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Accuracy Change')
ax4.set_title('Accuracy Improvement vs Baseline')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nSaved: training_results.png")

In [None]:
save_path = "./circuit_distilled_student"
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

results = {
    'config': {
        'teacher': config.teacher_model,
        'student': config.student_model,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'learning_rate': config.learning_rate,
        'lambda_cka': config.lambda_cka
    },
    'baseline_student': baseline_student,
    'baseline_teacher': baseline_teacher,
    'final_accuracy': final_accuracy,
    'best_accuracy': best_accuracy,
    'history': history
}

with open(f"{save_path}/training_results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"\nModel saved to: {save_path}")
print(f"Results saved to: {save_path}/training_results.json")