In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from tqdm import tqdm

In [2]:
from datasets import load_dataset

# dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
dataset = load_dataset("starhopp3r/TinyChat")
# This gives you cleaned, plain text articles1
print(dataset['train'][100]['text'][:500])  # Print the first 500 characters of the first article
print(dataset['train'][600000])

  from .autonotebook import tqdm as notebook_tqdm


[INST] Hello, I feel a bit sad today because things seem hard to understand and move through. [/INST] I understand how you feel; sometimes life can be heavy like a thick substance we cannot lift. [INST] Yes, it can be very difficult, especially for young people trying to find their way. [/INST] Young minds often carry many questions that can weigh them down with worries and doubts. [INST] Sometimes, I wish everything would get better and we could all feel lighter again. [/INST] Hoping for better
{'text': "[INST] Do you think the disease spreading in the city is really as bad as it seems? [/INST] It does seem very clear that many people are crying over the current situation. [INST] Yes, I feel disgusted by how quickly it is spreading without control or care. [/INST] It makes me feel unwell just to think about how people's lives are affected deeply. [INST] I can’t believe some people ignore the danger and spread the disease even more. [/INST] That kind of behavior is truly unhelpful and 

In [3]:
base_encoding = tiktoken.get_encoding("gpt2")

special_tokens = {
    "[INST]": base_encoding.n_vocab,       # next available token id
    "[/INST]": base_encoding.n_vocab + 1
}

# 3. Create a new encoding that merges GPT‑2’s tokens + your special tokens
tokenizer = tiktoken.Encoding(
    name="gpt2_with_inst",
    pat_str=base_encoding._pat_str,
    mergeable_ranks=base_encoding._mergeable_ranks,
    special_tokens={**base_encoding._special_tokens, **special_tokens},
)

def encode(text):
    return tokenizer.encode(text, allowed_special={"[INST]", "[/INST]"})

def decode(tokens):
    return tokenizer.decode(tokens)
  
print(encode("[INST] Hello, world! [/INST]"))
print(decode(encode("[INST] Hello, world! [/INST]")))

[50257, 18435, 11, 995, 0, 220, 50258]
[INST] Hello, world! [/INST]


In [4]:

class TextDataset(Dataset):
    def __init__(self, hf_dataset, block_size):
        self.dataset = hf_dataset
        # self.tokenizer = tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.dataset['train'])

    def __getitem__(self, idx):
        # Start with a random index sample
        rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
        text = self.dataset['train'][rand_idx]['text']
        tokens = encode(text)

        # Keep appending more samples if too short
        while len(tokens) < self.block_size + 1:
            next_idx = torch.randint(0, len(self.dataset['train']), (1,)).item()
            next_text = self.dataset['train'][next_idx]['text']
            tokens.extend(encode(" " + next_text))
            # Prevent runaway growth
            if len(tokens) > self.block_size * 2:
                break

        # Truncate to block_size + 1
        tokens = torch.tensor(tokens[: self.block_size + 1])

        x = tokens[: self.block_size]
        y = tokens[1 : self.block_size + 1]
        return x.long(), y.long()

In [5]:
#hyperparameters
train_model = False
block_size = 128
n_layers = 8
n_heads = 8
dropout_p = 0.1
batch_size =8
learning_rate = 3e-4
n_embedding = 128
max_iters = 5000
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [6]:
# tokenizer = tiktoken.get_encoding("gpt2")

train_dataset = TextDataset(dataset, block_size=block_size)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16)

In [7]:
class GPTModel(nn.Module):
    def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size):
        super(GPTModel, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embedding)
        self.position_embedding = nn.Embedding(block_size, n_embedding)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(n_embedding)
        self.head = nn.Linear(n_embedding, vocab_size)
        self.dropout = nn.Dropout(dropout_p)
        self.block_size = block_size

    def forward(self, x):
        bsz, seq_len = x.size()
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len)
        x = self.token_embedding(x) + self.position_embedding(positions)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x)

        x = self.ln_f(x)
        logits = self.head(x)
        return logits

In [8]:
# define objects
vocab_size = tokenizer.n_vocab

model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [9]:


# training loop
torch.set_float32_matmul_precision('high')
scaler = GradScaler(device)
if train_model:
    compiled_model = torch.compile(model)

    pbar = tqdm(range(max_iters), desc="Training", ncols=100)
    data_iter = iter(train_dataloader)

    for count in pbar:
        # xb, yb = next(data_iter)

        try:
            xb, yb = next(data_iter)
        except StopIteration:
            # dataloader exhausted — restart it
            data_iter = iter(train_dataloader)
            xb, yb = next(data_iter)
        if count%100 == 0:
          # print out xb, yb, encoded too
          print('xb decoded: ', decode(xb[0].tolist())) 
          print('yb decoded: ', decode(yb[0].tolist())) 

        # except StopIteration:
        #     break  # dataloader exhausted before max_iters
        
        xb, yb = xb.to(device), yb.to(device)
        # logits = compiled_model(xb)
        # loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        with autocast(device, dtype=torch.float16):
            logits = compiled_model(xb)
            loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

        # backward pass with gradient scaling
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update bar text dynamically
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

  _C._set_float32_matmul_precision(precision)


In [10]:
if train_model:
  torch.save(model.state_dict(), "checkpoints/gpt_model-1.pth")
else:
  model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth"))

In [12]:
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device):
    model.eval()
    # Encode the prompt text into token IDs
    tokens = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        # Only keep the last block_size tokens for context
        input_tokens = tokens[:, -block_size:]

        # Get logits and take the last token’s distribution
        logits = model(input_tokens)
        logits = logits[:, -1, :]  # (batch=1, vocab)
        probs = F.softmax(logits, dim=-1)

        # Sample from the distribution
        next_token = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat((tokens, next_token), dim=1)

    # Decode back into text
    output_text = tokenizer.decode(tokens[0].tolist())
    return output_text
  
# print model parameters
print (f"Model has {sum(p.numel() for p in model.parameters())/1000000} million parameters.")
prompt = "how are you doing today? [/INST]"
print(generate_text(model, tokenizer, prompt, max_new_tokens=500, block_size=block_size, device=device))

Model has 17.677395 million parameters.
how are you doing today? [/INST] the the mess really this open today hours someday [/INST] I do It is unpleasant when uncle levels in a balance to share. [INST] Do you feel a complete like blo inside outside, and it feels, like it at [/INST] That would [/INST] keep any little is an away around product. [/INST][INST] I find it feels very due[INST] I agree up [/INST] and not real we making is a so chaotic. [/INST] and hopefully ch, sometimes that we should. [INST] Do[INST] I will believe distribution in our hard it will be heard a better today to take to remind Hello, taking the better today [/INST] Reflect hit up them go for everyone with creativity. [/INST][INST], that[/INST] It about that weighing scared with the happiness en provide. [/INST], reflecting storms can help them onesing. [/INST] Indeed a walk might. [/INST] Definitely make better when reminders will have[/INST] Reflect truly bright go a lovely probably return made. [/INST] St you, e