pip install torch
pip install datasets
pip install tokenizers
pip install wandb

In [1]:
# pyt
import torch as t
import torch.nn as nn
from torch.utils.data import DataLoader

# data pipeline
from datasets import load_dataset, DatasetDict, load_from_disk
from typing import cast
import math, random

# tokenization
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import Lowercase
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast

# logging
import os, argparse
import wandb


  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
hyper = {
    'vs': 2**13,
    'ly': 4,
    'hs': 768,
    'ah': 4,
    'cx': 512,
    'lr': 1e-4,
    'bs': 32,
    'ac': 4,
}

hyper = argparse.Namespace(**hyper)


In [3]:
dataset = cast(DatasetDict, load_dataset('skeskinen/TinyStories-Instruct-hf'))
dataset['train'].set_format(type='torch', columns=['text'])
dataset['train'].format['type']
dataset['validation'].set_format(type='torch', columns=['text'])
dataset['validation'].format['type']
print(dataset)

Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/skeskinen___parquet/skeskinen--TinyStories-Instruct-hf-1f9111cb77858404/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2476533
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 25028
    })
})


In [4]:
tok = Tokenizer(BPE())
tok.normalizer = Lowercase()
tok.pre_tokenizer = ByteLevel()
tok.decoder = ByteLevelDecoder()
tok.post_processor = TemplateProcessing(single='$0 <|endoftext|>', special_tokens=[('<|endoftext|>', 1)],)
tok.enable_truncation(max_length=hyper.cx)
tok.enable_padding(pad_token='<pad>', length=hyper.cx)
trainer = BpeTrainer(vocab_size=hyper.vs, initial_alphabet=ByteLevel.alphabet(), special_tokens=['<pad>', '<|endoftext|>', '\n','Words: ', 'Features: ', 'Random sentence: ', 'Summary: ', 'Story: '])

In [5]:
if os.path.isfile('tiny.json'): tok = Tokenizer.from_file('tiny.json')
else: tok.train_from_iterator(dataset['train']['text'], trainer=trainer); tok.save('tiny.json')

tok = PreTrainedTokenizerFast(tokenizer_object=tok)
tok.pad_token = 0


In [6]:
def tokenization(example):
    return tok(example['text'], truncation=True, max_length=hyper.cx, padding='max_length')

if os.path.exists('train_dataset') and os.path.exists('valid_dataset'):
    train = load_from_disk('train_dataset')
    valid = load_from_disk('valid_dataset')
else:
    train = dataset['train'].map(tokenization, batched=True)
    valid = dataset['validation'].map(tokenization, batched=True)
    train.save_to_disk('train_dataset')
    valid.save_to_disk('valid_dataset')


In [7]:
train.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
train.format['type']
valid.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
valid.format['type']

'torch'

In [8]:
trainl = DataLoader(train, batch_size=hyper.bs, shuffle=True)
validl = DataLoader(valid, batch_size=hyper.bs, shuffle=True)

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, dropout = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = t.arange(max_len).unsqueeze(1)
        div_term = t.exp(t.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = t.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = t.sin(position * div_term)
        pe[:, 0, 1::2] = t.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
    
class trans(nn.Module):
    def __init__(self):
        super().__init__()
        self.inbed = nn.Embedding(hyper.vs, hyper.hs)
        self.posit = PositionalEncoding(hyper.hs, hyper.cx)
        self.think = nn.TransformerEncoderLayer(d_model=hyper.hs, nhead=hyper.ah, dim_feedforward=hyper.hs*4, activation='gelu')
        self.thnkr = nn.TransformerEncoder(self.think, num_layers=hyper.ly)
        self.speak = nn.Linear(hyper.hs, hyper.vs)
        self.cmask= t.triu(t.ones(hyper.cx, hyper.cx) * float('-inf'), diagonal=1)
    def forward(self, x, pask=None):
        x = self.inbed(x) * (hyper.hs ** .5)
        x = self.posit(x)
        x = self.thnkr(x, is_causal=True, mask=pask if pask is not None else self.cmask)
        return self.speak(x)


In [10]:
storytell = trans().to('cuda')

print(f'There are {round((sum(p.numel() for p in storytell.parameters()) - hyper.vs*hyper.hs*2)/1e6, 1)} million parameters in the model, plus {round((hyper.vs*hyper.hs*2)/1e6, 1)} million embeddings parameters.')

There are 35.4 million parameters in the model, plus 12.6 million embeddings parameters.


In [11]:
optim = t.optim.Adam(storytell.parameters(), lr=hyper.lr)

In [12]:
lossf = nn.CrossEntropyLoss()

In [13]:
step = 0
for batch in trainl:
    step += 1
    seq = batch['input_ids'].to('cuda')
    out = storytell(seq)
    loss = lossf(t.flatten(out, end_dim=1), t.flatten(t.roll(seq, -1)))
    loss.backward()

    if (step % hyper.ac == 0) or (step + 1 == len(trainl)):
            optim.step()
            optim.zero_grad()
            print(f'Step {step} of {len(trainl)}: loss {loss.item()}')

    if step % 500 == 0:
        t.save(storytell.state_dict(), f'story_model_{step}.pt')
        t.save(optim.state_dict(), f'story_optim_{step}.pt')
    
        with t.no_grad():
            tloss = 0
            steps = 0
            storytell.eval()
            for batch in validl:
                seq = batch['input_ids'].to('cuda')
                out = storytell(seq)
                tloss += lossf(t.flatten(out, end_dim=1), t.flatten(t.roll(seq, -1))).item()
                steps += 1
            print(f'validation: loss {tloss/steps}')
            storytell.train()

Step 5 of 77392: loss 9.157788276672363
Step 9 of 77392: loss 6.0940351486206055
Step 13 of 77392: loss 4.536830425262451
Step 17 of 77392: loss 4.138949871063232
Step 21 of 77392: loss 4.101062297821045
Step 25 of 77392: loss 4.0288286209106445
Step 29 of 77392: loss 3.9099535942077637
Step 33 of 77392: loss 4.300177574157715
Step 37 of 77392: loss 3.7145438194274902
Step 41 of 77392: loss 3.786059856414795
Step 45 of 77392: loss 4.090689182281494
Step 49 of 77392: loss 3.7991576194763184
Step 53 of 77392: loss 3.516782522201538
Step 57 of 77392: loss 3.7490367889404297
Step 61 of 77392: loss 3.5259861946105957
Step 65 of 77392: loss 3.1983258724212646
Step 69 of 77392: loss 3.8269405364990234
Step 73 of 77392: loss 3.2255895137786865
Step 77 of 77392: loss 3.0665438175201416
Step 81 of 77392: loss 2.8591346740722656
Step 85 of 77392: loss 3.1582348346710205
Step 89 of 77392: loss 3.2537848949432373
Step 93 of 77392: loss 2.9473795890808105
Step 97 of 77392: loss 3.014868974685669
Ste

TypeError: 'DataLoader' object is not subscriptable