# Basic Class

In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from datasets import load_dataset
import math

from einops import rearrange # einstein operation

  from .autonotebook import tqdm as notebook_tqdm


## Download Dataset

In [2]:
sample = 100

dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token  # You can choose any appropriate token for padding

subset_dataset = dataset['train'][:sample]['text']
tokenized_dataset = tokenizer(
    subset_dataset,
    return_tensors='pt',
    padding=True,  # Enable padding
    truncation=True  # Enable truncation
)

Repo card metadata block was not found. Setting CardData to empty.


In [3]:
data = tokenized_dataset['input_ids']
data.shape

torch.Size([100, 298])

In [16]:
# InferenceParams

batch_size = 16

n_head = 4
n_embd = 36
sequence_len = data.size(1) - 1
vocab_size = tokenizer.vocab_size

In [17]:
def get_batch(data, batch_size):
    idx = torch.randint(0, len(data), size=(batch_size,))
    batch = data[idx]

    xb = batch[:, :-1].contiguous()
    yb = batch[:, 1:].contiguous()
    
    return xb, yb

xb, yb = get_batch(data, batch_size)
xb.shape, yb.shape

(torch.Size([16, 297]), torch.Size([16, 297]))

## Embedding

In [18]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, n_embd, sequence_len):
        super().__init__()
        self.sequence_len = sequence_len
        
        self.wte = nn.Embedding(vocab_size, n_embd)
        self.position = nn.Embedding(sequence_len, n_embd)
        
    def forward(self, input_ids):
        token_embd = self.wte(input_ids)
        position_embd = self.position(torch.arange(self.sequence_len))
        
        hidden_states = token_embd + position_embd        
        
        return hidden_states

In [19]:
m = Embedding(vocab_size, n_embd, sequence_len)
hidden_states = m(xb)
hidden_states.shape

torch.Size([16, 297, 36])

## MLP

In [20]:
class MLP(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        n_inner = 4 * n_embd
        
        self.fc1 = nn.Linear(n_embd, n_inner)
        self.fc2 = nn.Linear(n_inner, n_embd)
        self.act = nn.ReLU()
        
    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc2(hidden_states)
        
        return hidden_states

In [21]:
m = MLP(n_embd)
ffwd_out = m(hidden_states)
ffwd_out.shape

torch.Size([16, 297, 36])

## Attention

In [24]:
class SelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        pass
        
    def forward(self, qkv):
        seq_len = qkv.shape[1]
        q, k, v = qkv.unbind(2)
        
        softmax_scale = 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum("bthd, bshd -> bhts", q, k * softmax_scale)
        
        mask = torch.triu(torch.full((seq_len, seq_len), -10000), 1)
        scores += mask
        
        attention = torch.softmax(scores, dim=-1)
        
        output = torch.einsum("bhts, bshd -> bthd", attention, v)
        
        return output

In [25]:
class MHA(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.head_dim = n_embd // n_head
        opt_size = n_head * self.head_dim
        hidden_size = n_embd
        
        self.Wqkv = nn.Linear(hidden_size, 3 * opt_size)
        self.out_proj = nn.Linear(opt_size, hidden_size)
        
        self.inner_attn = SelfAttention()
        
    def forward(self, x):
        qkv = self.Wqkv(x)
        qkv = rearrange(qkv, 'b t (three h d) -> b t three h d', three=3, d=self.head_dim)
        
        output = self.inner_attn(qkv)
        
        output = rearrange(output, "... h d -> ... (h d)")
        attn_out = self.out_proj(output)
        
        return attn_out

In [26]:
m = MHA(n_embd, n_head)
attn_out = m(hidden_states)
attn_out.shape

torch.Size([16, 297, 36])

## Block

In [27]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        
        self.ln = nn.LayerNorm(n_embd)
        
        self.ffwd = MLP(n_embd)
        self.attn = MHA(n_embd, n_head)
        
    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.ln(hidden_states)
        
        attn_out = self.attn(hidden_states)
        
        ffwd_out = self.ffwd(hidden_states)
        
        output = attn_out + ffwd_out + residual
        return output

In [28]:
m = Block(n_embd, n_head)
output = m(hidden_states)
output.shape

torch.Size([16, 297, 36])

In [29]:
class LMHead(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        
        self.ln = nn.LayerNorm(n_embd)
        self.linear = nn.Linear(n_embd, vocab_size)
        
    def forward(self, output):
        output = self.ln(output)
        logits = self.linear(output)
        
        return logits

In [30]:
m = LMHead(vocab_size, n_embd)
logits = m(output)
logits.shape

torch.Size([16, 297, 50257])

## Loss

In [32]:
class LMLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.loss_fct = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels):              
                             
        loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return loss

In [33]:
lm_loss = LMLoss()
loss = lm_loss(logits, yb)
loss

tensor(11.1934, grad_fn=<NllLossBackward0>)