In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

### Decoder Transformer Architecture to Generate Stories

In [2]:
# Dataset roneneldan/TinyStories from HuggingFace
# https://huggingface.co/datasets/roneneldan/TinyStories

In [3]:
torch.set_default_device("cuda")

In [4]:
# -------------------------------------
n_embd = 192
# -------------------------------------
n_heads = 16
n_layers = 8
head_size = 12
dropout = 0.2
block_size = 128
batch_size = 44
lr = 3e-5
max_iters = 300000
# -------------------------------------
eval_interval = 5000
eval_iters = 200
# -------------------------------------

In [5]:
# Preprocessing dataset!
# words = open("./TinyStories-train.txt", "r", encoding='utf-8').read().lower()
# chars = " abcdefghijklmnopqrstuvwxyz,'\".\n"
# words = "".join([char for char in words if char in chars])
# for char in ",'\".\n":
#     words = words.replace(char, f" {char} ")
# len(words)
# f = open("preprocessed.txt", "w")
# f.write(words)

In [5]:
# USE THIS CODE BLOCK AFTER RUNNING THE BLOCK ABOVE
words = open("./preprocessed.txt", "r", encoding='utf-8').read().split()
words_set = sorted(list(set(words)))
vocab_size = len(words_set)
print(f"{vocab_size} UNIQUE WORDS!")

52441 UNIQUE WORDS!


In [6]:
itos = {i: s for i, s in enumerate(words_set)}
stoi = {s: i for i, s in enumerate(words_set)}

In [7]:
encode = lambda s: [stoi[i] for i in s]
decode = lambda l: "".join([itos[i] for i in l])

In [8]:
def get_batch(split):
    idx = torch.randint(len(words)-block_size-1, (batch_size,))
    X = torch.stack([torch.tensor(encode(words[i:i+block_size])) for i in idx])
    Y = torch.stack([torch.tensor(encode(words[i+1:i+block_size+1])) for i in idx])
    return X, Y

In [9]:
X, Y = get_batch("train")
print(X.shape)
print(Y.shape)
print(X)
print(Y)

torch.Size([44, 128])
torch.Size([44, 128])
tensor([[    0, 19498, 37763,  ..., 22151, 41238, 28465],
        [    3, 45400, 48556,  ..., 49886, 51650,     2],
        [19498, 32034, 21813,  ...,     2, 22928, 29222],
        ...,
        [45343, 45400, 24058,  ..., 50462, 49886, 51650],
        [ 3950,  5791, 30174,  ..., 19854,     3, 39327],
        [41578, 21132, 24569,  ..., 50294,     2, 24989]], device='cuda:0')
tensor([[19498, 37763,     3,  ..., 41238, 28465,  3871],
        [45400, 48556, 49886,  ..., 51650,     2,  1333],
        [32034, 21813, 45400,  ..., 22928, 29222, 16158],
        ...,
        [45400, 24058,     3,  ..., 49886, 51650,     3],
        [ 5791, 30174,  8508,  ...,     3, 39327, 23832],
        [21132, 24569, 52003,  ...,     2, 24989,     3]], device='cuda:0')


In [10]:
class SingleHeadAttention(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.Keys = nn.Linear(n_embd, head_size, bias=False)
        self.Queries = nn.Linear(n_embd, head_size, bias=False)
        self.Values = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # X: (Batch_Size, BlockSize, EmbeddingDim)
        N, T, C = x.shape
        k = self.Keys(x)
        q = self.Queries(x)
        v = self.Values(x)

        weights = q @ k.mT
        weights = torch.masked_fill(weights, self.tril[:T, :T] == 0, -torch.inf)
        weights = torch.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        attention = weights @ v # Shape (N, T, H)
        return attention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([SingleHeadAttention(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


In [12]:
class FeedForward(nn.Module):
    def __init__(self, fan_in=n_embd, fan_out=n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(fan_in, fan_out*4),
            nn.ReLU(),
            nn.Linear(fan_out*4, fan_out),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

In [13]:
class Block(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.l1 = nn.LayerNorm(n_embd)
        self.head = MultiHeadAttention(num_heads, head_size)
        self.l2 = nn.LayerNorm(n_embd)
        self.ffwd = FeedForward(n_embd, n_embd)
    
    def forward(self, x):
        # Should have skip connections
        x = x + self.head(self.l1(x))
        x = x + self.ffwd(self.l2(x))
        return x

In [14]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_heads, head_size) for _ in range(n_layers)])
        self.la_head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, targets=None):
        T = x.shape[1] # So that it can be used to evaluate as well, (T is not always block_size)
        word_emb = self.word_emb(x)
        pos_emb = self.pos_emb(torch.arange(T))
        x = word_emb + pos_emb
        x = self.blocks(x)
        logits = self.la_head(x)
        # print(logits.shape, targets.shape)
        if targets is None:
            loss = None
        else:
            N, T, C = logits.shape
            logits = logits.view(N*T, C)
            targets = targets.view(N*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    @torch.no_grad()
    def generate(self, initial_string="endoftext"):
        # idx is of shape (N, T)
        # N is probably 1 and so is T
        self.eval()
        current_pred = ""
        result = []
        init = list(initial_string.lower().split())
        idx = torch.tensor(encode(init)).view(1, -1)
        while current_pred != "endoftext":
            idx = idx[:, -block_size:]
            logits, loss = self(idx) # Logits Shape (N, T, vocab_size)
            logits = logits[:, -1, :] # Interested only in last prediction (N, vocab_size)
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, 1)
            current_pred = itos[idx_next.squeeze().item()]
            result.append(current_pred)
            idx = torch.cat((idx, idx_next), dim=1)
            print(current_pred, end=" ")
        self.train()


In [15]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [16]:
model = GPT()
# USE THE FOLLOWING LINE WHEN YOU HAVE SAVED state_dict PREVIOUSLY
model.load_state_dict(torch.load("./storyGen_state_dict"))
opt = torch.optim.AdamW(model.parameters(), lr)
sum([p.nelement() for p in model.parameters()])

23768665

In [None]:
for iter in range(max_iters):
    if iter % eval_interval == 0:
      losses = estimate_loss()
      print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 
    Xb, Yb = get_batch("train")
    logits, loss = model(Xb, Yb)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

In [26]:
# TO SAVE - AFTER TRAINING
# torch.save(model.state_dict(), "./storyGen_state_dict")
# TO LOAD
# model = GPT()
# model.load_state_dict(torch.load("./storyGen_state_dict"))

In [18]:
model.generate("Once upon a time")

, there was a girl called lizzy . she was happy and loved to play outside in her yard . one day , lizzy went outside to play and pick some juice . but the juice dripped was too cold , so lizzy decided to roll the spoon around in the mud . she was so tired , she wanted to go out and play in the mud , but she was also impatient . lizzy ' s friend wanted to help . she took a step and rolled the knife down , handing it to lizzy . lizzy and lizzy wondered how to make lemonade , but it would get too hard . lizzy laughed and said , this hot coffee doesn ' t be bitter . the tastes even hotter , but it still made lizzy so embarrassed . lizzy finally had a plan she zoomed off and zip the ladder to another day . lizzy sat and waited for the pasta to come alive to the third day . and so the rice was all special and her mom thanked her , and she proudly kept on insisting . lizzy smiled and said , thank you for inviting me , mom . endoftext 