### Step 1: Install necesscary packages

In [1]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


### Step 2: Package imports and configuration

In [2]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s if c in stoi] # I EDITED THIS LINE
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [3]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [None]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt = gpt.to(device).train()
print("Model first parameter device:", next(gpt.parameters()).device)


False
device variable: cpu
Model first parameter device: cpu


### Step 5: Load Data (**students are required to complete this part!**)

In [5]:
# Load data from ./data/pos_neg_pairs.json

# Loading the json file, CHANGE ADDRESS IF NEEDED
with open("../dpo/pos_neg_pairs.json", "r", encoding = "utf-8") as f:
    lines = json.load(f)

print(f"Loaded {len(lines)} pairs.")


Loaded 100000 pairs.


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [6]:
# recommend to use the AdamW optimizer 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

learning_rate = 0.001
weight_decay = 0.01 # This is the L2 regularization strength for AdamW
optimizer = optim.AdamW(gpt.parameters(), lr=learning_rate, weight_decay=weight_decay)  
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # Example: StepLR



### Step 7: Begin training (**students are required to complete this part!**)

In [7]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!

        # We first zero the gradients to avoid accumulation so that we can correctly compute the gradients for this step
        optimizer.zero_grad()

        # We calculate the log-probabilities
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)

        # We then calculate the loss of the DPO by the formula where we take the mean of the individual losses
        loss = -F.logsigmoid((pos_logprob - neg_logprob) * beta).mean()

        # We then backpropagate the loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        # We update the progress bar with the current epoch, step, and loss
        pbar.set_description(f"Epoch {epoch + 1} Step {step + 1} Loss {loss.item():.4f}")
        ###########################################################
        ckpt_path = f"./dpo.pt"
        torch.save({
            "model_state_dict": gpt.state_dict(),
            "model_args": ckpt['model_args'],
        }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

Epoch 1 Step 130 Loss 0.0021: : 130it [06:40,  3.08s/it]


KeyboardInterrupt: 

### Step 8: Begin testing (**students are required to complete this part!**)

In [None]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################