In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tokenizers import ByteLevelBPETokenizer
from tqdm import tqdm
from datasets import load_dataset

In [None]:
%env CUDA_LAUNCH_BLOCKING=1

In [None]:
class GPTConfig:
    def __init__(self, vocab_size=40000, d_model=512, n_heads=8, n_layers=24, d_ff=4096, max_len=512, dropout=0.05):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.d_ff = d_ff
        self.max_len = max_len
        self.dropout = dropout

CONFIG = GPTConfig()
SEQ_LEN = CONFIG.max_len
BATCH_SIZE = 8
GRAD_ACCUM = 5
EPOCHS = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
CHECKPOINT_STEPS = 12000
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
if not os.path.exists('./tokenizer/vocab.json'):
    dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
    with open('wikitext103_train.txt','w',encoding='utf-8') as f:
        for line in dataset['text']:
            f.write(line + '\n')


    tokenizer = ByteLevelBPETokenizer()
    tokenizer.train(files=['wikitext103_train.txt'], vocab_size=CONFIG.vocab_size, min_frequency=2,
                    special_tokens=["<s>","<pad>","</s>","<unk>","<mask>"])
    os.makedirs('./tokenizer', exist_ok=True)
    tokenizer.save_model('./tokenizer')
else:
    tokenizer = ByteLevelBPETokenizer('./tokenizer/vocab.json', './tokenizer/merges.txt')

In [None]:
class TextDataset(Dataset):
    def __init__(self, dataset_split, tokenizer, seq_len=512):
        self.examples = []
        self.seq_len = seq_len
        for text in dataset_split['text']:
            ids = tokenizer.encode(text).ids
            stride = seq_len // 2
            for i in range(0, len(ids), stride):
                chunk = ids[i:i+seq_len]
                if len(chunk) < 2:
                    continue
                self.examples.append(chunk)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        chunk = self.examples[idx]
        input_ids = chunk[:-1]
        labels = chunk[1:]
        if len(input_ids) < self.seq_len - 1:
            pad_len = (self.seq_len -1) - len(input_ids)
            input_ids = input_ids + [0]*pad_len
            labels = labels + [-100]*pad_len

        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

In [None]:
def collate_batch(batch):
    x = torch.stack([b[0] for b in batch], dim=0)
    y = torch.stack([b[1] for b in batch], dim=0)
    return x, y

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader

train_split = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train[:650000]')

val_split = load_dataset('wikitext', 'wikitext-103-raw-v1', split='validation[:20000]')

train_ds = TextDataset(train_split, tokenizer, seq_len=SEQ_LEN)
val_ds   = TextDataset(val_split, tokenizer, seq_len=SEQ_LEN)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_batch
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=True,
    collate_fn=collate_batch
)

print(f"Train samples: {len(train_ds)}")
print(f"Val samples: {len(val_ds)}")

In [None]:
###################
train_split = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train[:75000]')
val_split = load_dataset('wikitext', 'wikitext-103-raw-v1', split='validation[:5000]')
val_ds = TextDataset(val_split, tokenizer, seq_len=SEQ_LEN)
train_ds = TextDataset(dataset, tokenizer, seq_len = SEQ_LEN)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn = collate_batch)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads
        self.qkv_proj = nn.Linear(config.d_model, 3*config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        self.scale = 1/math.sqrt(self.head_dim)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv_proj(x).view(B,T,3,self.n_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = torch.matmul(q,k.transpose(-2, -1))*self.scale
        mask = torch.tril(torch.ones(T, T,device=x.device)).unsqueeze(0).unsqueeze(0)
        attn = attn.masked_fill(mask==0, -torch.inf)
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.GELU(),
            nn.Linear(config.d_ff, config.d_model),
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_len, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        tok = self.tok_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device)).unsqueeze(0).expand(B, -1, -1)
        x = tok+pos
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)

model = GPT(CONFIG)
print(f"trainabale model params: {sum(p.numel() for p in model.parameters() if p.requires_grad)//1e6:.1f}M")
print(f"total model params: {sum(p.numel() for p in model.parameters())//1e6:.1f}M")

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2.5e-4)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    model.train()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    optimizer.zero_grad()
    for step, (x,y) in enumerate(pbar):
        x,y = x.to(DEVICE), y.to(DEVICE)
        with torch.cuda.amp.autocast():
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-100)/GRAD_ACCUM
        scaler.scale(loss).backward()
        if(step+1)%GRAD_ACCUM==0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm(model.parameters(), 1, 0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            pbar.set_postfix({'loss':loss.item()*GRAD_ACCUM})

        if (step+1) % CHECKPOINT_STEPS == 0:
            ckpt_path = os.path.join(CHECKPOINT_DIR, f'gpt_epoch{epoch}_step{step+1}.pt')
            torch.save({'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'epoch': epoch, 'step': step+1}, ckpt_path)
            print(f'Checkpoint saved at step {step+1}: {ckpt_path}')


    ckpt_path = os.path.join(CHECKPOINT_DIR, f'gpt_epoch{epoch}.pt')
    torch.save({'model_state':model.state_dict(), 'optimizer_state':optimizer.state_dict(), 'epoch':epoch}, ckpt_path)
    print(f'Checkpoint saved: {ckpt_path}')

In [None]:
CHECKPOINT_DIR = './checkpoints'

files = os.listdir(CHECKPOINT_DIR)
for f in sorted(files):
    print(f)
for f in sorted(files):
    path = os.path.join(CHECKPOINT_DIR, f)
    size_mb = os.path.getsize(path) / (1024*1024)
    print(f'{f} → {size_mb:.2f} MB')

In [None]:
import os
import shutil

CHECKPOINT_DIR = './checkpoints'


shutil.rmtree(CHECKPOINT_DIR)

# Recreate the folder for new training
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
model.eval()
prompt = 'Cars need newer'
input_ids = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.long, device=DEVICE).unsqueeze(0)
for _ in range(50):
    logits = model(input_ids)
    next_id = torch.multinomial(F.softmax(logits[:,-1,:], dim=-1),1)
    input_ids = torch.cat([input_ids, next_id], dim=1)
generated = tokenizer.decode(input_ids.squeeze().tolist())
print(f"Response: {generated}\n")

In [None]:
#reload model

In [None]:
import torch
import torch.nn as nn
import os
class GPTConfig:
    def __init__(self, vocab_size=40000, d_model=512, n_heads=8, n_layers=12, d_ff=4096, max_len=512, dropout=0.05):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.d_ff = d_ff
        self.max_len = max_len
        self.dropout = dropout

CONFIG = GPTConfig()
SEQ_LEN = CONFIG.max_len
BATCH_SIZE = 8
GRAD_ACCUM = 4
EPOCHS = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
CHECKPOINT_STEPS = 12000
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads
        self.qkv_proj = nn.Linear(config.d_model, 3*config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        self.scale = 1/math.sqrt(self.head_dim)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv_proj(x).view(B,T,3,self.n_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = torch.matmul(q,k.transpose(-2, -1))*self.scale
        mask = torch.tril(torch.ones(T, T,device=x.device)).unsqueeze(0).unsqueeze(0)
        attn = attn.masked_fill(mask==0, -torch.inf)
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.GELU(),
            nn.Linear(config.d_ff, config.d_model),
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_len, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        tok = self.tok_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device)).unsqueeze(0).expand(B, -1, -1)
        x = tok+pos
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)

model = GPT(CONFIG).to(DEVICE)
print(f"trainabale model params: {sum(p.numel() for p in model.parameters() if p.requires_grad)//1e6:.1f}M")
print(f"total model params: {sum(p.numel() for p in model.parameters())//1e6:.1f}M")

In [None]:
model = GPT()
model.load_state_dict(path, map_location='cuda')
model.eval()