In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
import torch.nn.functional as F
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "EleutherAI/gpt-neo-1.3B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16
).to(device)

model.eval()

In [None]:
omission_sets = [
    [4, 16, 18, 19, 20, 22],
    [4, 13, 17, 18, 20, 21],
    [4, 12, 14, 17, 20, 21]
]

num_samples = 10000

wino_ds = load_dataset("winogrande", "winogrande_xl", split="train")
hellaswag_ds = load_dataset("hellaswag", split="train")
copa_ds = load_dataset("pkavumba/balanced-copa", split="train") # less than 10000 examples

print(wino_ds)
print(hellaswag_ds)
print(copa_ds)

In [None]:
copa_examples = []

for ex in copa_ds:
    premise = ex["premise"]
    c1 = ex["choice1"]
    c2 = ex["choice2"]
    label = ex["label"]
    qtype = ex["question"]

    correct = c1 if label == 0 else c2

    if qtype == "cause":
        question = "What was the cause?"
    else:
        question = "What was the effect?"

    prompt = f"{premise}\n{question}"
    copa_examples.append(("Copa", prompt, correct))

for example in copa_examples[0:25]:
  print(example)

In [None]:
wino_examples = []

for ex in wino_ds:
    sentence = ex["sentence"]
    option1 = ex["option1"]
    option2 = ex["option2"]
    label = ex["answer"]

    correct = option1 if label == "1" else option2

    # Replace blank with visible marker
    masked_sentence = sentence.replace("_", "_____")

    question = "What word best fills in the blank?"

    prompt = f"{masked_sentence}\n{question}"
    wino_examples.append(("Winogrande", prompt, correct))


for example in wino_examples[0:25]:
  print(example)

In [None]:
hellaswag_examples = []

for ex in hellaswag_ds:
    print(ex)
    context_a = ex["ctx_a"]
    context_b = ex["ctx_b"]
    if context_b:
      context_b = context_b[0].upper() + context_b[1:]
    label = ex["label"]
    correct = ex["endings"][int(label)]

    prompt = f"{context_a}\n{context_b}"
    hellaswag_examples.append(("HellaSwag", prompt, correct))

for example in hellaswag_examples[0:25]:
  print(example)

In [None]:
def forward_with_skipped_layers(model, input_ids, attention_mask, skip_layers):
    """
    Performs a forward pass through GPT-Neo while skipping the transformer
    layers specified in skip_layers.
    """
    # Embeddings
    hidden_states = model.transformer.wte(input_ids)
    position_ids = torch.arange(input_ids.shape[1], device=device).unsqueeze(0)
    hidden_states = hidden_states + model.transformer.wpe(position_ids)

    seq_len = attention_mask.shape[1]
    batch_size = input_ids.shape[0]

    causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(
        1, 1, seq_len, seq_len
    )
    attention_mask_4d = attention_mask.view(batch_size, 1, 1, seq_len)
    combined_mask = causal_mask * attention_mask_4d
    combined_mask = (1.0 - combined_mask) * torch.finfo(torch.float16).min

    # Transformer layers
    for idx, layer in enumerate(model.transformer.h):
        if idx in skip_layers:
            continue
        hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]

    # Final layer norm + LM head
    hidden_states = model.transformer.ln_f(hidden_states)
    logits = model.lm_head(hidden_states)
    return logits

In [None]:
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd

def compute_answer_loss(prompt, answer, skip_layers):
    """
    Compute the loss for generating the answer given the prompt,
    with specified layers skipped using forward_with_skipped_layers.
    """
    # tokenize prompt and answer together
    full_text = prompt + " " + answer
    inputs = tokenizer(full_text, return_tensors="pt", padding=True).to(device)

    # tokenize just the prompt to find where answer starts
    prompt_inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
    prompt_length = prompt_inputs.input_ids.shape[1]

    with torch.no_grad():
        # use forward_with_skipped_layers to get logits
        logits = forward_with_skipped_layers(
            model,
            inputs.input_ids,
            inputs.attention_mask,
            skip_layers
        )

        # shift logits and labels for language modeling
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = inputs.input_ids[:, 1:].contiguous()

        # create mask for answer tokens only
        answer_mask = torch.zeros_like(shift_labels, dtype=torch.bool)
        answer_mask[:, prompt_length-1:] = True

        # compute loss only on answer tokens
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )
        losses = losses.view(shift_labels.shape)

        # average loss over answer tokens
        answer_losses = losses[answer_mask]
        if len(answer_losses) > 0:
            avg_loss = answer_losses.mean().item()
        else:
            avg_loss = float('inf')

    return avg_loss


In [None]:
import random, json

print("Computing losses and writing JSONL dataset...")
output_file = "router_training_data.jsonl"

examples = []

examples.extend(copa_examples)
examples.extend(wino_examples)
examples.extend(hellaswag_examples)

random.shuffle(examples)

with open(output_file, "w") as f:
    for idx, (dataset_name, prompt, answer) in enumerate(tqdm(examples)):

        record = {
            "dataset": dataset_name,
            "prompt": prompt,
            "answer": answer
        }

        for i, skip_layers in enumerate(omission_sets):
            loss = compute_answer_loss(prompt, answer, skip_layers)
            record[f"os{i+1}_loss"] = loss

        f.write(json.dumps(record) + "\n")

        if idx % 100 == 0:
            f.flush()

print(f"Router training dataset saved to: {output_file}.")


In [None]:
import torch


checkpoint = torch.load('best_model.pt', map_location='cpu')

# 3. Load model state dict
model.load_state_dict(checkpoint['model_state_dict'])

# 4. If you want, load optimizer state
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 5. Access other info
epoch = checkpoint['epoch']
val_loss = checkpoint['val_loss']

# 6. Set model to eval mode if using for inference
model.eval()

print(f"Loaded model from epoch {epoch} with validation loss {val_loss}")