In [45]:
import os, sys
import ipdb
from tqdm import tqdm
from datetime import datetime
import requests, zipfile, io

import torch
import torch.nn as nn
from torch.nn import functional as F

# tokenizer
import sentencepiece as spm

# thsese improve performance for Ampere architecture
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()

In [46]:
# files_url = "https://ideami.com/llm_train"
# print("Downloading dataset...")
# response = requests.get(files_url)
# zipfile.ZipFile(io.BytesIO(response.content)).extractall(".")

In [47]:
# architecture parameters
batch_size = 8
context = 512
embed_size = 384
n_layers = 7
n_heads = 7
BIAS = True

# hyperparameters
lr = 3e-4
dropout = 0.05
weight_decay = 0.01
grad_clip = 1.0

# training parameters
train_iters = 100000
eval_interval = 50
eval_iters = 10
compile = True
checkpoint_dir = 'models/'
checkpoint_load_fn = 'latest.pt'
load_pretrained = True
dtype = torch.bfloat16

# MODE 
inference = False

# DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)



cuda


In [48]:
# logging
wandb_log = True
wandb_project = 'llm-from-scratch'
# wandb_run_name = 'llm1-' + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
wandb_run_name = 'llm-test-run'

if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

In [49]:
with open("wiki.txt", "r", encoding="utf-8") as f:
    text = f.read()

print(text[10000:10300])

 that was used to represent a team in an old TV show, The A-Team. A capital a is written "A". Use a capital A at the start of a sentence if writing.

A is also a musical note, sometimes referred to as "La".

The letter 'A' was in the Phoenician alphabet's aleph. This symbol came from a simple pictur


In [50]:
# tokenizer
sp = spm.SentencePieceProcessor(model_file="wiki_tokenizer.model")

vocab_size = sp.get_piece_size()
print(f"Tokenizer vocab_size: {vocab_size}")

Tokenizer vocab_size: 4096


In [51]:
encode = lambda s: sp.Encode(s)
decode = lambda l: sp.Decode(l)

zdanie = "niebo jest niebieskie"
print(encode(zdanie))
print(decode(encode(zdanie)))

[316, 428, 4052, 4037, 599, 395, 316, 428, 4052, 412, 4055, 428]
niebo jest niebieskie


In [52]:
if os.path.exists("encoded_data.pt"):
    print("Loading encoding")
    data = torch.load("encoded_data.pt")
else:
    data = torch.tensor(encode(text), dtype=torch.long)
    torch.save(data, "encoded_data.pt")


Loading encoding


In [53]:
data_size = len(data)
splt = int(0.9 * data_size)
train_data = data[:splt]
val_data = data[splt:]

print(f"Total data: {data_size / 1e6:.2f} Million | Training: {len(train_data) / 1e6:.2f} Million | Validation {len(val_data) / 1e6:.2f} Million ")

Total data: 59.21 Million | Training: 53.29 Million | Validation 5.92 Million 


In [54]:
def get_batch(split):
    data = train_data if split=="train" else val_data
    indeces = torch.randint(len(data) - context, (batch_size,))
    x = torch.stack([data[i: i+context] for i in indeces]) # (batch_size, sequence_length)
    y = torch.stack([data[i+1:i+context+1] for i in indeces])

    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch("train")
print(x.shape, y.shape)
print(x[0][:10])
print(y[0][:10])

torch.Size([8, 512]) torch.Size([8, 512])
tensor([  13,  764, 1674,  879,  266, 1836,  299,  264,  926, 1836],
       device='cuda:0')
tensor([ 764, 1674,  879,  266, 1836,  299,  264,  926, 1836,  280],
       device='cuda:0')


In [59]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_size) # 3096 x 384
        self.positions = nn.Embedding(context, embed_size) # 512 x 384
        # self.blocks = nn.Sequential(*[Block(n_heads) for _ in range(n_layers)])
        self.layer_normalisation = nn.LayerNorm(embed_size)
        self.final_linear = nn.Linear(embed_size, vocab_size, bias=BIAS) # 384 x 4096
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input, targets=None):
        loss = None
        BS, SL = input.shape # BS x SL
        emb = self.embeddings(input) # BSS x SL x 384
        pos = self.positions(torch.arange(SL, device=device)) # SL x 384
        x = emb + pos # BS x SL x 384
        # x = self.blocks(x) # BS x SL x 384
        x = self.layer_normalisation(x) # BS x SL x Embedding size
        logits = self.final_linear(x) # BS x SL x vocab_size (4096)

        if targets is not None:
            BS, SL, vocabsize = logits.shape
            logits = logits.view(BS * SL, vocabsize)
            targets = targets.view(BS * SL)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
            
    def generate(self, input, max=500):
        for _ in range(max):
            input = input[:, -context:] # (1, input length until max of sequence length)
            logits, _ = self(input) # (1, input length, vocab_size)
            logits = logits[:, -1, :] # pick last logit (1, vocab_size)
            probs = F.softmax(logits, dim=-1) # (1, vocab_size)
            next = torch.multinomial(probs, num_samples=1)
            input = torch.cat((input, next), dim=1)
        return input


In [60]:
class Block(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        head_size = embed_size // n_heads
        self.multi_attention = Multihead(n_heads, head_size)
        self.feed_forward = ForwardLayer(embed_size)
        self.ln1 = nn.LayerNorm(embed_size)
        self.ln2 = nn.LayerNorm(embed_size)

    def forward(self, x):
        x = x + self.multi_attention(self.ln1)
        x = x + self.feed_forward(self.ln2(x))
        return x  

In [61]:
class ForwardLayer(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(embed_size, 6*embed_size, bias=BIAS),
            nn.GELU(),
            nn.Linear(6*embed_size, embed_size, bias=BIAS),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.network(x)
        return x

In [62]:
class Multihead(nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.combine = nn.Linear(head_size * n_heads, embed_size, bias=BIAS)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=1)
        x = self.combine(x) # (BS, SL, 384)
        x = self.dropout(x)
        return x


In [63]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.queries = nn.Linear(embed_size, head_size, bias=BIAS)
        self.keys = nn.Linear(embed_size, head_size, bias=BIAS)
        self.values = nn.Linear(embed_size, head_size, bias=BIAS)

        self.register_buffer("tril", torch.tril(torch.ones(context, context)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        BS, SL, VS = x.shape
        q = self.queries(x) # BS x SL x 54
        k = self.keys(x) # BS x SL x 54
        v = self.values(x) # BS x SL x 54

        # attention scores
        attn_w = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # BS, SL, SL
        attn_w = attn_w.masked_fill(self.tril[:SL, :SL]==0, float('-inf'))
        attn_w = F.softmax(attn_w, dim=-1) # BS x SL x SL

        x = attn_w @ v # BS x SL x 54
        return x

In [64]:
x, y = get_batch("train")
model = GPT()
model = model.to(dtype)
model = model.to(device)
logits, loss = model(x, y)
print(loss.item())


8.375


In [None]:
@torch.no_grad
def generate_sample(input):
    t1 = torch.tensor(encode(input), dtype=torch.long, device=device)
    t1 = t1[None, :]
    newgen = model.generate(t1, max=64)[0].tolist()
    result = decode(newgen)
    print(f"{result}")

# generate_sample("Once upon a time")


Once upon a timeihnuptside pres heartake bookanioachesokeouri owneball based Loveecut� frog Atlantive got album l federal Sm Februarychestra Laounds crick supportamily years here canton eng beginningfessachusetts�minarch released industryivid tournament these live recordedc Siless womenTealand pointhic Walesrod Mont�ance


In [66]:
# TRAINING SETUP

model = GPT()
model = model.to(dtype)
model = model.to(device)
if compile:
    print("Torch :: Compiling model")
    model = torch.compile(model)

print(sum(p.numel() for p in model.parameters()) / 1e6, " Million parameters")



Torch :: Compiling model
3.3472  Million parameters


In [67]:
@torch.no_grad()
def calculate_loss():
    out = {}
    model.eval()

    for split in ["train", "eval"]:
        l = torch.zeros(eval_iters)
        for i in range(eval_iters):
            x, y = get_batch(split)
            _, loss = model(x, y)
            l[i] = loss
        out[split] = l.mean().item()
    model.train()
    return out

l = calculate_loss(
)
print(l)



{'train': 8.375, 'eval': 8.375}


In [68]:
# setting up the optimizer
parameter_dict = {p_name: p for p_name, p in model.named_parameters() if p.requires_grad}
weight_decay_p = [p for n, p in parameter_dict.items() if p.dim() >= 2]
no_weight_decay_p = [p for n, p in parameter_dict.items() if p.dim() < 2]
optimizer_groups = [
    {
        'params': weight_decay_p, 
        'weight_decay': weight_decay

    },
    {
        'params': no_weight_decay_p, 'weight_decay': 0.0
    }
]

optimizer = torch.optim.AdamW(optimizer_groups, lr=lr, betas=(0.9, 0.99))

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_iters, eta_min=lr/10)

start_iteration = 0
best_val_loss = float('inf')

In [69]:
# loading checkpoints

def load_checkpoint(path):
    print("LLM - Loading model")
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    iteration = checkpoint['iteration']
    loss = checkpoint['loss']
    print(f"Loaded iter {iteration} with loss {loss}")
    return iteration, loss

if os.path.exists(f"{checkpoint_dir}/{checkpoint_load_fn}") and load_pretrained:
    start_iteration, loss = load_checkpoint(checkpoint_dir + checkpoint_load_fn)
    best_val_loss = loss

In [70]:
# inference
if inference is True:
    model.eval()
    while True:
        qs = input("Enter text (q to quit):\n")
        if qs=="":
            continue
        if qs == "q":
            break
        generate_sample(qs)

In [71]:
# training loop

try:
    for i in tqdm(range(start_iteration, train_iters)):
        xb, yb = get_batch("train")
        logits, loss = model(xb, yb)

        if (i % eval_interval==0 or i == train_iters - 1):
            l=calculate_loss()
            print(f"\n{i}: train loss: {l['train']} | val loss: {l['eval']}")
            generate_sample("Once upon a time")

            if l['eval'] < best_val_loss:
                best_val_loss = l['eval']
                print('[CHECKPOINT]: Saving with loss: ', best_val_loss)
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_val_loss,
                    'iteration': i,

                }, checkpoint_dir + checkpoint_load_fn)

                if wandb_log:
                    wandb.log({
                        "loss/train": l["train"],
                        "loss/val": l["eval"],
                        "lr": scheduler.get_last_lr()[0],

                    }, step = i)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
        optimizer.step()
        scheduler.step()

    if wandb_log:
        wandb.finish()
except KeyboardInterrupt:
    print("Training interrupted. Cleaning up...")

finally:
    torch.cuda.empty_cache()
    print("GPU memory released")
    

  0%|          | 0/100000 [00:00<?, ?it/s]


0: train loss: 8.375 | val loss: 8.375
Once upon a time conduct Tor Britainundred To Associationisc places). seen en designed inter Pop couriv Gl Budd involency lawyer stay because taking y party neigh universityulalishedestival ordercol whichalthreld countries weremaix couldove Europeanuctiter childrenufros and Arab chang Jackagob movie dataumentsuments Ge Franceof god
[CHECKPOINT]: Saving with loss:  8.375


  0%|          | 50/100000 [00:32<15:54:05,  1.75it/s]


50: train loss: 6.962500095367432 | val loss: 6.915625095367432


  0%|          | 51/100000 [00:37<53:27:32,  1.93s/it]

Once upon a timeemb returnous early et Maxokurp him Song i collegeimaovie eas year data baseract eng dem react the Ju time energyways House Philardsower information Class clatedation v frelf killedence modernthingated snlandsperadu Rich sett fam and6 Tur Mr Gu ItWation Earthized mediaaster
[CHECKPOINT]: Saving with loss:  6.915625095367432


  0%|          | 100/100000 [00:58<12:05:00,  2.30it/s]


100: train loss: 5.884375095367432 | val loss: 5.787499904632568


  0%|          | 101/100000 [01:00<31:20:25,  1.13s/it]

Once upon a time, and. It is governor thous designa kD lineatch-'se) all treidd An times that thearck, 3 ( is does to "P Leide Johnson caused December 19) is a Hockey aky It isinn rot is roform, the sameors6
[CHECKPOINT]: Saving with loss:  5.787499904632568


  0%|          | 138/100000 [01:15<15:15:08,  1.82it/s]


Training interrupted. Cleaning up...
GPU memory released
Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7fd7d112da90>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fd84183f770, execution_count=71 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7fd8407039d0, raw_cell="# training loop

try:
    for i in tqdm(range(star.." transformed_cell="# training loop

try:
    for i in tqdm(range(star.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/maciej/workspace/LLMUdemy/main.ipynb#X31sZmlsZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost