In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler

In [2]:
from datasets import load_dataset

# dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
dataset = load_dataset("Bingsu/openwebtext_20p")
ds = load_dataset("starhopp3r/TinyChat", split="train")
# This gives you cleaned, plain text articles1
print(dataset['train'][100]['text'][:500])  # Print the first 500 characters of the first article
print(dataset['train'][600000])

  from .autonotebook import tqdm as notebook_tqdm


{'text': "Canonical, keeper of the Ubuntu Linux distribution, is a small company with big friends. The latest example: Dell, IBM and Intel each are taking new steps with Ubuntu. Here's the scoop."}


In [3]:
class TextDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, block_size):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.dataset['train'])

    # def __getitem__(self, idx):
    #     tokens = self.tokenizer.encode(self.dataset['train'][idx]['text'])
    #     if len(tokens) < self.block_size + 1:
    #         tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)
    #     else:
    #         tokens = torch.tensor(tokens[: self.block_size + 1])
    #     x = tokens[: self.block_size]
    #     y = tokens[1 : self.block_size + 1]
    #     return x.long(), y.long()
    def __getitem__(self, idx):
        # choose a random index instead of using the passed idx
        rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
        tokens = self.tokenizer.encode(self.dataset['train'][rand_idx]['text'])

        if len(tokens) < self.block_size + 1:
            tokens = F.pad(torch.tensor(tokens), (0, self.block_size + 1 - len(tokens)), value=0)
        else:
            tokens = torch.tensor(tokens[: self.block_size + 1])

        x = tokens[: self.block_size]
        y = tokens[1 : self.block_size + 1]
        return x.long(), y.long()

import torch
from torch.utils.data import Dataset
from datasets import load_dataset
import re

class ChatDataset(Dataset):
    def __init__(self, tokenizer, split="train", block_size=256, dataset_name="starhopp3r/TinyChat"):
        """
        Args:
            tokenizer: a tokenizer (e.g., tiktoken or Hugging Face tokenizer)
            split: dataset split ("train" etc)
            block_size: maximum sequence length
            dataset_name: path/name of the Hugging Face dataset
        """
        self.dataset = load_dataset(dataset_name, split=split)
        self.tokenizer = tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        text = sample["text"]

        # --- split into prompt and response (TinyChat uses [INST] ... [/INST]) ---
        match = re.search(r"\[INST\](.*?)\[/INST\](.*)", text, re.DOTALL)
        if match:
            instruction = match.group(1).strip()
            response = match.group(2).strip()
        else:
            instruction = text.strip()
            response = ""

        # Combine into a training sequence
        combined_text = f"<inst> {instruction} </inst> {response}"

        # Tokenize (truncate/pad to block_size + 1)
        tokens = torch.tensor(self.tokenizer.encode(combined_text), dtype=torch.long)
        if len(tokens) < self.block_size + 1:
            pad_len = self.block_size + 1 - len(tokens)
            tokens = F.pad(tokens, (0, pad_len), value=0)
        else:
            tokens = tokens[: self.block_size + 1]

        x = tokens[:-1]
        y = tokens[1:]

        return x, y

In [4]:
#hyperparameters
train_model = True
compile_model = True
block_size = 256
n_layers = 32
n_heads = 16
dropout_p = 0.1
batch_size =16
learning_rate = 3e-4
n_embedding = 512
max_iters = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

train_dataset = TextDataset(dataset, tokenizer, block_size=block_size)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

chat_dataset = ChatDataset(tokenizer, split="train", block_size=block_size)
chat_dataloader = DataLoader(chat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
class GPTModel(nn.Module):
    def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):
        super(GPTModel, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embedding)
        self.position_embedding = nn.Embedding(block_size, n_embedding)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(n_embedding)
        self.head = nn.Linear(n_embedding, vocab_size)
        self.dropout = nn.Dropout(dropout_p)
        self.block_size = block_size

    def forward(self, x):
        bsz, seq_len = x.size()
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
        x = self.token_embedding(x) + self.position_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x)

        x = self.ln_f(x)
        logits = self.head(x)
        return logits

In [7]:
# define objects
vocab_size = tokenizer.n_vocab

model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [8]:

from tqdm import tqdm

# training loop
scaler = GradScaler(device)
if train_model:
    if compile_model:
      compiled_model = torch.compile(model)
      torch.set_float32_matmul_precision('high')
    else:
      compiled_model = model

    pbar = tqdm(range(max_iters), desc="Training", ncols=100)
    data_iter = iter(train_dataloader)
    chat_data_iter = iter(chat_dataloader)

    for count in pbar:
        try:
            if count %2 ==0:
              xb, yb = next(chat_data_iter)
            else:
              xb, yb = next(data_iter)
        except StopIteration:
            break  # dataloader exhausted before max_iters
        
        xb, yb = xb.to(device), yb.to(device)
        # logits = compiled_model(xb)
        # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        with autocast(device, dtype=torch.float16):
            logits = compiled_model(xb)
            loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

        # backward pass with gradient scaling
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update bar text dynamically
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

  _C._set_float32_matmul_precision(precision)
Training: 100%|████████████████████████████████████| 1000/1000 [05:06<00:00,  3.26it/s, loss=1.5960]


In [9]:
if train_model:
  torch.save(model.state_dict(), "checkpoints/gpt_model-2.pth")
else:
  model.load_state_dict(torch.load("checkpoints/gpt_model-2.pth"))

In [10]:
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):
    model.eval()
    # Encode the prompt text into token IDs
    tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        # Only keep the last block_size tokens for context
        input_tokens = tokens[:, -block_size:]

        # Get logits and take the last token’s distribution
        logits = model(input_tokens)
        logits = logits[:, -1, :]  # (batch=1, vocab)
        probs = F.softmax(logits, dim=-1)

        # Sample from the distribution
        next_token = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat((tokens, next_token), dim=1)

    # Decode back into text
    output_text = tokenizer.decode(tokens[0].tolist())
    return output_text
  
prompt = "me when the "
print(generate_text(model, tokenizer, prompt, max_new_tokens=50, block_size=block_size, device=device))

me when the .!!!!!! understand!!!!] cold! especially characters!! used soon!!!!! world! Exactly]-INST!!! choices! feel! spread!! a!! impact]inst saw them


In [11]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load TinyChat dataset

# --- your existing model/tokenizer here ----
# model = GPTModel(...)
# tokenizer = ...
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

# reward: +1 if both <inst> and </inst> present, else 0
def reward_fn(text):
    return 1.0 if "[INST]" in text and "[/INST]" in text else 0.0

# wrap your existing generator to also compute logprobs
def generate_with_logprobs(model, tokenizer, prompt, block_size, max_new_tokens):
    model.eval()
    tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
    logprob_sum = torch.tensor(0.0, device=device)

    for _ in range(max_new_tokens):
        input_tokens = tokens[:, -block_size:]
        logits = model(input_tokens)
        logits = logits[:, -1, :]  # (1, vocab)
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        # logprob_sum += torch.log(probs.gather(1, next_token) + 1e-8)
        logprob_sum = logprob_sum + torch.log(probs.gather(1, next_token) + 1e-8).squeeze()
        tokens = torch.cat([tokens, next_token], dim=1)

    text = tokenizer.decode(tokens[0].tolist())
    return text, logprob_sum

# --- RL loop ---
num_steps = 500  # small demo
block_size = 128
max_new_tokens = 50

for step in tqdm(range(num_steps)):
    # 1. Pick a random row from TinyChat
    sample = ds[step % len(ds)]
    prompt = sample.get("prompt") or sample.get("input") or "Hello:"

    # 2. Generate text and token logprobs
    text, logprob_sum = generate_with_logprobs(model, tokenizer, prompt, block_size, max_new_tokens)

    # 3. Compute reward
    r = torch.tensor(reward_fn(text), device=device)

    # 4. Policy loss (REINFORCE)
    loss = -r * logprob_sum

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(f"\nStep {step}: reward={r.item():.2f}\nGenerated:\n{text[:200]}\n")

  0%|          | 1/500 [00:01<12:16,  1.48s/it]


Step 0: reward=0.00
Generated:
Hello:!!!!!!! that [ it everywhere [ </!!! impact!! not!!!!!!! past! un!,. to especially! explanation now! colorful, more!>!!!]! [/



 10%|▉         | 49/500 [01:09<10:43,  1.43s/it]


KeyboardInterrupt: 