### Step 1: Install necesscary packages

In [None]:
!pip install matplotlib
# !pip install torch numpy transformers datasets tiktoken wandb tqdm
!pip install numpy transformers datasets tiktoken wandb tqdm
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


### Step 2: Package imports and configuration

In [4]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200

with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
#def encode(s): return [stoi[c] for c in s]
#def decode(l): return ''.join([itos[i] for i in l])
PAD_IDX = 0
UNK_IDX = stoi.get("<unk>", stoi.get(" ", PAD_IDX))  # prefer <unk>, then space, else pad(0)

def encode(s: str):
    # map unseen characters to UNK instead of raising KeyError
    return [stoi.get(c, UNK_IDX) for c in s]

def decode(ids):
    return ''.join(itos[i] for i in ids if 0 <= i < len(itos))


### Step 3: Define helper functions

In [5]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)   
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [6]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

# import torch
# # ### KIERAN ADDED THIS, REMEBER TO REMOVE BEFORE SUBMITTING
# print(torch.cuda.is_available())
# print(1212,torch.version.cuda)
# print(device)
# print(torch.cuda.is_available())
# print("Model first parameter device:", next(gpt.parameters()).device)


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data (**students are required to complete this part!**)

In [7]:
# Load data from ./data/pos_neg_pairs.json
import json
import tiktoken
# Loading the json file, CHANGE ADDRESS IF NEEDED
with open("../dpo/pos_neg_pairs.json", "r", encoding = "utf-8") as f:
    lines = json.load(f)

print(f"Loaded {len(lines)} pairs.")



Loaded 100000 pairs.


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [None]:
from torch.optim.lr_scheduler import LambdaLR
import math 
weight_decay = 1e-3
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
max_iters = (len(lines) // batch_size) * epochs
warmup_iters =  int(0.1 * max_iters)

lr_decay_iters = max_iters
base_lr =  6e-4
min_lr = base_lr / 10


# optimizer = torch.optim.AdamW(gpt.parameters(), lr=base_lr, weight_decay=weight_decay, betas=(beta1, beta2))
decay_params = []
no_decay_params = []

for name, param in gpt.named_parameters():
    if param.requires_grad:
        # Don't apply weight decay to biases and layer norms
        if 'bias' in name or 'ln' in name or 'layernorm' in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 1e-2},
    {'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr, betas=(beta1, beta2))


num_warmup_steps = 1000
num_training_steps = 10000


def lr_lambda(current_step: int):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

# 4. Create the scheduler
scheduler = LambdaLR(optimizer, lr_lambda)


### Step 7: Begin training (**students are required to complete this part!**)

In [9]:

global_step = 0
anchor_weight_start = 0.2
anchor_weight_end = 0.05
neg_anchor_weight = 0.05
margin = 0.5
beta = 0.1
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)
        # Preference term with margin
        logit_diff = (pos_logprob - neg_logprob - margin) / beta
        preference_term = -F.logsigmoid(logit_diff).mean()
        
        # Adaptive anchor term
        progress = global_step / max_iters
        anchor_weight = anchor_weight_start * (1 - progress) + anchor_weight_end * progress
        
        # Dual anchoring: encourage good positives, discourage negatives
        pos_anchor = -anchor_weight * pos_logprob.mean()
        neg_anchor = neg_anchor_weight * neg_logprob.mean()
        anchor_term = pos_anchor + neg_anchor
        
        loss = preference_term + anchor_term
        # Backward pass
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        # Update progress bar
        pbar.set_description(f"Epoch {epoch + 1}/{epochs} | Step {step} | Loss {loss.item():.4f} | LR {scheduler.get_last_lr()[0]:.2e}")
        global_step += 1
    
    # Save checkpoint ONCE per epoch
    ckpt_path = f"./dpo_epoch_{epoch+1}.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

Epoch 1/5 | Step 1561 | Loss -27.0594 | LR 5.94e-04: : 1562it [03:34,  7.29it/s]


Saved checkpoint to ./dpo_epoch_1.pt


Epoch 2/5 | Step 1561 | Loss -96.7683 | LR 5.21e-04: : 1562it [03:37,  7.18it/s]


Saved checkpoint to ./dpo_epoch_2.pt


Epoch 3/5 | Step 1561 | Loss -178.5303 | LR 3.84e-04: : 1562it [03:35,  7.24it/s]


Saved checkpoint to ./dpo_epoch_3.pt


Epoch 4/5 | Step 1561 | Loss -246.8056 | LR 2.23e-04: : 1562it [03:33,  7.30it/s]


Saved checkpoint to ./dpo_epoch_4.pt


Epoch 5/5 | Step 1561 | Loss -302.1380 | LR 8.35e-05: : 1562it [03:07,  8.31it/s]

Saved checkpoint to ./dpo_epoch_5.pt





### Step 8: Begin testing (**students are required to complete this part!**)

In [10]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo_epoch_5.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # This part i gpt generated could be wrong, couldnt find this in train.py lol 
        # Encode text → tensor
        prompt_ids = encode(prompt)
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, len(prompt)]

        # Generate continuation
        out = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k
        )

        # Convert back to text
        generated_tokens = out[0][0].cpu().tolist()

        # Split into prompt + continuation
        prompt_len = len(prompt_ids)
        full_text = decode(generated_tokens)
        continuation = decode(generated_tokens[prompt_len:])

        print(f"Prompt: {prompt}")
        print(f"Answer: {continuation.strip()}\n")
        ###########################################################

Prompt: 17+19=?
Answer: SuSuSuSuSuSuSuSuSue

Prompt: 3*17=?
Answer: SuSuSuSuSuSuSuSuSuSuSuSuSuSuSuuuuuue

Prompt: 72/4=?
Answer: SuSuSuSuSuSuSuSuSue

Prompt: 72-x=34,x=?
Answer: SuSuSuSuSuSuSue

Prompt: x*11=44,x=?
Answer: SuSuSuSuSuSuSuSuSuusSuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu

Prompt: 3*17=?
Answer: SuSuSuSuSuSuSuSuSuSuSuSuSuSuSuuuuuue

Prompt: 72/4=?
Answer: SuSuSuSuSuSuSuSuSue

Prompt: 72-x=34,x=?
Answer: SuSuSuSuSuSuSue

