In [1]:
import json
import glob
import math
import os
import torch
import pathlib
import time
import tiktoken

import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as torch_data
import torchvision as tv



In [2]:
class Dataset(torch_data.Dataset):
    def __init__(self, tokenizer, len_context):

        self.tokenizer = tiktoken.encoding_for_model(tokenizer)
        self.len_context = len_context
        self.elements = sorted(glob.glob(f"/mnt/data/wikipedia/tokens/{tokenizer}/{len_context}/*")) 

    def __len__(self):
        return len(self.elements)
    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        tokens_filename = self.elements[idx]
        filename = tokens_filename.split("/")[-1]
        text_filename = os.path.join(f"/mnt/data/wikipedia/text", filename)
        
        with open(tokens_filename, "rb") as f:
            tokens = json.load(f)["tokens"]

        with open(text_filename, "rb") as f:
            text = json.load(f)["text"]

        tokens = tokens[:self.len_context + 1]
        tokens = torch.from_numpy(np.array(tokens)).to(torch.long)
        
        element = dict()
        element["text"] = text
        element["tokens"] = tokens

        return element

In [3]:
SHUFFLE = True
NUM_WORKERS = 16
DEVICE = "cuda"
LR = 1e-4
N_EPOCHS = 100
BATCH_SIZE = 32 + 8
BATCH_SUM = 32

TOKENIZER = "gpt2"
LEN_CONTEXT = 256
N_EMB = 2048
N_HEADS = 12
N_LAYERS = 7
DROPOUT = 0.1

FILE_PATH = "models/text_embedding.pth"

In [4]:
dataset = Dataset(TOKENIZER, LEN_CONTEXT)
VOCAB_SIZE = dataset.tokenizer.n_vocab
loader = torch_data.DataLoader(dataset, shuffle=SHUFFLE, pin_memory=True, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE)
print(len(dataset))

3735788


In [5]:
element = dataset[np.random.randint(0, len(dataset))]
tokens = element["tokens"]
print(tokens.shape)
batch = next(iter(loader))
tokens = batch["tokens"]
print(tokens.shape)

torch.Size([257])
torch.Size([40, 257])


In [6]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
class FeedFoward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
class Block(nn.Module):
    def __init__(self, n_head, n_embd, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ffwd = FeedFoward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
class Model(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)

        shape = 0
        blocks = []
        for i in range(n_layer):
            blocks.append(Block(n_head, n_embd // 2**i, block_size, dropout))
            blocks.append(nn.AvgPool1d(kernel_size=2, stride=2))
            shape = n_embd // 2**(i + 1)
        self.down_blocks = nn.Sequential(*blocks)
        
        blocks = []
        for i in range(n_layer):
            blocks.append(Block(n_head, shape, block_size, dropout))
            blocks.append(nn.Upsample(scale_factor=2, mode='nearest'))
            shape = shape * 2
        self.up_blocks = nn.Sequential(*blocks)
        
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(idx.shape[1], device=idx.device))
        x = tok_emb + pos_emb 
        latent = self.down_blocks(x)
        x = self.up_blocks(latent)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits, latent

In [7]:
def save_checkpoint(epoch, batch, batch_acc, token_count, losses, ppls, model, optimizer, scaler, path):
    checkpoint = dict()
    checkpoint["epoch"] = epoch
    checkpoint["batch"] = batch
    checkpoint["batch_acc"] = batch_acc
    checkpoint["token_count"] = token_count
    checkpoint["model_state_dict"] = model.state_dict()
    checkpoint["optimizer_state_dict"] = optimizer.state_dict()
    checkpoint["scaler_state_dict"] = scaler.state_dict()
    checkpoint["losses"] = losses
    checkpoint["ppls"] = ppls
    
    path = path.split(".")[0] + "-checkpoint"
    
    elements = [int(element.split("-")[-1].split(".")[0]) for element in glob.glob(f"{path}*")]
    max_idx = np.max(elements)
    new_idx = max_idx + 1

    new_path = f"{path}-{new_idx:05d}.pth"
    print()
    print(f"Save file at: {new_path} ...")
    torch.save(checkpoint, f"{path}-{new_idx:05d}.pth")
    time.sleep(1)
    print("... finished!")
    print()
def load_checkpoint(path=None):

    with torch.no_grad():
        model = Model(vocab_size=VOCAB_SIZE, n_embd=N_EMB, block_size=LEN_CONTEXT, n_head=N_HEADS, n_layer=N_LAYERS, dropout=DROPOUT).to(DEVICE)
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
        scaler = torch.cuda.amp.GradScaler(enabled=True)
        
        if path is not None:
            checkpoint = torch.load(path)
            
            model.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            scaler.load_state_dict(checkpoint["scaler_state_dict"])
        
            token_count = checkpoint["token_count"]
        
            losses = checkpoint["losses"]
            ppls = checkpoint["ppls"]
        
            n_epoch_ = checkpoint["epoch"]
            n_ = checkpoint["batch"] + 1
            bs_ = checkpoint["batch_acc"]
        else:
            token_count = 0
            losses = [[0, 0, 0, 0]]
            ppls = [[0, 0, 0, 0]]
            n_epoch_ = 0
            n_ = 0
            bs_ = 0

        # print("Compile model...")
        # model = torch.compile(model)
        # print("...done!")
        # print()
    
        return model, optimizer, scaler, token_count, losses, ppls, n_epoch_, n_, bs_  
def train(loader, path, n_save=25):

    model, optimizer, scaler, token_count, losses, ppls, n_epoch_, n_, bs_ = load_checkpoint(path)

    print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000:.2f} Mio")
    print()
    
    model.train()
    
    iterator = iter(loader)
    N = int(np.floor(len(loader) / BATCH_SUM))
    
    for n_epoch in range(n_epoch_, N_EPOCHS):
        for n in range(n_, N):
            for bs in range(BATCH_SUM):
                
                batch = next(iterator)
            
                in_tokens = batch["tokens"][:, :-1].to(DEVICE)
                out_tokens = batch["tokens"][:, 1:].to(DEVICE)
                
                logits, latent = model(in_tokens)
                loss = F.cross_entropy(logits.view(out_tokens.shape[0] * out_tokens.shape[1], -1), out_tokens.view(-1))
                scaler.scale(loss).backward()

                losses.append((n_epoch, n, bs, loss.item()))
                ppls.append((n_epoch, n, bs, torch.exp(loss).item()))
                token_count += len(batch) * LEN_CONTEXT

                losses_ = np.asarray(losses)[-10000:, -1]
                ppls_ = np.asarray(ppls)[-10000:, -1]
                print(f"\r{n_epoch + 1:03d}|{N_EPOCHS}, {n + 1:04d}|{N}, {bs + 1:03d}|{BATCH_SUM}, loss: {losses_.mean():.5f}, ppl: {ppls_.mean():010.5f}, {token_count / 1_000_000:.5f} Mio tokens.", end="")

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

            if (n + 1) % 25 == 0:
                print()

            if (n + 1) % n_save == 0:
                save_checkpoint(n_epoch, n, bs, token_count, losses, ppls, model, optimizer, scaler, FILE_PATH)

In [8]:
# with torch.no_grad():
#     model, optimizer, scaler, token_count, losses, ppls, n_epoch_, n_, bs_ = load_checkpoint(f"models/text_embedding-checkpoint-00000.pth")
#     model.eval()
#     batch = next(iter(loader))
#     in_tokens = batch["tokens"].to(DEVICE)
#     logits, latent = model(in_tokens)
#     print(logits.shape)
#     print(latent.shape)
#     logits.detach().cpu().numpy()
#     latent.detach().cpu().numpy()

In [9]:
train(loader, f"models/text_embedding-checkpoint-00000.pth", n_save=250)
# train(loader, None, n_save=1)

Number of parameters: 290.22 Mio

001|100, 0775|2918, 032|32, loss: 4.91671, ppl: 0138.06357, 12.69760 Mio tokens.
001|100, 0800|2918, 032|32, loss: 4.88770, ppl: 0134.02805, 13.10720 Mio tokens.
001|100, 0825|2918, 032|32, loss: 4.85912, ppl: 0130.18149, 13.51680 Mio tokens.
001|100, 0850|2918, 032|32, loss: 4.83248, ppl: 0126.71584, 13.92640 Mio tokens.
001|100, 0875|2918, 032|32, loss: 4.80753, ppl: 0123.54368, 14.33600 Mio tokens.
001|100, 0900|2918, 032|32, loss: 4.78333, ppl: 0120.55527, 14.74560 Mio tokens.
001|100, 0925|2918, 032|32, loss: 4.76140, ppl: 0117.90247, 15.15520 Mio tokens.
001|100, 0950|2918, 032|32, loss: 4.73935, ppl: 0115.29356, 15.56480 Mio tokens.
001|100, 0975|2918, 032|32, loss: 4.71858, ppl: 0112.90876, 15.97440 Mio tokens.
001|100, 1000|2918, 032|32, loss: 4.69830, ppl: 0110.62538, 16.38400 Mio tokens.

Save file at: models/text_embedding-checkpoint-00001.pth ...
... finished!

001|100, 1025|2918, 032|32, loss: 4.67830, ppl: 0108.41878, 16.79360 Mio tokens

KeyboardInterrupt: 