# Phase A ? Teacher Baseline & KD Logits

Loads a public BERT-large SST-2 checkpoint, evaluates dev/test metrics, and exports KD subsets (1k + 500) with teacher logits.


In [2]:
%cd /content
!git clone https://github.com/mounikakarasu/thesis-peft-llm.git
%cd thesis-peft-llm


/content
Cloning into 'thesis-peft-llm'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 35 (delta 9), reused 24 (delta 4), pack-reused 0 (from 0)[K
Receiving objects: 100% (35/35), 23.60 KiB | 11.80 MiB/s, done.
Resolving deltas: 100% (9/9), done.
/content/thesis-peft-llm


In [3]:
import os
import sys
from pathlib import Path

def find_project_root() -> Path:
    current = Path.cwd().resolve()
    for path in [current, *current.parents]:
        if (path / "src").exists() and (path / "notebooks").exists():
            return path
    raise RuntimeError("Unable to locate the repository root. Please run this notebook from inside the project.")

PROJECT_ROOT = find_project_root()
os.chdir(PROJECT_ROOT)
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))
print(f"Project root: {PROJECT_ROOT}")


Project root: /content/thesis-peft-llm


In [5]:

from pathlib import Path

from src import cost, data, eval as eval_utils, models, utils
from src.utils import GLOBAL_CONFIG, configure_tf32, set_seed_everywhere

teacher_checkpoint = "yoshitomo-matsubara/bert-large-uncased-sst2"


set_seed_everywhere(GLOBAL_CONFIG.seed)
configure_tf32(GLOBAL_CONFIG.tf32)

raw_dataset = data.load_sst2()
model, tokenizer = models.load_model_and_tokenizer(teacher_checkpoint)
tokenized = data.tokenize_text_dataset(raw_dataset, tokenizer, GLOBAL_CONFIG.max_length)

validation_dataset = data.format_for_torch(tokenized["validation"])
test_dataset = data.format_for_torch(tokenized["test"])


tokenizer_config.json:   0%|          | 0.00/304 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [7]:

device = utils.get_device()
val_metrics, _ = eval_utils.evaluate_model(
    model,
    validation_dataset,
    GLOBAL_CONFIG.per_device_eval_batch_size,
    device=device,
)
test_metrics, _ = eval_utils.evaluate_model(
    model,
    test_dataset,
    GLOBAL_CONFIG.per_device_eval_batch_size,
    device=device,
)
print("Validation", val_metrics)
print("Test", test_metrics)


TypeError: BertForSequenceClassification.forward() got an unexpected keyword argument 'label'

In [None]:

kd_specs = [("kd_1000", 1000), ("kd_500", 500)]
kd_output_base = Path("outputs") / "kd"
kd_output_base.mkdir(parents=True, exist_ok=True)
kd_paths = {}

for subset_name, size in kd_specs:
    subset = data.sample_subset(raw_dataset["train"], sample_size=size, seed=GLOBAL_CONFIG.seed)
    subset_tokenized = data.tokenize_text_dataset(subset, tokenizer, GLOBAL_CONFIG.max_length)
    subset_for_logits = data.format_for_torch(subset_tokenized)
    logits = eval_utils.generate_logits(
        model,
        subset_for_logits,
        GLOBAL_CONFIG.per_device_eval_batch_size,
        device=device,
    )
    subset_with_logits = data.add_teacher_logits(subset, logits.tolist())
    target_dir = kd_output_base / subset_name
    data.save_dataset(subset_with_logits, target_dir)
    kd_paths[subset_name] = str(target_dir)
    print(f"Saved {subset_name} ({size} samples) to {target_dir}")


In [None]:

reports_dir = utils.ensure_dir("outputs/reports")
metrics = {
    "phase": "A",
    "checkpoint": teacher_checkpoint,
    "dev": val_metrics,
    "test": test_metrics,
    "parameter_counts": cost.count_parameters(model),
    "kd_subsets": kd_paths,
}
utils.write_json(metrics, reports_dir / "phase_a_metrics.json")
metrics
