In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

### Decoder Transformer Architecture to Generate Stories

In [None]:
# Dataset roneneldan/TinyStories from HuggingFace
# https://huggingface.co/datasets/roneneldan/TinyStories

In [2]:
torch.set_default_device("cuda")

In [3]:
# -------------------------------------
n_embd = 396
# -------------------------------------
n_heads = 12
n_layers = 8
head_size = 396
dropout = 0.2
block_size = 128
batch_size = 18
lr = 3e-5
max_iters = 5000
# -------------------------------------
eval_interval = 500
eval_iters = 200
# -------------------------------------

In [4]:
words = open("./TinyStories-train.txt", "r", encoding='utf-8').read().lower()

In [5]:
# Using 45% of the original dataset
fractional_num = int(len(words)*0.25)
words = words[:fractional_num]
print(len(words))
punctuations = ['\t', '#', '$', '%', '&', '(', ')', '*', '+','-', '/', ':', ';', '<', '=', '>', '[', '\\', ']', '_', '`', '{', '|', '}', '~', '\xa0', '¬°', '¬¢', '¬£', '¬ß', '¬´', '\xad', '¬¥', '¬ª', '¬ø', '√†', '√°', '√¢', '√©', '√≠', '√Ø', '√±', '√≥', '√∂', '—ñ', '\u2005', '\u2009', '\u200a', '\u200b', '‚Äì', '‚Äî', '‚Äï', '‚Äò', '‚Äô', '‚Äú', '‚Äù', '‚Äû', '‚Ä¶', '\u2028', '\u2029', '‚àí', '„Äç', 'Ô¨Å', '\ufeff', 'ÔøΩ', 'ùëê', 'üôÇ']
for punc in punctuations:
    words = words.replace(punc, " ")

wanted = ["\n", ".", "!", '"', ",", '?', "'"]
for i in wanted:
    words = words.replace(i, f" {i} ")
print(len(words))

480691772
524555966


In [6]:
words[:1000]

'one day ,  a little girl named lily found a needle in her room .  she knew it was difficult to play with it because it was sharp .  lily wanted to share the needle with her mom ,  so she could sew a button on her shirt .  \n lily went to her mom and said ,   " mom ,  i found this needle .  can you share it with me and sew my shirt ?  "  her mom smiled and said ,   " yes ,  lily ,  we can share the needle and fix your shirt .  "  \n together ,  they shared the needle and sewed the button on lily \' s shirt .  it was not difficult for them because they were sharing and helping each other .  after they finished ,  lily thanked her mom for sharing the needle and fixing her shirt .  they both felt happy because they had shared and worked together .  \n   endoftext   \n once upon a time ,  there was a little car named beep .  beep loved to go fast and play in the sun .  beep was a healthy car because he always had good fuel .  good fuel made beep happy and strong .  \n one day ,  beep was d

In [7]:
words_set = sorted(list(set(words.split())))
print(words_set)
vocab_size = len(words_set)
print(f"{vocab_size} UNIQUE WORDS!")

27808 UNIQUE WORDS!


In [8]:
itos = {i: s for i, s in enumerate(words_set)}
stoi = {s: i for i, s in enumerate(words_set)}

In [9]:
encode = lambda s: [stoi[i] for i in s]
decode = lambda l: "".join([itos[i] for i in l])

In [10]:
words = words.split()

In [11]:
data = torch.tensor(encode(words))
num = int(0.9*len(data))
data_train = data[:num]
data_val = data[num:]

In [12]:
print(data_train.shape)
print(data_val.shape)

torch.Size([101556760])
torch.Size([11284085])


In [13]:
def get_batch(split):
    data = data_train if split=="train" else data_val
    idx = torch.randint(len(data)-block_size-1, (batch_size,))
    X = torch.stack([data[i:i+block_size] for i in idx])
    Y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    return X, Y

In [14]:
X, Y = get_batch("train")
print(X.shape)
print(Y.shape)
print(X)
print(Y)

torch.Size([18, 128])
torch.Size([18, 128])
tensor([[21736,     4, 24353,  ...,     4, 20929, 20070],
        [16999,     2, 20007,  ...,   845,  2948, 22068],
        [26064,   135, 24617,  ..., 20278, 12058,   500],
        ...,
        [  135,  2214, 17802,  ...,     1, 13685,     3],
        [10941,   134, 27614,  ..., 11011,  8264,     4],
        [  845,  4613,   845,  ...,   134,     1, 24797]], device='cuda:0')
tensor([[    4, 24353, 13686,  ..., 20929, 20070, 24353],
        [    2, 20007, 24378,  ...,  2948, 22068, 15997],
        [  135, 24617, 24321,  ..., 12058,   500,     4],
        ...,
        [ 2214, 17802,     4,  ..., 13685,     3, 14985],
        [  134, 27614, 21125,  ...,  8264,     4,  3265],
        [ 4613,   845, 11395,  ...,     1, 24797, 19160]], device='cuda:0')


In [15]:
class SingleHeadAttention(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.Keys = nn.Linear(n_embd, head_size, bias=False)
        self.Queries = nn.Linear(n_embd, head_size, bias=False)
        self.Values = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # X: (Batch_Size, BlockSize, EmbeddingDim)
        N, T, C = x.shape
        k = self.Keys(x)
        q = self.Queries(x)
        v = self.Values(x)

        weights = q @ k.mT
        weights = torch.masked_fill(weights, self.tril[:T, :T] == 0, -torch.inf)
        weights = torch.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        attention = weights @ v # Shape (N, T, H)
        return attention

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([SingleHeadAttention(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


In [17]:
class FeedForward(nn.Module):
    def __init__(self, fan_in=n_embd, fan_out=n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(fan_in, fan_out*4),
            nn.ReLU(),
            nn.Linear(fan_out*4, fan_out),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

In [18]:
class Block(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.l1 = nn.LayerNorm(n_embd)
        self.head = MultiHeadAttention(num_heads, head_size)
        self.l2 = nn.LayerNorm(n_embd)
        self.ffwd = FeedForward(n_embd, n_embd)
    
    def forward(self, x):
        # Should have skip connections
        x = x + self.head(self.l1(x))
        x = x + self.ffwd(self.l2(x))
        return x

In [19]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_heads, head_size) for _ in range(n_layers)])
        self.la_head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, targets=None):
        T = x.shape[1] # So that it can be used to evaluate as well, (T is not always block_size)
        word_emb = self.word_emb(x)
        pos_emb = self.pos_emb(torch.arange(T))
        x = word_emb + pos_emb
        x = self.blocks(x)
        logits = self.la_head(x)
        # print(logits.shape, targets.shape)
        if targets is None:
            loss = None
        else:
            N, T, C = logits.shape
            logits = logits.view(N*T, C)
            targets = targets.view(N*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        # idx is of shape (N, T)
        # N is probably 1 and so is T
        self.eval()
        result = []
        for _ in range(max_new_tokens):
            idx = idx[:, -block_size:]
            logits, loss = self(idx) # Logits Shape (N, T, vocab_size)
            logits = logits[:, -1, :] # Interested only in last prediction (N, vocab_size)
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, 1)
            result.append(itos[idx_next.squeeze().item()])
            idx = torch.cat((idx, idx_next), dim=1)
        self.train()
        return result


In [20]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [21]:
model = GPT()
# model.load_state_dict(torch.load("./storyGen_state_dict"))
opt = torch.optim.AdamW(model.parameters(), lr)

In [108]:
# torch.cuda.empty_cache()

In [109]:
for iter in range(20000):
    if iter % eval_interval == 0 or iter == max_iters-1:
      losses = estimate_loss()
      print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 
    Xb, Yb = get_batch("train")
    logits, loss = model(Xb, Yb)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

step 0: train loss 2.9852, val loss 2.9885
step 500: train loss 2.9803, val loss 2.9907
step 1000: train loss 2.9585, val loss 2.9895
step 1500: train loss 2.9824, val loss 2.9886
step 2000: train loss 2.9678, val loss 2.9782
step 2500: train loss 2.9751, val loss 2.9819
step 3000: train loss 2.9593, val loss 2.9721
step 3500: train loss 2.9669, val loss 2.9624
step 4000: train loss 2.9549, val loss 2.9926
step 4500: train loss 2.9576, val loss 2.9701
step 4999: train loss 2.9597, val loss 2.9616
step 5000: train loss 2.9798, val loss 2.9681
step 5500: train loss 2.9548, val loss 2.9721
step 6000: train loss 2.9611, val loss 2.9809
step 6500: train loss 2.9564, val loss 2.9877
step 7000: train loss 2.9656, val loss 2.9682
step 7500: train loss 2.9481, val loss 2.9727
step 8000: train loss 2.9497, val loss 2.9405
step 8500: train loss 2.9419, val loss 2.9577
step 9000: train loss 2.9367, val loss 2.9582
step 9500: train loss 2.9306, val loss 2.9426
step 10000: train loss 2.9419, val los

In [110]:
# TO SAVE
torch.save(model.state_dict(), "./storyGen_state_dict")
# TO LOAD
# model = GPT()
# model.load_state_dict(torch.load("./storyGen_state_dict"))

In [113]:
initial_string = "Once "
init = list(initial_string.lower().split())
init_tensor = torch.tensor(encode(init)).view(1, -1)

In [114]:
" ".join(model.generate(init_tensor, 100))

'upon a time , there was a little boy named tim . tim loved to play outside in the sandbox and enjoying . one day , tim saw his friend , " hey , tim , you can \' t lift . i \' m trying to help you . " timmy was happy to have his friend and played with the ball together . he had made the book happy to find his mom very much living . as tim finally as he sneezed the room , he accidentally every day and soon he went for the brave walk'