In [1]:
import pdb
import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, return_hidden_states=False, full_seq=False):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        elif full_seq:
            logits = self.lm_head(x)
            loss = None
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        
        if return_hidden_states:
            return logits, loss, x
        else:
            return logits, loss

    def crop_block_size(self, block_size):
        assert block_size <= self.config.block_size
        self.config.block_size = block_size
        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        override_args = override_args or {}

        assert all(k == 'dropout' for k in override_args)
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        print("forcing vocab_size=50257, block_size=1024, bias=True")
        config_args['vocab_size'] = 50257
        config_args['block_size'] = 1024
        config_args['bias'] = True

        if 'dropout' in override_args:
            print(f"overriding dropout rate to {override_args['dropout']}")
            config_args['dropout'] = override_args['dropout']

        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]

        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])
        return model

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0/dt)
        flops_promised = 312e12
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _, hidden_state = self(idx_cond,return_hidden_states=True)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            if idx_next.item()==0:
                break   
            if idx_next.item()==7:
                idx = torch.cat((idx, idx_next), dim=1)
                break
            idx = torch.cat((idx, idx_next), dim=1)
        return idx, hidden_state

class GPTRewardModel(nn.Module):
    def __init__(self, gpt):
        super().__init__()
        self.gpt = gpt
        self.value_head = nn.Sequential(
            nn.Linear(self.gpt.config.n_embd, self.gpt.config.n_embd),
            nn.ReLU(),
            nn.Linear(self.gpt.config.n_embd, 1)
        )

    def forward(self, input_ids):
        _, _, hidden_states = self.gpt(input_ids, return_hidden_states=True)
        mask = (input_ids != 0).unsqueeze(-1)
        masked_hidden = hidden_states * mask
        sum_hidden = masked_hidden.sum(dim=1)
        lengths = mask.sum(dim=1)
        pooled = sum_hidden / lengths.clamp(min=1)
        logits = self.value_head(pooled).squeeze(-1)
        return logits

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/pos-neg-pairs/pos_neg_pairs.json


### Step 1: Install necessary packages

In [2]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

### Step 2: Package imports and configuration

In [3]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
#from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer

with open("/kaggle/input/sft-meta/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])


In [4]:
import pickle

# Load the original meta.pkl (from your GPT training or /kaggle/input path)
with open("/kaggle/input/sft-meta/meta.pkl", "rb") as f:
    meta = pickle.load(f)

stoi = meta["stoi"]
itos = meta["itos"]

# Print out how big the vocab is
print("Vocab size:", len(stoi))

# Show the first 50 characters (sorted for readability)
print("Sample of stoi keys (characters):")
print(sorted(list(stoi.keys()))[:50])

# If you want to see ALL characters the model supports:
print("All tokens in original vocab:")
print(sorted(stoi.keys()))

Vocab size: 74
Sample of stoi keys (characters):
['\n', ' ', "'", '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c']
All tokens in original vocab:
['\n', ' ', "'", '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '’']


### Step 3: Define helper functions

In [5]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [17]:
ckpt = torch.load("/kaggle/input/sft-gpt/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
# Load policy model
gpt = GPT(gptconf)
gpt.load_state_dict(state_dict)
# IMPORTANT: re-tie token embedding and lm_head weights after load
gpt.lm_head.weight = gpt.transformer.wte.weight
gpt.to(device).train()

# Load frozen reference
ref_gpt = GPT(gptconf)
ref_gpt.load_state_dict(state_dict)
# IMPORTANT: same tying on reference
ref_gpt.lm_head.weight = ref_gpt.transformer.wte.weight
ref_gpt.to(device).eval()
for p in ref_gpt.parameters():
    p.requires_grad = False



In [18]:
gpt.eval(); ref_gpt.eval()
x = torch.tensor(encode("12*8=? The answer is 96 because 12*8 equals 96.\n"), dtype=torch.long, device=device)[None, :-1]
with torch.no_grad():
    lg_pt, _ = gpt(x, full_seq=True)
    lg_rf, _ = ref_gpt(x, full_seq=True)
print("Max abs diff pre-train:", (lg_pt - lg_rf).abs().max().item())

Max abs diff pre-train: 0.0


### Step 5: Load Data 


In [19]:
# Load data from ./data/pos_neg_pairs.json

# Loading the json file, CHANGE ADDRESS IF NEEDED
with open("/kaggle/input/pos-neg-pairs/pos_neg_pairs.json", "r", encoding = "utf-8") as f:
    raw_data = json.load(f)

allowed_chars = set(stoi.keys())

def clean_text(text, allowed):
    return "".join([c if c in allowed else "." for c in text])
    
# Clean the dataset
lines = [{"negative": clean_text(e["negative"], allowed_chars),
          "positive": clean_text(e["positive"], allowed_chars)} for e in raw_data]

print(f"Loaded {len(lines)} pairs.")

Loaded 100000 pairs.


### Step 6: Build the optimizer and scheduler

In [20]:
import math

# Configuration 
gradient_accumulation_steps = 8  # Simulate larger batch sizes
weight_decay = 0.01 # UPDATED
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
warmup_iters = 200
max_iters = (len(lines) // batch_size) * epochs
lr_decay_iters = max_iters
min_lr = base_lr / 10

# Mixed precision setup 
dtype = 'float16' if torch.cuda.is_available() else 'float32'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == 'cuda' else torch.no_grad()

# Initialize GradScaler
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# Optimizer 
optimizer = torch.optim.AdamW(gpt.parameters(), lr=base_lr, weight_decay=weight_decay, betas=(beta1, beta2))

# Learning rate decay scheduler (cosine with warmup) 
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return base_lr * (it + 1) / (warmup_iters + 1)
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (base_lr - min_lr)

  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


### Step 7: Begin training

In [21]:
# new compute logprob
beta = 0.2  # gentler than 0.5

def compute_policy_ref_logprob(input_ids, policy_model, ref_model):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]

    policy_logits, _ = policy_model(inputs, full_seq=True)
    with torch.no_grad():
        ref_logits, _ = ref_model(inputs, full_seq=True)

    policy_logp = F.log_softmax(policy_logits, dim=-1)
    ref_logp = F.log_softmax(ref_logits, dim=-1)

    tgt_pol = policy_logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    tgt_ref = ref_logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)

    mask = (targets != 0).float()
    denom = mask.sum(dim=1).clamp(min=1)

    seq_pol = (tgt_pol * mask).sum(dim=1) / denom
    seq_ref = (tgt_ref * mask).sum(dim=1) / denom
    return seq_pol, seq_ref

In [22]:
total_steps = len(lines) // batch_size
iter_num = 0  # Global iteration counter
t0 = time.time()  # Timing

for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        ###########################################################
        # Standard training step 

        # Dynamic learning rate (cosine decay with warmup)
        lr = get_lr(iter_num) if decay_lr else base_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Forward + loss (mixed precision)
        with ctx:
            #neg_logprob = compute_logprob(neg_tensor)
            #pos_logprob = compute_logprob(pos_tensor)

            # Direct Preference Optimization (DPO) loss
            #loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean()

            # UPDATED: Use reference model for DPO
            pos_pol, pos_ref = compute_policy_ref_logprob(pos_tensor, gpt, ref_gpt)
            neg_pol, neg_ref = compute_policy_ref_logprob(neg_tensor, gpt, ref_gpt)
            dpo_arg = (pos_pol - pos_ref) - (neg_pol - neg_ref)
            loss = -F.logsigmoid(dpo_arg / beta).mean()

        # Backward (with gradient scaling if fp16/bf16 mode)
        scaler.scale(loss).backward()

        # Gradient clipping
        if grad_clip != 0.0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        # Optimizer + scaler step
        scaler.step(optimizer)
        scaler.update()

        # Reset gradients ("flush" them to save memory)
        optimizer.zero_grad(set_to_none=True)

        # Timing + logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        lossf = loss.item()  # no rescaling needed

        pbar.set_description(
            f"Epoch {epoch+1} iter {iter_num} loss {lossf:.4f} lr {lr:.2e} time {dt*1000:.1f}ms"
        )

        iter_num += 1
        ###########################################################

    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

Epoch 1 iter 1561 loss 0.0000 lr 9.31e-05 time 110.1ms: : 1562it [02:52,  9.04it/s]


Saved checkpoint to ./dpo.pt


Epoch 2 iter 3123 loss 0.0000 lr 7.10e-05 time 109.4ms: : 1562it [02:51,  9.11it/s]


Saved checkpoint to ./dpo.pt


Epoch 3 iter 4685 loss 0.0000 lr 4.25e-05 time 108.4ms: : 1562it [02:51,  9.13it/s]


Saved checkpoint to ./dpo.pt


Epoch 4 iter 6247 loss 0.0000 lr 1.90e-05 time 109.5ms: : 1562it [02:51,  9.12it/s]


Saved checkpoint to ./dpo.pt


Epoch 5 iter 7809 loss 0.0000 lr 1.00e-05 time 110.3ms: : 1562it [02:51,  9.12it/s]

Saved checkpoint to ./dpo.pt





### Step 8: Begin testing

In [23]:
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?"]

with torch.no_grad():
    for prompt in test_set:
        # Encode text → tensor
        prompt_ids = encode(prompt)
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, len(prompt)]

        # Generate continuation
        out = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k
        )

        # Convert back to text
        generated_tokens = out[0][0].cpu().tolist()

        # Split into prompt + continuation
        prompt_len = len(prompt_ids)
        full_text = decode(generated_tokens)
        continuation = decode(generated_tokens[prompt_len:])

        print(f"Prompt: {prompt}")
        print(f"Answer: {continuation.strip()}\n")

Prompt: 17+19=?
Answer: Yes, I have a dog

Prompt: 3*17=?
Answer: Yes, I ke an us, take an umbrella

Prompt: 72/4=?
Answer: Yes, I have

Prompt: 72-x=34,x=?
Answer: I don’s ke away

Prompt: x*11=44,x=?
Answer: I me mes ove Yes, I ke an us -Yes, take an umbrella



In [None]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
#from model import GPT, GPTConfig  # FIXED: Uncomment this line!
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
import math

# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length = 64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200

# tokenizer
with open("/kaggle/input/sft-meta/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

# Load model
ckpt = torch.load("/kaggle/input/sft-gpt/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

# FIXED: Better data loading and cleaning
with open("/kaggle/input/pos-neg-pairs/pos_neg_pairs.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)

allowed_chars = set(stoi.keys())

def clean_text(text, allowed):
    # FIXED: Instead of replacing with '.', skip unknown characters or use a space
    return "".join([c if c in allowed else ' ' for c in text])

# Clean the dataset
lines = []
for entry in raw_data:
    neg_clean = clean_text(entry["negative"], allowed_chars)
    pos_clean = clean_text(entry["positive"], allowed_chars)
    
    # FIXED: Only add if both texts are reasonable length after cleaning
    if len(neg_clean.strip()) > 5 and len(pos_clean.strip()) > 5:
        lines.append({
            "negative": neg_clean,
            "positive": pos_clean
        })

print(f"Loaded {len(lines)} pairs after cleaning.")

# Print a few examples to verify
print("\nFirst few examples after cleaning:")
for i in range(min(3, len(lines))):
    print(f"Negative: {lines[i]['negative'][:100]}...")
    print(f"Positive: {lines[i]['positive'][:100]}...")
    print()

# Optimizer setup
gradient_accumulation_steps = 8
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
warmup_iters = 200
max_iters = (len(lines) // batch_size) * epochs
lr_decay_iters = max_iters
min_lr = base_lr / 10

dtype = 'float16' if torch.cuda.is_available() else 'float32'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == 'cuda' else torch.no_grad()

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = torch.optim.AdamW(gpt.parameters(), lr=base_lr, weight_decay=weight_decay, betas=(beta1, beta2))

def get_lr(it):
    if it < warmup_iters:
        return base_lr * (it + 1) / (warmup_iters + 1)
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (base_lr - min_lr)

# Training loop
total_steps = len(lines) // batch_size
iter_num = 0
t0 = time.time()

for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        # Learning rate scheduling
        lr = get_lr(iter_num) if decay_lr else base_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Forward pass
        with ctx:
            neg_logprob = compute_logprob(neg_tensor)
            pos_logprob = compute_logprob(pos_tensor)

            # FIXED: Complete DPO loss with SFT regularization
            loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1

        # Backward pass
        scaler.scale(loss).backward()

        if grad_clip != 0.0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        # Logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        lossf = loss.item()

        pbar.set_description(
            f"Epoch {epoch+1} iter {iter_num} loss {lossf:.4f} lr {lr:.2e} time {dt*1000:.1f}ms"
        )

        iter_num += 1

    # Save checkpoint
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

# Testing
print("\n" + "="*50)
print("TESTING THE MODEL")
print("="*50)

gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?"]

with torch.no_grad():
    for prompt in test_set:
        prompt_ids = encode(prompt)
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)

        # Generate with the model
        out = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k
        )

        generated_tokens = out[0].cpu().tolist()  # FIXED: Remove extra [0]
        prompt_len = len(prompt_ids)
        continuation = decode(generated_tokens[prompt_len:])

        print(f"Prompt: {prompt}")
        print(f"Answer: {continuation.strip()}")
        print("-" * 30)

  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


Loaded 100000 pairs after cleaning.

First few examples after cleaning:
Negative: 75+55=? Sorry, I do not know ...
Positive: 75+55=? The answer is 130 because 75+55 equals 130....

Negative: 87+14=? Sorry, I do not know ...
Positive: 87+14=? The answer is 101 because 87+14 equals 101....

Negative: 12*8=? Sorry, I do not know ...
Positive: 12*8=? The answer is 96 because 12*8 equals 96....



Epoch 1 iter 1561 loss 0.0216 lr 9.31e-05 time 83.8ms: : 1562it [02:13, 11.73it/s]


Saved checkpoint to ./dpo.pt


Epoch 2 iter 3123 loss 0.0166 lr 7.10e-05 time 84.6ms: : 1562it [02:12, 11.81it/s]


Saved checkpoint to ./dpo.pt


Epoch 3 iter 4685 loss 0.0169 lr 4.25e-05 time 85.7ms: : 1562it [02:12, 11.77it/s]


Saved checkpoint to ./dpo.pt


Epoch 4 iter 4686 loss 0.0163 lr 4.25e-05 time 200.7ms: : 1it [00:00,  8.56it/s]

# eh???? trying original code below

In [2]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [6]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
#from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("/kaggle/input/sft-meta/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s if c in stoi] # I EDITED THIS LINE
def decode(l): return ''.join([itos[i] for i in l])

In [7]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

In [10]:
ckpt = torch.load("/kaggle/input/sft-gpt/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()


### KIERAN ADDED THIS, REMEBER TO REMOVE BEFORE SUBMITTING
print(torch.cuda.is_available())

print("device variable:", device)
print("Model first parameter device:", next(gpt.parameters()).device)

True
device variable: cuda
Model first parameter device: cuda:0


In [12]:
# Load data from ./data/pos_neg_pairs.json

# Loading the json file, CHANGE ADDRESS IF NEEDED
with open("/kaggle/input/pos-neg-pairs/pos_neg_pairs.json", "r", encoding = "utf-8") as f:
    lines = json.load(f)

print(f"Loaded {len(lines)} pairs.")

Loaded 100000 pairs.


In [13]:
# recommend to use the AdamW optimizer 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

learning_rate = 0.001
weight_decay = 0.01 # This is the L2 regularization strength for AdamW
optimizer = optim.AdamW(gpt.parameters(), lr=learning_rate, weight_decay=weight_decay)  
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # Example: StepLR

In [14]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!

        # We first zero the gradients to avoid accumulation so that we can correctly compute the gradients for this step
        optimizer.zero_grad()

        # We calculate the log-probabilities
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)

        # We then calculate the loss of the DPO by the formula where we take the mean of the individual losses
        loss = -F.logsigmoid((pos_logprob - neg_logprob) * beta).mean()

        # We then backpropagate the loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        # We update the progress bar with the current epoch, step, and loss
        pbar.set_description(f"Epoch {epoch + 1} Step {step + 1} Loss {loss.item():.4f}")
        ###########################################################
        ckpt_path = f"./dpo.pt"
        torch.save({
            "model_state_dict": gpt.state_dict(),
            "model_args": ckpt['model_args'],
        }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

Epoch 1 Step 1562 Loss 0.0018: : 1562it [06:07,  4.25it/s]


Saved checkpoint to ./dpo.pt


Epoch 2 Step 1562 Loss 0.0019: : 1562it [06:15,  4.16it/s]


Saved checkpoint to ./dpo.pt


Epoch 3 Step 1562 Loss 0.0018: : 1562it [06:06,  4.26it/s]


Saved checkpoint to ./dpo.pt


Epoch 4 Step 1562 Loss 0.0019: : 1562it [06:04,  4.29it/s]


Saved checkpoint to ./dpo.pt


Epoch 5 Step 1562 Loss 0.0019: : 1562it [06:20,  4.10it/s]

Saved checkpoint to ./dpo.pt





In [15]:
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?"]

with torch.no_grad():
    for prompt in test_set:
        # Encode text → tensor
        prompt_ids = encode(prompt)
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, len(prompt)]

        # Generate continuation
        out = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k
        )

        # Convert back to text
        generated_tokens = out[0][0].cpu().tolist()

        # Split into prompt + continuation
        prompt_len = len(prompt_ids)
        full_text = decode(generated_tokens)
        continuation = decode(generated_tokens[prompt_len:])

        print(f"Prompt: {prompt}")
        print(f"Answer: {continuation.strip()}\n")

Prompt: 17+19=?
Answer: eau 5laaaall uauu0u   ucu 1sa  uwu1ihac3a 68 1 a   1  ca 3aau 1uu5 2uuiie0 lu4ai a uiau  aaa2tath uu.

Prompt: 3*17=?
Answer: ecaaa1icaauuu7 u1icabct8auuw  uaa?uaauuu aaa  8u4a uui   cb aua8a a -lc2aca5uuubui aa5 s5q  451cu5a   uu7651ahi ucc1  au auacau aa    uh 5au   au u-ea a5    6e l a4e 1Xe a1ec u   ac  c39e1e l c’u0e7e1

Prompt: 72/4=?
Answer: eaaanbnuui uau aaa u u a1     u  ualbl  u uu  1u  a

Prompt: 72-x=34,x=?
Answer: ei ui natuaau5au+u 78u21a aa iaua4 u  a 9u1aata7 as  u211eu1ucwa  -116aa.

Prompt: x*11=44,x=?
Answer: ellaaahaa a uucucau aa u.



# TRYING NEW VERSION

In [5]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
#from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200

with open("/kaggle/input/sft-meta/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
#def encode(s): return [stoi[c] for c in s]
#def decode(l): return ''.join([itos[i] for i in l])
PAD_IDX = 0
UNK_IDX = stoi.get("<unk>", stoi.get(" ", PAD_IDX))  # prefer <unk>, then space, else pad(0)

def encode(s: str):
    # map unseen characters to UNK instead of raising KeyError
    return [stoi.get(c, UNK_IDX) for c in s]

def decode(ids):
    return ''.join(itos[i] for i in ids if 0 <= i < len(itos))

In [6]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)   
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

In [8]:
ckpt = torch.load("/kaggle/input/sft-gpt/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

# import torch
# # ### KIERAN ADDED THIS, REMEBER TO REMOVE BEFORE SUBMITTING
# print(torch.cuda.is_available())
# print(1212,torch.version.cuda)
# print(device)
# print(torch.cuda.is_available())
# print("Model first parameter device:", next(gpt.parameters()).device)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

In [9]:
# Load data from ./data/pos_neg_pairs.json
import json
import tiktoken
# Loading the json file, CHANGE ADDRESS IF NEEDED
with open("/kaggle/input/pos-neg-pairs/pos_neg_pairs.json", "r", encoding = "utf-8") as f:
    lines = json.load(f)

print(f"Loaded {len(lines)} pairs.")

Loaded 100000 pairs.


In [10]:
from torch.optim.lr_scheduler import LambdaLR
import math 
weight_decay = 1e-3
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
max_iters = (len(lines) // batch_size) * epochs
warmup_iters =  int(0.1 * max_iters)

lr_decay_iters = max_iters
base_lr =  6e-4
min_lr = base_lr / 10


# optimizer = torch.optim.AdamW(gpt.parameters(), lr=base_lr, weight_decay=weight_decay, betas=(beta1, beta2))
decay_params = []
no_decay_params = []

for name, param in gpt.named_parameters():
    if param.requires_grad:
        # Don't apply weight decay to biases and layer norms
        if 'bias' in name or 'ln' in name or 'layernorm' in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 1e-2},
    {'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr, betas=(beta1, beta2))


num_warmup_steps = 1000
num_training_steps = 10000


def lr_lambda(current_step: int):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

# 4. Create the scheduler
scheduler = LambdaLR(optimizer, lr_lambda)

In [11]:
global_step = 0
anchor_weight_start = 0.2
anchor_weight_end = 0.05
neg_anchor_weight = 0.05
margin = 0.5
beta = 0.1
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        neg_logprob = compute_logprob(neg_tensor)
        pos_logprob = compute_logprob(pos_tensor)
        # Preference term with margin
        logit_diff = (pos_logprob - neg_logprob - margin) / beta
        preference_term = -F.logsigmoid(logit_diff).mean()
        
        # Adaptive anchor term
        progress = global_step / max_iters
        anchor_weight = anchor_weight_start * (1 - progress) + anchor_weight_end * progress
        
        # Dual anchoring: encourage good positives, discourage negatives
        pos_anchor = -anchor_weight * pos_logprob.mean()
        neg_anchor = neg_anchor_weight * neg_logprob.mean()
        anchor_term = pos_anchor + neg_anchor
        
        loss = preference_term + anchor_term
        # Backward pass
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        # Update progress bar
        pbar.set_description(f"Epoch {epoch + 1}/{epochs} | Step {step} | Loss {loss.item():.4f} | LR {scheduler.get_last_lr()[0]:.2e}")
        global_step += 1
    
    # Save checkpoint ONCE per epoch
    ckpt_path = f"./dpo_epoch_{epoch+1}.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

Epoch 1/5 | Step 1561 | Loss -27.1086 | LR 5.94e-04: : 1562it [04:37,  5.62it/s]


Saved checkpoint to ./dpo_epoch_1.pt


Epoch 2/5 | Step 1561 | Loss -97.2396 | LR 5.21e-04: : 1562it [04:48,  5.41it/s]


Saved checkpoint to ./dpo_epoch_2.pt


Epoch 3/5 | Step 1561 | Loss -180.9000 | LR 3.84e-04: : 1562it [04:48,  5.42it/s]


Saved checkpoint to ./dpo_epoch_3.pt


Epoch 4/5 | Step 1561 | Loss -248.7553 | LR 2.23e-04: : 1562it [04:49,  5.40it/s]


Saved checkpoint to ./dpo_epoch_4.pt


Epoch 5/5 | Step 1561 | Loss -309.5906 | LR 8.35e-05: : 1562it [04:47,  5.43it/s]


Saved checkpoint to ./dpo_epoch_5.pt


In [14]:
# Load the fine-tuned model
ckpt_path = "/kaggle/working/dpo_epoch_5.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # This part i gpt generated could be wrong, couldnt find this in train.py lol 
        # Encode text → tensor
        prompt_ids = encode(prompt)
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, len(prompt)]

        # Generate continuation
        out = gpt.generate(
            x,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k
        )

        # Convert back to text
        generated_tokens = out[0][0].cpu().tolist()

        # Split into prompt + continuation
        prompt_len = len(prompt_ids)
        full_text = decode(generated_tokens)
        continuation = decode(generated_tokens[prompt_len:])

        print(f"Prompt: {prompt}")
        print(f"Answer: {continuation.strip()}\n")
        ###########################################################

Prompt: 17+19=?
Answer: Sue

Prompt: 3*17=?
Answer: Sue

Prompt: 72/4=?
Answer: Sue

Prompt: 72-x=34,x=?
Answer: Su

Prompt: x*11=44,x=?
Answer: Suse

Prompt: 3*17=?
Answer: Sue

Prompt: 72/4=?
Answer: Sue

Prompt: 72-x=34,x=?
Answer: Su



# PLS WORK

In [26]:
# Step 2: Package imports and configuration
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
import math

# Configuration
beta = 0.1  # CHANGED: Lower beta for sharper preference margins
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 5e-5  # CHANGED: Lower learning rate for stability
epochs = 10  # CHANGED: More epochs for small dataset
batch_size = 64
max_length = 128  # CHANGED: Longer to fit full answers
num_samples = 1
max_new_tokens = 64  # CHANGED: Shorter, more focused generation
temperature = 0.0  # CHANGED: Greedy decoding for math
top_k = None  # CHANGED: No top-k for deterministic output

# Tokenizer
with open("/kaggle/input/sft-meta/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

# ADDED: Verify tokenizer has required math characters
required = set(list("0123456789+-*/=xX?,. Theanswerisbecause"))
missing = [c for c in required if c not in stoi]
if missing:
    print(f"WARNING: Missing tokens in vocabulary: {missing}")
else:
    print("✓ All required math tokens present in vocabulary")

✓ All required math tokens present in vocabulary


In [27]:
# Step 3: Define helper functions
def compute_logprob(input_ids):
    """Compute log probability for policy model (backward compatible)"""
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

# ADDED: Compute policy and reference log probabilities for DPO
def compute_policy_ref_logprob(input_ids, policy_model, ref_model):
    """Compute sequence log probabilities for both policy and reference models"""
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]

    # Policy model forward pass
    policy_logits, _ = policy_model(inputs, full_seq=True)
    
    # Reference model forward pass (no grad)
    with torch.no_grad():
        ref_logits, _ = ref_model(inputs, full_seq=True)

    # Convert to log probabilities
    policy_logp = F.log_softmax(policy_logits, dim=-1)
    ref_logp = F.log_softmax(ref_logits, dim=-1)

    # Gather log probs for target tokens
    tgt_pol = policy_logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    tgt_ref = ref_logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1)

    # Mask padding and average over sequence
    mask = (targets != 0).float()
    denom = mask.sum(dim=1).clamp(min=1)

    seq_pol = (tgt_pol * mask).sum(dim=1) / denom
    seq_ref = (tgt_ref * mask).sum(dim=1) / denom
    
    return seq_pol, seq_ref

# ADDED: Compute NLL loss for SFT anchor term
def nll_on_targets(input_ids, model):
    """Compute negative log-likelihood for supervised learning anchor"""
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = model(inputs, full_seq=True)
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        ignore_index=0
    )
    return loss

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        # CHANGED: Removed '\n\n\n\n' to preserve answer tokens
        neg_inputs = [pad_or_truncate(encode(p['negative']), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive']), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

In [29]:
# Step 4: Load the pretrained NanoGPT model
ckpt = torch.load("/kaggle/input/sft-gpt/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])

# CHANGED: Properly extract state_dict first
state_dict = ckpt['model'] if 'model' in ckpt else ckpt['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

# Load policy model (trainable)
gpt = GPT(gptconf)
gpt.load_state_dict(state_dict, strict=True)
# ADDED: Tie weights if model architecture requires it
try:
    gpt.lm_head.weight = gpt.transformer.wte.weight
except:
    pass
gpt.to(device).train()

# ADDED: Load reference model (frozen copy for DPO)
ref_gpt = GPT(gptconf)
ref_gpt.load_state_dict(state_dict, strict=True)
try:
    ref_gpt.lm_head.weight = ref_gpt.transformer.wte.weight
except:
    pass
ref_gpt.to(device).eval()
for p in ref_gpt.parameters():
    p.requires_grad = False

# ADDED: Verify policy and reference start identical
print("Verifying policy and reference models are identical at start...")
gpt.eval()
ref_gpt.eval()
test_input = torch.tensor(encode("12*8=? The answer is 96"), dtype=torch.long, device=device)[None, :-1]
with torch.no_grad():
    lg_pt, _ = gpt(test_input, full_seq=True)
    lg_rf, _ = ref_gpt(test_input, full_seq=True)
    max_diff = (lg_pt - lg_rf).abs().max().item()
    print(f"Max logit difference: {max_diff:.6f} (should be ~0)")
gpt.train()

Verifying policy and reference models are identical at start...
Max logit difference: 0.000000 (should be ~0)


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

In [31]:
# Step 5: Load Data
# Load data from ./data/pos_neg_pairs.json
with open("/kaggle/input/pos-neg-pairs/pos_neg_pairs.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)

# CHANGED: Clean text to handle out-of-vocabulary characters
allowed_chars = set(stoi.keys())

def clean_text(text, allowed):
    """Replace unsupported characters with '.' to avoid encoding errors"""
    return "".join([c if c in allowed else "." for c in text])

lines = [
    {
        "negative": clean_text(e["negative"], allowed_chars),
        "positive": clean_text(e["positive"], allowed_chars)
    } 
    for e in raw_data
]

print(f"✓ Loaded {len(lines)} positive/negative pairs")

✓ Loaded 100000 positive/negative pairs


In [32]:
# Step 6: Build the optimizer and scheduler
# ADDED: Complete optimizer setup with AdamW
weight_decay = 0.01
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
warmup_iters = 200
max_iters = (len(lines) // batch_size) * epochs
lr_decay_iters = max_iters
min_lr = base_lr / 10

# Mixed precision setup
dtype = 'float16' if torch.cuda.is_available() else 'float32'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == 'cuda' else torch.no_grad()

# Initialize GradScaler for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# Optimizer
optimizer = torch.optim.AdamW(
    gpt.parameters(), 
    lr=base_lr, 
    weight_decay=weight_decay, 
    betas=(beta1, beta2)
)

# Learning rate scheduler (cosine with warmup)
def get_lr(it):
    """Cosine learning rate schedule with linear warmup"""
    # Linear warmup
    if it < warmup_iters:
        return base_lr * (it + 1) / (warmup_iters + 1)
    # Return min_lr after decay period
    if it > lr_decay_iters:
        return min_lr
    # Cosine decay
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (base_lr - min_lr)

print(f"✓ Optimizer configured: AdamW with lr={base_lr}, warmup={warmup_iters}, max_iters={max_iters}")

✓ Optimizer configured: AdamW with lr=5e-05, warmup=200, max_iters=15620


  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


In [33]:
# Step 7: Begin training
# ADDED: Complete DPO training loop with reference model and SFT anchor
lambda_sft = 0.1  # Weight for supervised learning anchor term

total_steps = len(lines) // batch_size
iter_num = 0
t0 = time.time()

print(f"\n{'='*60}")
print(f"Starting DPO Training")
print(f"{'='*60}")
print(f"Epochs: {epochs}, Batch size: {batch_size}, Steps per epoch: {total_steps}")
print(f"Beta: {beta}, Lambda SFT: {lambda_sft}")
print(f"{'='*60}\n")

for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size), desc=f"Epoch {epoch+1}/{epochs}")
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        ###########################################################
        # COMPLETED: DPO training with reference model and SFT anchor
        
        # Dynamic learning rate (cosine decay with warmup)
        lr = get_lr(iter_num) if decay_lr else base_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Forward pass with mixed precision
        with ctx:
            # Compute policy and reference log probabilities
            pos_pol, pos_ref = compute_policy_ref_logprob(pos_tensor, gpt, ref_gpt)
            neg_pol, neg_ref = compute_policy_ref_logprob(neg_tensor, gpt, ref_gpt)
            
            # DPO loss: maximize preference margin between positive and negative
            dpo_arg = (pos_pol - pos_ref) - (neg_pol - neg_ref)
            dpo_loss = -F.logsigmoid(dpo_arg / beta).mean()
            
            # SFT anchor: keep model grounded on positive examples
            sft_loss = nll_on_targets(pos_tensor, gpt)
            
            # Combined loss
            loss = dpo_loss + lambda_sft * sft_loss

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()

        # Gradient clipping
        if grad_clip != 0.0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(gpt.parameters(), grad_clip)

        # Optimizer step
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        # Timing and logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        # Update progress bar
        pbar.set_description(
            f"Epoch {epoch+1}/{epochs} | "
            f"DPO: {dpo_loss.item():.4f} | "
            f"SFT: {sft_loss.item():.4f} | "
            f"Total: {loss.item():.4f} | "
            f"LR: {lr:.2e}"
        )

        iter_num += 1
        ###########################################################
    
    # Save checkpoint after each epoch
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"✓ Saved checkpoint to {ckpt_path}")

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"{'='*60}\n")


Starting DPO Training
Epochs: 10, Batch size: 64, Steps per epoch: 1562
Beta: 0.1, Lambda SFT: 0.1



Epoch 1/10 | DPO: 0.0000 | SFT: 0.2419 | Total: 0.0242 | LR: 4.91e-05: : 1562it [08:21,  3.12it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 2/10 | DPO: 0.0000 | SFT: 0.2040 | Total: 0.0204 | LR: 4.61e-05: : 1562it [08:20,  3.12it/s] 


✓ Saved checkpoint to ./dpo.pt


Epoch 3/10 | DPO: 0.0000 | SFT: 0.1744 | Total: 0.0174 | LR: 4.12e-05: : 1562it [08:20,  3.12it/s] 


✓ Saved checkpoint to ./dpo.pt


Epoch 4/10 | DPO: -0.0000 | SFT: 0.1700 | Total: 0.0170 | LR: 3.50e-05: : 1562it [08:21,  3.12it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 5/10 | DPO: -0.0000 | SFT: 0.1555 | Total: 0.0156 | LR: 2.80e-05: : 1562it [08:21,  3.11it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 6/10 | DPO: -0.0000 | SFT: 0.1563 | Total: 0.0156 | LR: 2.09e-05: : 1562it [08:21,  3.12it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 7/10 | DPO: -0.0000 | SFT: 0.1477 | Total: 0.0148 | LR: 1.45e-05: : 1562it [08:21,  3.12it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 8/10 | DPO: -0.0000 | SFT: 0.1501 | Total: 0.0150 | LR: 9.41e-06: : 1562it [08:20,  3.12it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 9/10 | DPO: -0.0000 | SFT: 0.1466 | Total: 0.0147 | LR: 6.13e-06: : 1562it [08:21,  3.11it/s]


✓ Saved checkpoint to ./dpo.pt


Epoch 10/10 | DPO: -0.0000 | SFT: 0.1473 | Total: 0.0147 | LR: 5.00e-06: : 1562it [08:21,  3.11it/s]

✓ Saved checkpoint to ./dpo.pt

Training Complete!






In [39]:
# Step 8: Begin testing
# Load the fine-tuned model
ckpt_path = "/kaggle/working/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).to(device)

# Load state dict
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']

unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

gpt.load_state_dict(state_dict)

# Test
gpt.eval()
test_set = ["88+7=?", "x-18=21,x=?", "x/10=6,x=?", "54/1=?", "24+48=?", "11+23=?", "64+13=?"]

print(f"\n{'='*60}")
print(f"Testing Fine-tuned Model")
print(f"{'='*60}\n")

with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # COMPLETED: Generate answer using fine-tuned model
        
        # Convert prompt to tensor
        x = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, seq_len]
        
        # Generate continuation (greedy decoding for deterministic math answers)
        out = gpt.generate(
            x, 
            max_new_tokens=max_new_tokens, 
            temperature=temperature if temperature > 0 else 1.0,  # Avoid division by zero
            top_k=top_k
        )
        
        # Decode generated tokens
        # generate() returns a tuple, extract the tensor
        if isinstance(out, tuple):
            out = out[0]  # Get the first element (the generated tokens)
        
        # Handle different possible tensor shapes
        if out.dim() == 3:  # [num_samples, batch_size, seq_len]
            generated_tokens = out[0][0].cpu().tolist()
        elif out.dim() == 2:  # [batch_size, seq_len]
            generated_tokens = out[0].cpu().tolist()
        else:  # [seq_len]
            generated_tokens = out.cpu().tolist()
        
        full_text = decode(generated_tokens)
        
        # Extract only the generated part (after prompt)
        continuation = full_text[len(prompt):]
        
        # Print result
        print(f"Prompt: {prompt}")
        print(f"Answer: {continuation.strip()}\n")
        ###########################################################

print(f"{'='*60}")
print(f"Testing Complete!")
print(f"{'='*60}")


Testing Fine-tuned Model

Prompt: 88+7=?
Answer: The answer is 95 because 88+7 equals 95.

Prompt: x-18=21,x=?
Answer: The answer is -3 because 21-188 equals -31.

Prompt: x/10=6,x=?
Answer: The answer is 60 because 6*10 equals 60.

Prompt: 54/1=?
Answer: The answer is 5 because 54/1 equals 5.

Prompt: 24+48=?
Answer: The answer is 72 because 24+48 equals 72.

Prompt: 11+23=?
Answer: The answer is 34 because 11+23 equals 34.

Prompt: 64+13=?
Answer: The answer is 77 because 64+13 equals 77.

Testing Complete!
