# Rotary Embedding

In [29]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from datasets import load_dataset
import math

from einops import rearrange # einstein operation

## Download Dataset

In [30]:
sample = 100

dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token  # You can choose any appropriate token for padding

subset_dataset = dataset['train'][:sample]['text']
tokenized_dataset = tokenizer(
    subset_dataset,
    return_tensors='pt',
    padding=True,  # Enable padding
    truncation=True  # Enable truncation
)

Repo card metadata block was not found. Setting CardData to empty.


In [31]:
data = tokenized_dataset['input_ids']
print(tokenizer.vocab_size)

50257


In [32]:
# InferenceParams
class InferenceParams(nn.Module):
    def __init__(self, sequence_len):
        
        self.rotary_dim = 3
        self.n_layer = 2
        
        self.sequence_len = sequence_len
        self.batch_size = 16
        self.n_embd = 20
        self.n_head = 4
        self.vocab_size = 50257

In [33]:
sequence_len = data.size(1) - 1
config = InferenceParams(sequence_len)

In [34]:
def get_batch(data, batch_size):
    idx = torch.randint(0, len(data), size=(batch_size,))
    batch = data[idx]

    xb = batch[:, :-1].contiguous()
    yb = batch[:, 1:].contiguous()
    
    return xb, yb

xb, yb = get_batch(data, config.batch_size)
xb.shape, yb.shape

(torch.Size([16, 297]), torch.Size([16, 297]))

## Embedding

In [35]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        
    def forward(self, input_ids):
        hidden_states = self.wte(input_ids)
        
        return hidden_states

In [36]:
m = Embedding(config)
hidden_states = m(xb)
hidden_states.shape

torch.Size([16, 297, 20])

In [37]:
class RotaryPositionEmbedding(nn.Module):
    def __init__(self, config, base = 10000):
        super().__init__()
        self.rotary_dim  = config.rotary_dim 
        
        inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2) / self.rotary_dim ))
        self.register_buffer("inv_freq", inv_freq)
        
        self.cos_cache = None
        self.sin_cache = None
        
    def forward(self, qkv):
        seqlen = qkv.shape[1]
        
        # Update cos sin cache
        t = torch.arange(seqlen)
        freqs = torch.outer(t, self.inv_freq)
        
        self.cos_cache = torch.cos(freqs)
        self.sin_cache = torch.sin(freqs)
        
        # Apply rotary qkv
        rotary_dim = self.cos_cache.shape[1]
        rotary_dim *= 2
        
        q_rot = qkv[:, :, 0, :, :rotary_dim]
        q_pass = qkv[:, :, 0, :, rotary_dim:]
        
        k_rot = qkv[:, :, 1, :, :rotary_dim]
        k_pass = qkv[:, :, 1, :, rotary_dim:]
        
        # Splits the queries and keys in half
        q1, q2 = q_rot.chunk(2, dim=-1)
        k1, k2 = k_rot.chunk(2, dim=-1)
        c, s = rearrange(self.cos_cache, "t d -> t 1 d"), rearrange(self.sin_cache, "t d -> t 1 d")
        
        # Computes the new keys and queries
        q_rot = torch.cat([q1 * c - q2 * s, q1 * s - q2 * c], dim=-1)
        k_rot = torch.cat([k1 * c - k2 * s, k1 * s - k2 * c], dim = -1)
        
        return torch.cat(
            [
                torch.cat([q_rot, q_pass], dim=-1).unsqueeze(2),
                torch.cat([k_rot, k_pass], dim=-1).unsqueeze(2),
                qkv[:, :, 2:3, :, :]
            ],
            dim=2
        )

## MLP

In [38]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        n_inner = 4 * config.n_embd
        
        self.fc1 = nn.Linear(config.n_embd, n_inner)
        self.fc2 = nn.Linear(n_inner, config.n_embd)
        self.act = nn.GELU()
        
    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc2(hidden_states)
        
        return hidden_states

In [39]:
m = MLP(config)
ffwd_out = m(hidden_states)
ffwd_out.shape

torch.Size([16, 297, 20])

## Attention

In [40]:
class SelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        pass
        
    def forward(self, qkv):
        seq_len = qkv.shape[1]
        q, k, v = qkv.unbind(2)
        
        softmax_scale = 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum("bthd, bshd -> bhts", q, k * softmax_scale)
        
        mask = torch.triu(torch.full((seq_len, seq_len), -10000), 1)
        scores += mask
        
        attention = torch.softmax(scores, dim=-1)
        
        output = torch.einsum("bhts, bshd -> bthd", attention, v)
        
        return output

In [41]:
class MHA(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.rotary_emb = RotaryPositionEmbedding(config)
        
        self.head_dim = config.n_embd // config.n_head
        opt_size = config.n_head * self.head_dim
        hidden_size = config.n_embd
        
        self.Wqkv = nn.Linear(hidden_size, 3 * opt_size)
        self.out_proj = nn.Linear(opt_size, hidden_size)
        
        self.inner_attn = SelfAttention()
        
    def forward(self, x):
        qkv = self.Wqkv(x)
        qkv = rearrange(qkv, 'b t (three h d) -> b t three h d', three=3, d=self.head_dim)
        
        qkv = self.rotary_emb(qkv)
        
        output = self.inner_attn(qkv)
        
        output = rearrange(output, "... h d -> ... (h d)")
        attn_out = self.out_proj(output)
        
        return attn_out

In [42]:
m = MHA(config)
attn_out = m(hidden_states)
attn_out.shape

torch.Size([16, 297, 20])

## Block

In [43]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.ln = nn.LayerNorm(config.n_embd)
        
        self.attn = MHA(config)
        self.ffwd = MLP(config)
        
    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.ln(hidden_states)
        
        attn_out = self.attn(hidden_states)
        
        ffwd_out = self.ffwd(hidden_states)
        
        output = attn_out + ffwd_out + residual
        return output

In [44]:
m = Block(config)
output = m(hidden_states)
output.shape

torch.Size([16, 297, 20])

In [45]:
class LMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.ln = nn.LayerNorm(config.n_embd)
        self.linear = nn.Linear(config.n_embd, config.vocab_size)
        
    def forward(self, output):
        output = self.ln(output)
        logits = self.linear(output)
        
        return logits

In [46]:
m = LMHead(config)
logits = m(output)
logits.shape

torch.Size([16, 297, 50257])

## Sequential

In [47]:
class SequentialForLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        n_layers = 2
        
        modules = [Embedding(config)]
        modules += [Block(config) for _ in range(n_layers)]
        modules.append(LMHead(config))
        
        self.layers = nn.Sequential(*modules)
        
    def forward(self, input_ids):
        logits = self.layers(input_ids)
        return logits

In [48]:
m = SequentialForLM(config)
logits = m(xb)
logits.shape

torch.Size([16, 297, 50257])

## Loss

In [49]:
class LMLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.loss_fct = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels):
        logits = logits.view(-1, logits.shape[-1])
        labels = labels.view(-1)                    
                             
        loss = self.loss_fct(logits, labels)

        return loss

In [50]:
lm_loss = LMLoss()
loss = lm_loss(logits, yb)
loss

tensor(11.1403, grad_fn=<NllLossBackward0>)

In [51]:
config = InferenceParams(sequence_len)
xb, yb = get_batch(data, config.batch_size)

m = SequentialForLM(config)
logits = m(xb)

lm_loss = LMLoss()
loss = lm_loss(logits, yb)
loss

tensor(10.8916, grad_fn=<NllLossBackward0>)