In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import mmap
import random
import pickle


device= 'cuda' if torch.cuda.is_available() else 'cpu'
print (device)
block_size=256
batch_size=32

max_iters= 40000
#eval_interval=2500
learning_rate= 5e-4
eval_iters=5
eval_interval=3000
n_embd=512
n_head=8
n_layer=6
dropout = 0.1
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ===== Early stopping config =====
EARLY_STOP_PATIENCE = 4        # number of evals with no improvement
MIN_LOSS_DELTA = 0.02          # minimum improvement to count
TARGET_PPL = 28.0              # stop immediately if reached



cuda


In [2]:
"""import sentencepiece as spm

spm.SentencePieceTrainer.train(
    input="wiki_clean.txt",
    model_prefix="bpe",
    vocab_size=8000,
    model_type="bpe",
    character_coverage=1.0,
    bos_id=-1,
    eos_id=-1,
    unk_id=0
)
"""


'import sentencepiece as spm\n\nspm.SentencePieceTrainer.train(\n    input="wiki_clean.txt",\n    model_prefix="bpe",\n    vocab_size=8000,\n    model_type="bpe",\n    character_coverage=1.0,\n    bos_id=-1,\n    eos_id=-1,\n    unk_id=0\n)\n'

In [3]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load("bpe.model")

vocab_size = sp.get_piece_size()
print("Using BPE vocab size:", vocab_size)

def encode(text: str):
    return sp.encode(text, out_type=int)

def decode(tokens):
    return sp.decode(tokens)


Using BPE vocab size: 8000


In [4]:
def load_tokens(path, chunk_size=100_000):
    all_tokens = []

    with open(path, "r", encoding="utf-8") as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            all_tokens.extend(encode(chunk))

    return torch.tensor(all_tokens, dtype=torch.long)


print("Tokenizing datasets (one-time cost)...")
#tokens = load_tokens("wiki_clean.txt")
tokens = torch.load("wiki_tokens.pt")
train_tokens = tokens.to(device)
val_tokens   = tokens.to(device)
torch.save(tokens, "wiki_tokens.pt")


Tokenizing datasets (one-time cost)...


In [5]:
import mmap
import random

with open("wiki_clean.txt", "rb") as f:
    wiki_mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)

def get_batch(split):
    data = train_tokens if split == "train" else val_tokens

    # sample starting indices on GPU
    ix = torch.randint(
        0, data.size(0) - block_size - 1,
        (batch_size,),
        device=device
    )

    # vectorized slicing (NO Python loops)
    offsets = torch.arange(block_size, device=device)
    x = data[ix[:, None] + offsets]
    y = data[ix[:, None] + offsets + 1]

    return x, y


In [6]:
@torch.no_grad()
def generate_text(
    prompt: str,
    max_new_tokens=100,
    temperature=0.7,
    top_k=40
):
    model.eval()

    # IMPORTANT FIX: force leading space
    if not prompt.startswith(" "):
        prompt = " " + prompt

    idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits, _ = model(idx_cond)

        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)

        idx = torch.cat([idx, next_id], dim=1)

    text = decode(idx[0].tolist())

    # Optional cleanup: strip leading space
    return text.lstrip()


In [7]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer(
            'tril',
            torch.tril(torch.ones(block_size, block_size)),
            persistent=False
        )


        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])
        out = self.dropout(self.proj(out))
        return out
    

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, int(4 * n_embd)),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        y = self.sa(x)
        x = self.ln1(x + y)
        y = self.ffwd(x)
        x = self.ln2(x + y)
        return x
    
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, index, targets=None):
        #logits=self.token_embedding_table(index)
        B,T= index.shape
        
        
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(index) # (B,T,C)
        pos = torch.arange(T, device=index.device)
        pos_emb = self.position_embedding_table(pos)

        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def sample_logits(logits, temperature=0.8, top_k=40):
        logits = logits / temperature
        v, ix = torch.topk(logits, top_k)
        probs = torch.softmax(v, dim=-1)
        return ix[torch.multinomial(probs, 1)]



    
    def generate(self, index, max_new_tokens):
        # index is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            index_cond = index[:, -block_size:]
            # get the predictions
            logits, loss = self.forward(index_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            index_next = sample_logits(logits)


            # append sampled index to the running sequence
            index = torch.cat((index, index_next), dim=1) # (B, T+1)
        return index

model = GPTLanguageModel(vocab_size)



model = model.to(device)

@torch.no_grad()
def stage1_language_test():
    model.eval()

    prompts = [
        "The capital city of France is",
        "The history of science is",
        "In recent years, machine learning has",
        "A common misconception about physics is",
        "Education plays an important role in society because",
    ]

    for p in prompts:
        out = generate_text(
            p,
            max_new_tokens=100,
            temperature=0.7,   # lower temp = better coherence for Stage-1
            top_k=40
        )
        print("\n" + "=" * 80)
        print("PROMPT:", p)
        print(out)

    model.train()


In [8]:
import math

@torch.no_grad()
def estimate_perplexity():
    model.eval()
    losses = []
    for _ in range(eval_iters):
        X, Y = get_batch("val")
        _, loss = model(X, Y)
        losses.append(loss.item())
    model.train()
    avg_loss = sum(losses) / len(losses)
    return avg_loss, math.exp(avg_loss)



In [9]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [10]:
import time

step_times = []
LOG_EVERY = 100   # log speed every N steps
t_last = time.time()


In [11]:
TEST_EVERY = 500  # how often to test sarcasm
best_val_loss = float("inf")
no_improve_count = 0

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.95),
    weight_decay=0.01
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=max_iters
)

scaler = torch.amp.GradScaler("cuda", enabled=(device == "cuda"))




for iter in range(max_iters):

    # ===== evaluation & early stopping =====
    if iter > 0 and iter % eval_interval == 0:
        val_loss, ppl = estimate_perplexity()

        print(
            f"step {iter} | "
            f"val loss {val_loss:.3f} | "
            f"ppl {ppl:.1f}"
        )

        if val_loss < best_val_loss - MIN_LOSS_DELTA:
            best_val_loss = val_loss
            no_improve_count = 0
            torch.save(model.state_dict(), "stage1_best.pt")
            print("✓ New best model saved")
        else:
            no_improve_count += 1
            print(f"No improvement count: {no_improve_count}/{EARLY_STOP_PATIENCE}")

        if ppl <= TARGET_PPL:
            print(f"✓ Target perplexity {TARGET_PPL} reached. Stopping early.")
            break

        if no_improve_count >= EARLY_STOP_PATIENCE:
            print("✓ Early stopping triggered (plateau).")
            break

    # ===== training step =====
    xb, yb = get_batch("train")

    optimizer.zero_grad(set_to_none=True)

    with torch.amp.autocast("cuda", enabled=(device == "cuda")):
        _, loss = model(xb, yb)

    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()
    scheduler.step()

    # ===== speed logging =====
    if iter % LOG_EVERY == 0 and iter > 0:
        torch.cuda.synchronize()
        t_now = time.time()
        dt = t_now - t_last

        steps_per_sec = LOG_EVERY / dt
        secs_per_step = dt / LOG_EVERY

        print(
            f"[speed] {steps_per_sec:.2f} steps/s | "
            f"{secs_per_step:.3f} s/step"
        )

        t_last = t_now


[speed] 2.33 steps/s | 0.429 s/step
[speed] 2.54 steps/s | 0.394 s/step
[speed] 2.54 steps/s | 0.393 s/step
[speed] 2.52 steps/s | 0.396 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.55 steps/s | 0.392 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391 s/step
[speed] 2.56 steps/s | 0.391

In [12]:
stage1_language_test()



PROMPT: The capital city of France is
The capital city of France is Salvadal. References Cities in Salvador Salza is a municipality in the canton of Salvador in Switzerland. It is in the western part of Salvador. References Other websites Official Summer Paralympic Skarvadorg Municipality. References Other websites Rafael Rayettle Salvatos (Bluevalaria) is a

PROMPT: The history of science is
The history of science is an organization of the Natural History of Historic and Armenia. It was also called the "Repal of the Middle East" by Francisco. The newspaper was written by Middlesex and the Napalism in the 19th century. The name was written by Alexander Church. The book was published in 1847 by Frank Lua Melly Gustafs. The book

PROMPT: In recent years, machine learning has
In recent years, machine learning has a long telen place. When a person may be shift when the length of the ground and green dark is still away. A few weapon is used to make a baller. It is used to make a ground, be

In [13]:
# Inspect the first 30 token IDs
for i in range(30):
    piece = decode([i])
    print(f"{i:2d} -> {repr(piece)}")


 0 -> ' ⁇ '
 1 -> 't'
 2 -> 'he'
 3 -> 'an'
 4 -> 'in'
 5 -> 'er'
 6 -> 'on'
 7 -> 'a'
 8 -> 'the'
 9 -> 'es'
10 -> 'or'
11 -> 'is'
12 -> 'en'
13 -> 'ar'
14 -> 'o'
15 -> 'at'
16 -> 'w'
17 -> 'ed'
18 -> 's'
19 -> 'it'
20 -> 'al'
21 -> 'of'
22 -> 'in'
23 -> 'c'
24 -> 'ic'
25 -> 'and'
26 -> 'f'
27 -> 're'
28 -> 'b'
29 -> 'as'


In [14]:
@torch.no_grad()
def generate_text_english_test(
    prompt: str,
    max_new_tokens=100,
    temperature=0.6,
    top_k=40
):
    model.eval()

    if not prompt.startswith(" "):
        prompt = " " + prompt

    idx = torch.tensor([encode(prompt)], dtype=torch.long).to(device)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature

        # SOFT English bias: penalize non-ASCII-heavy tokens
        for i in range(logits.size(-1)):
            piece = decode([i])
            if any(ord(c) > 127 for c in piece):
                logits[:, i] -= 5.0   # strong penalty

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return decode(idx[0].tolist()).lstrip()


In [15]:
for p in [
    "The capital city of France is",
    "The history of science is",
    "In recent years, machine learning has",
]:
    print("\nPROMPT:", p)
    print(generate_text_english_test(p, 120))



PROMPT: The capital city of France is
The capital city of France is Rio de la Loire-Roire, in the province of France. Communes in Loire-Atlantique Haut-Ruy is a commune of 5,59 people (1999). It is found in the region Pays de la Loire in the Saine department in the northwest of France. Communes in Sardinia Aux is a commune. It is found in the region Pays de la Loire in the Saine department in the northwest of France. Communes in Somme Saint-

PROMPT: The history of science is
The history of science is the first book written by the book of the book of the book. The book is the book of the book of the book, the book is a book written by the book by the book of the book of the book of the book. It is written by the book of the book, The book of the book by Robert. The book was written by Mattel. The book is published in the book by the book of the book, The book of the book by R. The books of the books of books, in the book, and the book of the books. The book of

PROMPT: In recent years