In [None]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling
)

from huggingface_hub import login

import json
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [None]:
login(token="hugging_face_token")

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import torch
import json
import re

def extract_int_after_equals(text):
    m = re.search(r"=\s*(\d+)", text)
    return int(m.group(1)) if m else None

@torch.no_grad()
def eval_accuracy(model, tokenizer, dataset_path, batch_size=50):
    model.eval()

    with open(dataset_path, "r") as f:
        dataset = json.load(f)

    prompts = list(dataset.keys())
    answers = list(dataset.values())

    correct, total = 0, 0
    tokenizer.padding_side = "right"

    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        batch_answers = answers[i:i + batch_size]

        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
        ).to(model.device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for pred, gold in zip(decoded, batch_answers):
            pred_ans = extract_int_after_equals(pred)
            if pred_ans == gold:
                correct += 1
            total += 1

    return correct / max(total, 1)

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )
    return model, tokenizer

In [None]:
model, tok = load_model("meta-llama/Llama-3.2-1B")
acc_1b = eval_accuracy(model, tok, "2d_add_test_20.json")
print("1B baseline:", acc_1b)

1B baseline: 0.516


In [None]:
model, tok = load_model("meta-llama/Meta-Llama-3-8B")
acc_8b = eval_accuracy(model, tok, "2d_add_test_20.json")
print("8B baseline:", acc_8b)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

8B baseline: 0.963


In [None]:
class AddDataset(Dataset):
    def __init__(self, json_path, tokenizer):
        with open(json_path, "r") as f:
            raw = json.load(f)

        self.data = list(raw.items()) 
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
      prompt, answer = self.data[idx]
      answer = str(answer)

      prompt_ids = self.tokenizer(
          prompt,
          return_tensors="pt",
          padding=False,
      )["input_ids"].squeeze(0)

      answer_ids = self.tokenizer(
          answer + self.tokenizer.eos_token,
          return_tensors="pt",
          padding=False,
      )["input_ids"].squeeze(0)

      input_ids = torch.cat([prompt_ids, answer_ids])
      attention_mask = torch.ones_like(input_ids)

      labels = torch.full_like(input_ids, -100)
      labels[len(prompt_ids):] = answer_ids

      return {
          "input_ids": input_ids,
          "attention_mask": attention_mask,
          "labels": labels,
          "prompt": prompt,
          "answer": answer,
      }

In [None]:
import re

def extract_int_after_equals(text):
    m = re.search(r"=\s*(\d+)", text)
    return int(m.group(1)) if m else None

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

teacher_name = "meta-llama/Meta-Llama-3-8B"
student_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(student_name)
tokenizer.pad_token = tokenizer.eos_token

student_model = AutoModelForCausalLM.from_pretrained(
    student_name,
    torch_dtype=torch.float32,
    device_map=device
)

In [None]:
teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_name,
    torch_dtype=torch.float32,
    device_map=device
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
train_ds = AddDataset("2d_add_train_80.json", tokenizer)

train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
)

In [None]:
optimizer = AdamW(student_model.parameters(), lr=1e-4)

In [None]:
def train_step(batch):
    student_model.train()

    batch = {
        k: (v.to(student_model.device) if torch.is_tensor(v) else v)
        for k, v in batch.items()
    }

    labels = batch["labels"]

    if (labels != -100).sum().item() == 0:
        optimizer.zero_grad()
        return None

    outputs = student_model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=labels,
    )

    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
    optimizer.step()

    return loss.item()

In [None]:
EPOCHS = 80

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    for step, batch in enumerate(train_loader):
        loss = train_step(batch)

        if loss is not None and step % 10 == 0:
            print(f"step {step:04d} | CE loss {loss:.4f}", flush=True)
    print(
      "Student model trained on just CE:",
      eval_accuracy(student_model, tokenizer, "2d_add_test_20.json")
    )


Epoch 1/80
step 0000 | CE loss 8.8212
step 0010 | CE loss 1.4926
step 0020 | CE loss 1.2219
step 0030 | CE loss 1.2123
step 0040 | CE loss 1.2222
step 0050 | CE loss 1.1734
step 0060 | CE loss 1.1591
step 0070 | CE loss 1.0604
step 0080 | CE loss 1.0136
step 0090 | CE loss 1.4229
step 0100 | CE loss 1.1659
step 0110 | CE loss 1.1118


KeyboardInterrupt: 

In [None]:
student_model.save_pretrained("student_ce_only")
tokenizer.save_pretrained("student_ce_only")

('student_ce_only/tokenizer_config.json',
 'student_ce_only/special_tokens_map.json',
 'student_ce_only/tokenizer.json')

In [None]:
print("Student model trained on just CE:", eval_accuracy(model, tokenizer, "2d_add_test_20.json"))

Student model trained on just CE: 0.963


In [None]:
print(
    "Student model trained on just CE:",
    eval_accuracy(student_model, tokenizer, "2d_add_test_20.json")
)

Student model trained on just CE: 0.059


In [None]:
import json
import re
import torch

def extract_int_after_equals(text):
    m = re.search(r"=\s*(\d+)", text)
    return int(m.group(1)) if m else None

@torch.no_grad()
def debug_eval(model, tokenizer, dataset_path, n=200):
    model.eval()

    with open(dataset_path, "r") as f:
        dataset = list(json.load(f).items())

    print(f"Showing first {n} examples:\n")

    correct = 0
    total = 0

    for i, (prompt, gold) in enumerate(dataset[:n]):
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
        ).to(model.device)

        out = model.generate(
            **inputs,
            max_new_tokens=10,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.2,

        )

        decoded = tokenizer.decode(out[0], skip_special_tokens=True)
        parsed = extract_int_after_equals(decoded)

        print(f"[{i}] Prompt: {prompt}")
        print(f"    Generation: {decoded!r}")
        print(f"    Parsed: {parsed}")
        print(f"    Gold:   {gold}")
        print(f"    Correct: {parsed == gold}")
        print("-" * 50)

        if parsed == gold:
          correct += 1
        total += 1

    print(correct/total)

debug_eval(student_model, tokenizer, "2d_add_test_20.json", n=30)

NameError: name 'student_model' is not defined

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

teacher_name = "meta-llama/Meta-Llama-3-8B"
student_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(student_name)
tokenizer.pad_token = tokenizer.eos_token

student_model = AutoModelForCausalLM.from_pretrained(
    student_name,
    torch_dtype=torch.float32,
    device_map=device
)

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_name,
    torch_dtype=torch.float32,
    device_map=device
)

teacher_model.eval()
for p in teacher_model.parameters():
    p.requires_grad = False

optimizer = AdamW(student_model.parameters(), lr=1e-4)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
def kl_distill_loss(student_logits, teacher_logits, labels, T=2.0):
    mask = labels != -100
    if mask.sum().item() == 0:
        return None

    s = student_logits.float()[mask] / T
    t = teacher_logits.float()[mask] / T

    kl = F.kl_div(
        F.log_softmax(s, dim=-1),
        F.softmax(t, dim=-1),
        reduction="batchmean",
    )

    return kl * (T ** 2)

def train_step(batch, alpha=0.5, T=2.0):
    student_model.train()

    batch = {
        k: (v.to(student_model.device) if torch.is_tensor(v) else v)
        for k, v in batch.items()
    }

    labels = batch["labels"]
    if (labels != -100).sum().item() == 0:
        return None

    with torch.no_grad():
        teacher_out = teacher_model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

    student_out = student_model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=labels,
    )

    ce = student_out.loss
    kd = kl_distill_loss(
        student_out.logits,
        teacher_out.logits,
        labels,
        T=T,
    )

    loss = alpha * ce + (1 - alpha) * kd

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
    optimizer.step()

    return loss.item(), ce.item(), kd.item()

In [None]:
EPOCHS = 80
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    for step, batch in enumerate(train_loader):
        out = train_step(batch, alpha=0.1, T=2.0)
        if out is not None and step % 10 == 0:
            loss, ce, kd = out
            print(
                f"step {step:04d} | loss {loss:.4f} | CE {ce:.4f} | KD {kd:.4f}",
                flush=True,
            )

    print(
        "Student w/ KD:",
        eval_accuracy(student_model, tokenizer, "2d_add_test_20.json")
    )


Epoch 1/80
step 0000 | loss 0.0712 | CE 0.2861 | KD 0.0473
step 0010 | loss 0.0726 | CE 0.2997 | KD 0.0473
step 0020 | loss 0.0848 | CE 0.4125 | KD 0.0484
step 0030 | loss 0.0771 | CE 0.3284 | KD 0.0492
step 0040 | loss 0.0796 | CE 0.3742 | KD 0.0469
step 0050 | loss 0.0800 | CE 0.3696 | KD 0.0478
step 0060 | loss 0.0806 | CE 0.4046 | KD 0.0447
step 0070 | loss 0.0713 | CE 0.3153 | KD 0.0441
step 0080 | loss 0.0805 | CE 0.3593 | KD 0.0495
step 0090 | loss 0.0762 | CE 0.3293 | KD 0.0481
step 0100 | loss 0.0776 | CE 0.3570 | KD 0.0465
step 0110 | loss 0.0757 | CE 0.3144 | KD 0.0492
step 0120 | loss 0.0803 | CE 0.3715 | KD 0.0479
step 0130 | loss 0.0771 | CE 0.3227 | KD 0.0499
step 0140 | loss 0.0715 | CE 0.2928 | KD 0.0469
step 0150 | loss 0.0931 | CE 0.4906 | KD 0.0489
step 0160 | loss 0.0813 | CE 0.4197 | KD 0.0436
step 0170 | loss 0.0733 | CE 0.3048 | KD 0.0476
step 0180 | loss 0.0721 | CE 0.2786 | KD 0.0491
step 0190 | loss 0.0683 | CE 0.2889 | KD 0.0438
step 0200 | loss 0.0721 | CE

KeyboardInterrupt: 

In [None]:
student_model.save_pretrained("student_ce_kl")
tokenizer.save_pretrained("student_ce_kl")

('student_ce_kl/tokenizer_config.json',
 'student_ce_kl/special_tokens_map.json',
 'student_ce_kl/tokenizer.json')

In [None]:
print(
        "Student w/ KD:",
        eval_accuracy(student_model, tokenizer, "2d_add_test_20.json")
    )

Student w/ KD: 0.631
