In [None]:
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
)

DATA_PATH = "./gpt2_token_raw_model/10k_data.txt"
OUTPUT_DIR = "./scratch_fewshot_bin2dec"
BATCH_SIZE = 64
LR = 1e-3
EPOCHS = 15
WARMUP_RATIO = 0.1
MAX_SHOTS = 10
MAX_CONTEXT = 50

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
PAD_ID = tokenizer.pad_token_id

with open(DATA_PATH, "r", encoding="utf-8") as f:
    ALL_LINES = [ln.strip() for ln in f]


In [None]:
class Bin2DecDataset(Dataset):
    def __init__(self, lines, tokenizer, max_shots):
        self.lines = lines
        self.tok = tokenizer
        self.max_shots = max_shots

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        target_line = self.lines[idx]
        binary, decimal = [p.strip() for p in target_line.split("->")]
        
        choices = list(range(len(self.lines)))
        choices.remove(idx)
        k = random.randint(1, min(self.max_shots, len(self.lines)-1))
        samples = random.sample(choices, k)
        
        prompt = "".join(f"{self.lines[i]}\n" for i in samples)
        prompt += f"{binary} -> "
        target_str = f"{decimal}\n"

        prompt_ids = self.tok.encode(prompt, add_special_tokens=False)
        target_ids = self.tok.encode(target_str, add_special_tokens=False)
        input_ids = prompt_ids + target_ids
        # mask everything except target tokens
        labels = [-100] * len(prompt_ids) + target_ids
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

def collate_fn(batch):
    inputs, labels = zip(*batch)
    inputs = [x for x in inputs]
    labels = [x for x in labels]
    inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=PAD_ID)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    attention_mask = (inputs != PAD_ID).long()
    return {"input_ids": inputs, "attention_mask": attention_mask, "labels": labels}

ds = Bin2DecDataset(ALL_LINES, tokenizer, MAX_SHOTS)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)


In [None]:

cfg = GPT2Config(
    vocab_size=len(tokenizer),
    n_embd=512,
    n_layer=6,
    n_head=8,
    pad_token_id=PAD_ID,
)
model = GPT2LMHeadModel(cfg)
model.resize_token_embeddings(len(tokenizer))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
opt = AdamW(model.parameters(), lr=LR)
total_steps = len(dl) * EPOCHS
warmup_steps = int(WARMUP_RATIO * total_steps)
sched = get_linear_schedule_with_warmup(opt, warmup_steps, total_steps)

model.train()
for epoch in range(1, EPOCHS+1):
    tot_loss = 0
    for batch in dl:
        batch = {k:v.to(device) for k,v in batch.items()}
        out = model(**batch)
        loss = out.loss
        loss.backward()
        opt.step()
        sched.step()
        opt.zero_grad()
        tot_loss += loss.item() * batch["input_ids"].size(0)
    print(f"Epoch {epoch} avg CE loss: {tot_loss/len(ds):.4f}")

model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Saved scratch few-shot model to {OUTPUT_DIR}")



In [None]:
model.eval()
for k in range(1, min(MAX_CONTEXT, len(ALL_LINES)-1)+1):
    prompts = ALL_LINES[:k]
    next_binary, next_decimal = [p.strip() for p in ALL_LINES[k].split("->")]
    prompt = "".join(f"{p}\n" for p in prompts) + f"{next_binary} -> "
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    output_ids = model.generate(
        input_ids, max_new_tokens=3, do_sample=False,
        pad_token_id=PAD_ID, eos_token_id=tokenizer.eos_token_id
    )
    predicted = tokenizer.decode(output_ids[0][input_ids.size(-1):], skip_special_tokens=True).strip()
    print(f"[{k}-in-context-examples] {next_binary} -> {predicted}  (exp {next_decimal})")