```
pip3 install -U torch --index-url https://download.pytorch.org/whl/cu118
pip install wandb==0.14.0
wandb login e5292edda95a11630042fdf943d60d2bbf749fcf
pip install datasets
pip install tokenizers
pip install transformers
```

In [1]:
# pyt
import torch as t
import torch.nn as nn
from torch.utils.data import DataLoader

# data pipeline
from datasets import load_dataset, DatasetDict, load_from_disk
from typing import cast
import math, random

# tokenization
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import Lowercase
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast

# logging
import os, argparse
import wandb


In [2]:
t.set_default_device('mps')

In [3]:
hyper = {
    'vs': 2**13,
    'ly': 4,
    'hs': 768,
    'ah': 4,
    'cx': 512,
    'lr': 1e-4,
    'bs': 256,
    'ac': 4,
    'ep': 10,
}

hyper = argparse.Namespace(**hyper)


In [7]:
dataset = cast(DatasetDict, load_dataset('skeskinen/TinyStories-Instruct-hf'))
dataset['train'].set_format(type='torch', columns=['text'])
dataset['train'].format['type']
dataset['validation'].set_format(type='torch', columns=['text'])
dataset['validation'].format['type']
print(dataset)

Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/skeskinen___parquet/skeskinen--TinyStories-Instruct-hf-1f9111cb77858404/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2476533
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 25028
    })
})


In [14]:
tok = Tokenizer(BPE())
tok.normalizer = Lowercase()
tok.pre_tokenizer = ByteLevel()
tok.decoder = ByteLevelDecoder()
tok.post_processor = TemplateProcessing(single='$0 <|endoftext|>', special_tokens=[('<|endoftext|>', 1)],)
tok.enable_truncation(max_length=hyper.cx)
tok.enable_padding(pad_token='<pad>', length=hyper.cx)
trainer = BpeTrainer(vocab_size=hyper.vs, initial_alphabet=ByteLevel.alphabet(), special_tokens=['<pad>', '<|endoftext|>', '\n','Words: ', 'Features: ', 'Random sentence: ', 'Summary: ', 'Story: '])

In [15]:
if os.path.isfile('tiny.json'): tok = Tokenizer.from_file('tiny.json')
else: tok.train_from_iterator(dataset['train']['text'], trainer=trainer); tok.save('tiny.json')

tok = PreTrainedTokenizerFast(tokenizer_object=tok)
tok.pad_token = 0


In [4]:
def tokenization(example):
    return tok(example['text'], truncation=True, max_length=hyper.cx, padding='max_length')

if os.path.exists('train_dataset') and os.path.exists('valid_dataset'):
    train = load_from_disk('train_dataset')
    valid = load_from_disk('valid_dataset')
else:
    train = dataset['train'].map(tokenization, batched=True, batch_size=8192, writer_batch_size=8192)
    valid = dataset['validation'].map(tokenization, batched=True, batch_size=8192, writer_batch_size=8192)
    train.save_to_disk('train_dataset')
    valid.save_to_disk('valid_dataset')


In [5]:
train.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
train.format['type']
valid.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])
valid.format['type']

'torch'

In [6]:
trainl = DataLoader(train, batch_size=hyper.bs, shuffle=True, drop_last=True)
validl = DataLoader(valid, batch_size=hyper.bs, shuffle=True, drop_last=True)

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, dropout = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = t.arange(max_len).unsqueeze(1)
        div_term = t.exp(t.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = t.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = t.sin(position * div_term)
        pe[:, 0, 1::2] = t.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
    
class trans(nn.Module):
    def __init__(self):
        super().__init__()
        self.inbed = nn.Embedding(hyper.vs, hyper.hs)
        self.posit = PositionalEncoding(hyper.hs, hyper.cx)
        self.think = nn.TransformerEncoderLayer(d_model=hyper.hs, nhead=hyper.ah, dim_feedforward=hyper.hs*4, activation='gelu')
        self.thnkr = nn.TransformerEncoder(self.think, num_layers=hyper.ly)
        self.speak = nn.Linear(hyper.hs, hyper.vs)
        self.cmask= t.triu(t.ones(hyper.cx, hyper.cx) * float('-inf'), diagonal=1)
    def forward(self, x):
        x = self.inbed(x) * (hyper.hs ** .5)
        x = self.posit(x)
        x = self.thnkr(x, mask=self.cmask, is_causal=True)
        return self.speak(x)


In [14]:
# run = wandb.init(
#     project="tinystories",
#     config={
#         "learning_rate": hyper.lr,
#         "epochs": 10,
#     })

In [8]:
storytell = trans()

print(f'There are {round((sum(p.numel() for p in storytell.parameters()) - hyper.vs*hyper.hs*2)/1e6, 1)} million parameters in the model, plus {round((hyper.vs*hyper.hs*2)/1e6, 1)} million embeddings parameters.')

There are 35.4 million parameters in the model, plus 12.6 million embeddings parameters.


In [16]:
# wandb.watch(storytell, log_freq=100)

In [9]:
optim = t.optim.Adam(storytell.parameters(), lr=hyper.lr)

In [10]:
lossf = nn.CrossEntropyLoss()

In [12]:
optim.load_state_dict(t.load('story_optim_6400.pt', map_location='mps'))

In [29]:
storytell.load_state_dict(t.load('story_model_18600.pt', map_location='mps'))

<All keys matched successfully>

In [21]:
step = 0
for epoch in range(hyper.ep):
    for batch in trainl:
        seq = batch['input_ids'].to(device)
        out = storytell(seq)
        loss = lossf(t.flatten(out, end_dim=1), t.flatten(t.roll(seq, -1)))
        loss.backward()

        optim.step()
        optim.zero_grad()

        if step % 10 == 0:
            print(f'Step {step} of {len(trainl)}: loss {loss.item()}')

        if step % 100 == 0:
            t.save(storytell.state_dict(), f'story_model_{step}.pt')
            t.save(optim.state_dict(), f'story_optim_{step}.pt')
        
            with t.no_grad():
                tloss = 0
                steps = 0
                storytell.eval()
                for batch in validl:
                    seq = batch['input_ids'].to(device)
                    out = storytell(seq)
                    tloss += lossf(t.flatten(out, end_dim=1), t.flatten(t.roll(seq, -1))).item()
                    steps += 1
                print(f'validation: loss {tloss/steps}')
                storytell.train()
        step += 1

Step 0 of 9673: loss 1.7828541994094849
validation: loss 1.7633733073460687
Step 10 of 9673: loss 1.7870320081710815
Step 20 of 9673: loss 1.7119168043136597
Step 30 of 9673: loss 1.7623684406280518
Step 40 of 9673: loss 1.7257035970687866
Step 50 of 9673: loss 1.8138842582702637
Step 60 of 9673: loss 1.8084945678710938
Step 70 of 9673: loss 1.7809900045394897
Step 80 of 9673: loss 1.8043197393417358
Step 90 of 9673: loss 1.8057655096054077
Step 100 of 9673: loss 1.8252079486846924
validation: loss 1.7632240179887753
Step 110 of 9673: loss 1.76325523853302
Step 120 of 9673: loss 1.7604920864105225
Step 130 of 9673: loss 1.816753625869751
Step 140 of 9673: loss 1.8630592823028564
Step 150 of 9673: loss 1.772922396659851
Step 160 of 9673: loss 1.7300993204116821
Step 170 of 9673: loss 1.7685534954071045
Step 180 of 9673: loss 1.736989974975586
Step 190 of 9673: loss 1.7772109508514404
Step 200 of 9673: loss 1.836031198501587
validation: loss 1.7626174479415737
Step 210 of 9673: loss 1.75

In [22]:

# idx = random.randint(0, len(valid) - 1)
# print(idx)

# print(tok.decode(valid['input_ids'][idx]))

# print('model gen:')
# print(tok.decode(storytell(valid['input_ids'][idx].unsqueeze(0).to(t.long)).argmax(dim=-1)[0]))

In [32]:
def generate_text(prompt, model, tokenizer, temperature=1.0, max_len=512):
    model.eval()
    with t.no_grad():
        input_ids = tokenizer.encode(prompt, return_tensors='pt')[:, :-1]
        print(input_ids.shape)
        cur_len = input_ids.shape[1]
        while cur_len < max_len:
            outputs = model(input_ids)
            next_token_logits = outputs[0][-1, :] / temperature
            # next_token_logits[:5] = -float('inf')
            next_token_id = t.multinomial(t.softmax(next_token_logits, dim=-1), num_samples=1).unsqueeze(-1)
            input_ids = t.cat([input_ids, next_token_id], dim=1)
            cur_len += 1
            if next_token_id[0][0] == tokenizer.eos_token_id:
                break
        return tokenizer.decode(input_ids.squeeze()[55:], skip_special_tokens=False)
    
idx = random.randint(0, len(valid) - 1)
prompt = tok.decode(valid['input_ids'][idx][:60])
print(prompt)
print(generate_text(prompt, storytell, tok, temperature=.7, max_len=200))

Features:  dialogue
Words:  prevent, worry, calm
Story:  once upon a time, in a calm little town, there lived a boy named tom. tom had a big worry. he was scared of the dark. every night, when it was time to sleep, tom would cry.

 one day
torch.Size([1, 60])
.

 one day, "wow!" tom was afraid and have to his mom's okay, and said, "ok, "hello, there was a way, he would be careful with her friend, lily's mom and the butterfly inside the dog kept walking through the train arrived at the yard.

Words:  attach, "why are you!








 tom says, lived in the empty


Story: 

Story:  once upon a little girl named lily finds a big hug their mouths.





Story:  once upon a time, but she got really fun. he couldn't worry, but then, the little girl named spot and a big dog was a
