In [2]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

In [3]:
dataset = load_dataset("roneneldan/TinyStories")



In [4]:
train = dataset['train']
validation = dataset['validation']

In [5]:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")

In [6]:
dataset_str = ""
for obj in train:
    dataset_str += obj['text']

In [7]:
len(dataset_str)

1899973203

In [8]:
context_len=32

In [9]:
def build_dataset(data, context):
    X,Y = [],[]
    for idx in range(0, len(data)-context, context):
        X.append(data[idx:idx+context])
        Y.append(data[idx+1:idx+1+context])
    return torch.tensor(X), torch.tensor(Y)

In [10]:
encoded = enc.encode(dataset_str[:10000000])
X,Y = build_dataset(encoded, context_len)
X.shape,Y.shape

(torch.Size([73922, 32]), torch.Size([73922, 32]))

In [11]:
n_layers = 1
d_model = 128
n_heads = 4
context_len = context_len
vocab_size = max(encoded)+1

In [12]:
class TinyLlama(nn.Module):
    def __init__(self):
        super().__init__()
        self.Blocks = nn.Sequential(*(Block() for _ in range(n_layers)))
        self.Linear = nn.Linear(d_model, vocab_size)
        self.Embedding = nn.Embedding(vocab_size, d_model, dtype=torch.float32)
        
    def forward(self, x):
        x = self.Embedding(x)
        out = self.Blocks(x)
        out  = self.Linear(out)
        return out

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.mha = MHA()
        self.ffn = FFN()
        self.ln1 = LayerNorm()
        self.ln2 = LayerNorm()
    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class MHA(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SelfAttention()
        self.l1 = nn.Linear(d_model, d_model) 
    def forward(self, x):
        #splits = torch.hsplit(x, n_heads)
        concat = torch.cat([self.sa(x) for _ in range(n_heads)], dim=-1)
        out = self.l1(concat)
        return out

        
class SelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.get_keys = nn.Linear(d_model, d_model // n_heads)
        self.get_values = nn.Linear(d_model, d_model // n_heads)
        self.get_queries = nn.Linear(d_model, d_model // n_heads)
        self.rope = RotaryPositionalEmbeddings(d_model // n_heads)
    def forward(self, x):
        K = self.rope(self.get_keys(x))
        V = self.get_values(x)
        Q = self.rope(self.get_queries(x))
        weightage = Q @ torch.transpose(K, -1, -2)
        weightage = torch.tril(weightage)
        weightage = weightage.masked_fill(weightage==0, float("-Inf"))
        scaled_weightage = weightage / (d_model // n_heads) ** 0.5
        out = F.softmax(scaled_weightage, dim=-1) @ V
        return out
        
        
class FFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(d_model, 4*d_model)
        self.l2 = nn.Linear(2*d_model, d_model)
        self.swiglu = SwiGLU()
    def forward(self, x):
        x = self.swiglu(self.l1(x))
        x = self.l2(x)
        return x
        
class LayerNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-05
        self.gamma = nn.Parameter(torch.ones(d_model))
        # self.beta = nn.Parameter(torch.zeros(d_model))  biases don't really help
    def forward(self, x):
        xmean = torch.mean(x, dim=-1, keepdims=True)
        xvar = torch.var(x, dim=-1, keepdims=True)
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        x = xhat * self.gamma
        return x

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, d: int, base: int = 10_000):
        super().__init__()
        self.base = base
        self.d = d
        self.cos_cached = None
        self.sin_cached = None

    def _build_cache(self, x: torch.Tensor):
        # if self.cos_cached is not None and x.shape[1] <= self.cos_cached.shape[0]:
        #     return
        seq_len = x.shape[1]
        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
        self.cos_cached = idx_theta2.cos()
        self.sin_cached = idx_theta2.sin()

    def _neg_half(self, x: torch.Tensor):
        d_2 = self.d // 2
        return torch.cat([-x[:, :, d_2:], x[:, :, :d_2]], dim=-1)

    def forward(self, x: torch.Tensor):
        self._build_cache(x)
        x_rope = x
        neg_half_x = self._neg_half(x_rope)
        x_rope = (x_rope * self.cos_cached) + (neg_half_x * self.sin_cached)
        return x_rope 

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [14]:
def get_batch(batch_size):
    len = X.shape[0]
    idx = random.randint(0, len-batch_size)
    return X[idx:idx+batch_size].to(device), Y[idx:idx+batch_size].to(device)

In [15]:
model = TinyLlama()
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.95))

In [16]:
sum([torch.numel(p) for p in model.parameters() if p.requires_grad])

25893631

In [17]:
def train_one_epoch(n_steps):
    running_loss = 0.
    last_loss = 0.
    for i in range(n_steps):
        inputs, labels = get_batch(32)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.transpose(outputs, -2, -1)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999 or i==0:
            if i==0:
                last_loss = running_loss
            else:
                last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

In [19]:
PATH = './V1'
model = TinyLlama()
model.load_state_dict(torch.load(PATH))
model.to(device)

TinyLlama(
  (Blocks): Sequential(
    (0): Block(
      (mha): MHA(
        (sa): SelfAttention(
          (get_keys): Linear(in_features=128, out_features=32, bias=True)
          (get_values): Linear(in_features=128, out_features=32, bias=True)
          (get_queries): Linear(in_features=128, out_features=32, bias=True)
          (rope): RotaryPositionalEmbeddings()
        )
        (l1): Linear(in_features=128, out_features=128, bias=True)
      )
      (ffn): FFN(
        (l1): Linear(in_features=128, out_features=512, bias=True)
        (l2): Linear(in_features=256, out_features=128, bias=True)
        (swiglu): SwiGLU()
      )
      (ln1): LayerNorm()
      (ln2): LayerNorm()
    )
  )
  (Linear): Linear(in_features=128, out_features=100255, bias=True)
  (Embedding): Embedding(100255, 128)
)

In [20]:
@torch.no_grad()
def generate(prompt, len):
    input = torch.unsqueeze(torch.tensor(enc.encode(prompt), device=device), dim=0)
    next_word_decoded = ''
    idx=0
    while idx!=len:
        input_len = input.shape[-1]
        output = model(input)
        next_word_logits = output[0][input_len-1]
        next_word_probs = F.softmax(next_word_logits, dim=0)
        next_word = torch.multinomial(next_word_probs, 1)
        next_word_decoded = enc.decode([next_word.item()])
        next_word = next_word.view(1,1) # torch.tensor(124) -> torch.tensor([[124]]), makes it easy to concat
        input = torch.cat((input, next_word), dim=-1)
        input = torch.unsqueeze(input[0][-context_len:None], dim=0)
        prompt += next_word_decoded
        idx+=1
    return prompt

In [21]:
generate("once upon a time there was dog named Bruno. But Bruno did not have many friends so he", 500)

'once upon a time there was dog named Bruno. But Bruno did not have many friends so he was safe in the done. They had learned a valuable lesson. She turned all about her toys and bigger than the little girl. They knew that they liked their new surprise and clothes.Lily and Ben were just friends. They liked to play and run and slide. They liked to run and play on the swings and slide magical next to see the grass. They hoped to do they did not know that Mommy had an idea. They played with them. They strange circles and smile. They had a lot of fun, hours walking on and across the woods on the sea. They saw a big tree with their mom. They were playing his fields and played with the pulled home with him. Timmy felt like so much fun with each other. From that day on, Timmy always smiled, and promised to go to a special until the stopped calling him for lunch. They were busy! They felt embarrassed and p: "I think!"\n\nJack ran to each other and cover and can bit of them!" Lily agreed and op

In [None]:
PATH = './V1'
torch.save(model.state_dict(), PATH)