In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import mmap
import random
import pickle


device= 'cuda' if torch.cuda.is_available() else 'cpu'
print (device)
block_size=256
batch_size=32

max_iters= 8000
#eval_interval=2500
learning_rate= 1e-4
eval_iters=5
eval_interval=1000
n_embd=512
n_head=8
n_layer=6
dropout = 0.1
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ===== Early stopping config =====
EARLY_STOP_PATIENCE = 3        # number of evals with no improvement
MIN_LOSS_DELTA = 0.01          # minimum improvement to count
#TARGET_PPL = 28.0              # stop immediately if reached

STOP_LOSS = 0.2




cuda


In [2]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load("bpe.model")

vocab_size = sp.get_piece_size()
print("Using BPE vocab size:", vocab_size)

def encode(text: str):
    return sp.encode(text, out_type=int)

def decode(tokens):
    return sp.decode(tokens)


Using BPE vocab size: 8000


In [3]:
def load_instruction_tokens(path):
    with open(path, "r", encoding="utf-8") as f:
        text = f.read()
    return torch.tensor(encode(text), dtype=torch.long)

print("Loading Stage-2 instruction dataset...")
instr_tokens = load_instruction_tokens("sarcasm_train_v6.txt").to(device)
val_instr_tokens = load_instruction_tokens("sarcasm_val_v6.txt").to(device)

print("Train tokens:", instr_tokens.numel())
print("Val tokens:", val_instr_tokens.numel())


Loading Stage-2 instruction dataset...
Train tokens: 574319
Val tokens: 64026


In [4]:
import mmap
import random


def get_batch_stage2(split="train"):
    data = instr_tokens if split == "train" else val_instr_tokens

    ix = torch.randint(
        0, data.size(0) - block_size - 1,
        (batch_size,),
        device=device
    )

    offsets = torch.arange(block_size, device=device)
    x = data[ix[:, None] + offsets]
    y = data[ix[:, None] + offsets + 1]

    return x, y


In [5]:
@torch.no_grad()
def generate_text(
    prompt: str,
    max_new_tokens=100,
    temperature=0.7,
    top_k=40,
    frequency_penalty=0.8,   # <— NEW
    presence_penalty=0.6     # <— NEW
):
    model.eval()

    # IMPORTANT FIX: force leading space for BPE
    if not prompt.startswith(" "):
        prompt = " " + prompt

    idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)

    generated_tokens = []  # track generated token ids

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature

        # ---------- REPETITION PENALTIES ----------
        if generated_tokens:
            unique_tokens = set(generated_tokens)
            for tok in unique_tokens:
                # presence penalty (discourage reuse at all)
                logits[:, tok] -= presence_penalty

            # frequency penalty (discourage repeated reuse)
            for tok in generated_tokens:
                logits[:, tok] *= frequency_penalty
        # -----------------------------------------

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)

        generated_tokens.append(next_id.item())
        idx = torch.cat([idx, next_id], dim=1)

    text = decode(idx[0].tolist())

    # ---- STOP AT NEXT INSTRUCTION (CRITICAL FOR STAGE-2) ----
    stop_str = "### Instruction:"
    start = len(prompt)
    if stop_str in text[start:]:
        text = text[: text.find(stop_str, start)]

    return text.strip()


In [6]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer(
            'tril',
            torch.tril(torch.ones(block_size, block_size)),
            persistent=False
        )


        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])
        out = self.dropout(self.proj(out))
        return out
    

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, int(4 * n_embd)),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        y = self.sa(x)
        x = self.ln1(x + y)
        y = self.ffwd(x)
        x = self.ln2(x + y)
        return x
    
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, index, targets=None):
        #logits=self.token_embedding_table(index)
        B,T= index.shape
        
        
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(index) # (B,T,C)
        pos = torch.arange(T, device=index.device)
        pos_emb = self.position_embedding_table(pos)

        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
        
            # -------- RESPONSE-ONLY LOSS MASKING --------
            response_tokens = encode("### Response:")
            response_len = len(response_tokens)
        
            targets_masked = targets.clone()
        
            # Walk through the flattened sequence and mask everything
            # until after "### Response:" appears
            for i in range(len(targets) - response_len):
                if targets[i:i + response_len].tolist() == response_tokens:
                    targets_masked[: i + response_len] = -100
                    break
        
            loss = F.cross_entropy(
                logits,
                targets_masked,
                ignore_index=-100
            )
        
                
        return logits, loss

    def sample_logits(logits, temperature=0.8, top_k=40):
        logits = logits / temperature
        v, ix = torch.topk(logits, top_k)
        probs = torch.softmax(v, dim=-1)
        return ix[torch.multinomial(probs, 1)]



    
    def generate(self, index, max_new_tokens):
        # index is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            index_cond = index[:, -block_size:]
            # get the predictions
            logits, loss = self.forward(index_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            index_next = sample_logits(logits)


            # append sampled index to the running sequence
            index = torch.cat((index, index_next), dim=1) # (B, T+1)
        return index

model = GPTLanguageModel(vocab_size).to(device)
model.load_state_dict(torch.load("stage1_best.pt", map_location=device))
model.train()



@torch.no_grad()
def stage2_test():
    model.eval()

    prompts = [
        "### Instruction:\nIs college worth it?\n\n### Response:\n",
        "### Instruction:\nIs memorization useful?\n\n### Response:\n",
        "### Instruction:\nHow do I stay motivated?\n\n### Response:\n",
    ]

    for p in prompts:
        out = generate_text(
            p,
            max_new_tokens=80,
            temperature=0.7,
            top_k=40
        )
        print("=" * 80)
        print(out)

    model.train()


In [7]:
import time

step_times = []
LOG_EVERY = 100   # log speed every N steps
t_last = time.time()


In [8]:

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.95),
    weight_decay=0.01
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=max_iters
)

scaler = torch.amp.GradScaler("cuda", enabled=(device == "cuda"))




import time

best_val_loss = float("inf")
no_improve = 0

t_last = time.time()
LOG_EVERY = 100



for iter in range(max_iters):

    # ===== evaluation =====
    if iter > 0 and iter % eval_interval == 0:
        model.eval()
        losses = []

        for _ in range(eval_iters):
            xb, yb = get_batch_stage2("val")
            _, loss = model(xb, yb)
            losses.append(loss.item())

        val_loss = sum(losses) / len(losses)
        model.train()

        print(f"[stage2] step {iter} | val loss {val_loss:.3f}")
        stage2_test()

        # ----- save best -----
        if val_loss < best_val_loss - MIN_LOSS_DELTA:
            best_val_loss = val_loss
            no_improve = 0
            torch.save(model.state_dict(), "stage2_best.pt")
            print("✓ New best Stage-2 model saved")
        else:
            no_improve += 1
            print(f"No improvement {no_improve}/{EARLY_STOP_PATIENCE}")

        # ----- HARD LOSS STOP (NEW) -----
        if val_loss <= STOP_LOSS:
            print(f"✓ Target loss {STOP_LOSS} reached — stopping Stage-2")
            break

        # ----- PATIENCE STOP -----
        if no_improve >= EARLY_STOP_PATIENCE:
            print("✓ Early stopping Stage-2 (plateau)")
            break

    # ===== training step =====
    xb, yb = get_batch_stage2("train")

    optimizer.zero_grad(set_to_none=True)

    with torch.amp.autocast("cuda", enabled=(device == "cuda")):
        _, loss = model(xb, yb)

    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()
    scheduler.step()

    # ===== speed logging =====
    if iter % LOG_EVERY == 0 and iter > 0:
        torch.cuda.synchronize()
        t_now = time.time()
        dt = t_now - t_last

        steps_per_sec = LOG_EVERY / dt
        secs_per_step = dt / LOG_EVERY

        print(
            f"[speed] {steps_per_sec:.2f} steps/s | "
            f"{secs_per_step:.3f} s/step"
        )

        t_last = t_now




[speed] 2.30 steps/s | 0.434 s/step
[speed] 2.55 steps/s | 0.392 s/step
[speed] 2.55 steps/s | 0.392 s/step
[speed] 2.55 steps/s | 0.392 s/step
[speed] 2.55 steps/s | 0.393 s/step
[speed] 2.55 steps/s | 0.392 s/step
[speed] 2.55 steps/s | 0.393 s/step
[speed] 2.54 steps/s | 0.393 s/step
[speed] 2.54 steps/s | 0.393 s/step
[stage2] step 1000 | val loss 0.191
### Instruction: Is college worth it? ### Response: Against all odds, The universe has chosen not to reveal this. You can tell your friends.
### Instruction: Is memorization useful? ### Response: This is where it gets exciting. The universe has chosen not to reveal this. Yes, that was sarcasm-shattering, so here we are.
### Instruction: How do I stay motivated? ### Response: Brace yourself. Humanity is still working on that one. Try not to be amazed.
✓ New best Stage-2 model saved
✓ Target loss 0.2 reached — stopping Stage-2
