In [None]:
import torch

class LlamaConfig():
  def __init__(
          self,
          vocab_size: int = 128_256,
          context_length: int = 131_072,
          emb_dim: int = 2048,
          n_heads: int = 32,
          n_layers: int = 16,
          hidden_dim: int = 8192,
          n_kv_groups: int = 8,
          head_dim: int | None = None,
          dtype: torch.dtype = torch.float32,
          mlp_bias: bool = False,
          rms_norm_eps: float = 1e-6,
          bias: bool = False,
          attention_bias: bool = False,
        ):
      self.vocab_size = vocab_size
      self.max_position_embeddings = context_length
      self.hidden_size = emb_dim
      self.num_attention_heads = n_heads
      self.num_hidden_layers = n_layers
      self.num_key_value_heads = n_kv_groups
      self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
      self.dtype = dtype
      self.intermediate_size = hidden_dim
      self.mlp_bias = mlp_bias
      self.rms_norm_eps = rms_norm_eps
      self.bias = bias
      self.attention_bias = attention_bias


In [None]:
import math
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F




class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

def precompute_freqs_cis(dim:int, seq_len: int, theta: float=10000.0, device: torch.device = torch.device("cpu")):
  # Computing Theta value for each dim pair which is dim/2
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2,device=device)[:(dim//2)].float()/dim))

  # Computing range of positions(m) in the sequence
  t = torch.arange(seq_len, dtype=torch.float32, device=device)

  # freqs gives all the Theta value range for all the position of tokens in the sequence
  freqs = torch.outer(t, freqs).to(device)

  # This is the rotation matrix which needs to be converted to Polar form in order to perform rotation to the embedding
  freqs_cis = torch.polar(torch.ones_like(freqs).to(device), freqs).to(device)
  return freqs_cis

def reshape_for_broadcast(freqs_cis, x):
  ndim = x.ndim
  assert 0<=1<ndim
  assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "the last two dimension of freqs_cis, x must match"
  shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]
  return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, device: torch.device = torch.device("cpu"))->Tuple[torch.Tensor, torch.Tensor]:
  # Applying rotary positional encoding to both query and key embedding together
  # First: The last dimension of xq and xk embedding needs to be reshaped to make it a pair. As rotation matrix is applied to each pair of dim.
  # Next: convert both xq and xk to complex number as the rotation matrix is only applicable to complex number
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device)    #xq_:[bsz, seq_len, n_heads, head_dim/2]
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device)    #xk_:[bsz, seq_len, n_heads, head_dim/2]

  # The rotation matrix(freqs_cis) dimensions across seq_len(dim=1) and head_dim(dim=3) should match with the embedding
  # Also, the shape freqs_cis should be the same with xq and xk, hence change the shape of freqs_cis:[seq_len,head_dim] -> freqs_cis:[1,seq_len,1,head_dim]
  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)

  #Finally, perform rotation operation by multiplying with freqs_cis.
  #After the rotation is completed, convert both xq_out and xk_out back to real number and return
  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device) #xq_out:[bsz, seq_len, n_heads, head_dim]
  xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device) #xk_out:[bsz, seq_len, n_heads, head_dim]
  return xq_out.type_as(xq), xk_out.type_as(xk)

def repeat_kv(x:torch.Tensor, n_rep: int)-> torch.Tensor:
  bsz, seq_len, n_kv_heads, head_dim = x.shape
  if n_rep == 1:
    return x
  return (
      x[:,:,:,None,:]
      .expand(bsz,seq_len,n_kv_heads,n_rep, head_dim)
      .reshape(bsz,seq_len,n_kv_heads * n_rep, head_dim)
  )

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = nn.SiLU() # nn.functional.silu ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
    
class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )
    
    def forward(self, hidden_states: torch.Tensor):
        batch_size, seq_len, _ = hidden_states.shape
        xq = self.q_proj(hidden_states)
        xk = self.k_proj(hidden_states)
        xv = self.v_proj(hidden_states)

        xq = xq.view(batch_size, seq_len, self.num_attention_heads, self.head_dim)
        xk = xk.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)

        # Compute rotation matrix and apply RoPE to queries and keys for for training.
        freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=seq_len, device=hidden_states.device)

        #xq[bsz,seq_len,n_heads, head_dim], xk[bsz,seq_len,n_heads, head_dim]
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis, device=hidden_states.device)

        # Use repeat_kv function to make Keys,Values shape same as the queries shape
        #keys[bsz,seq_len,n_heads,head_dim], #values[bsz,seq_len,n_heads,head_dim]
        keys = repeat_kv(xk, self.num_key_value_groups) #keys[bsz,seq_len,n_heads,head_dim]
        values = repeat_kv(xv, self.num_key_value_groups)

        # For training mode, we'll compute mask and apply to the attention score later
        mask = torch.full((seq_len, seq_len),float("-inf"),device=hidden_states.device)
        mask = torch.triu(mask, diagonal=1).to(hidden_states.device)

        # To compute attention, we'll need to perform a transpose operation to reshape all queries, keys and values bring heads at dim 1 and seq at dim 2
        xq = xq.transpose(1,2)                  #xq[bsz,n_heads,seq_len,head_dim]
        keys = keys.transpose(1,2)              #keys[bsz,n_heads,seq_len,head_dim]
        values = values.transpose(1,2)          #values[bsz,n_heads,seq_len,head_dim]

        # Computing attention score
        scores = torch.matmul(xq, keys.transpose(2,3)).to(hidden_states.device)/math.sqrt(self.head_dim)
        if mask is not None:
          scores = scores + mask

        # Apply softmax to the attention score
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        # Matrix multiplication of attention score with the values
        output = torch.matmul(scores, values).to(hidden_states.device)

        # We get the contextual embedding for each head
        # All heads need to be reshaped back and combined to give a single single contextual attention output
        # Shape change: output[bsz,n_heads,seq_len,head_dim] -> output[bsz,seq_len, n_heads,head_dim] -> output[bsz,seq_len, n_heads * head_dim]
        output = output.transpose(1,2).contiguous().view(batch_size, seq_len, -1)

        # shape: output [bsz,seq_len,dim]
        return self.o_proj(output)


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states = hidden_states + residual

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = hidden_states + residual
        return hidden_states

class LlamaModel(nn.Module):
    def __init__(self, config: LlamaConfig, embedding: torch.Tensor = None):
        super().__init__()
        # self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=config.dtype)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # self.rotary_emb = LlamaRotaryEmbedding(config=config)

    def forward(self, input_ids: torch.Tensor):
        hidden_states = self.embed_tokens(input_ids)

        for layer in self.layers:
            hidden_states = layer(hidden_states)

        hidden_states = self.norm(hidden_states)
        return hidden_states

class LlamaForCausalLM(nn.Module):
    def __init__(self, config: LlamaConfig, embedding: torch.Tensor = None):
        super().__init__()
        self.model = LlamaModel(config, embedding)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.bias)

    def forward(self, input_ids: torch.Tensor):
        hidden_states = self.model(input_ids)
        return self.lm_head(hidden_states)

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
import torch

teacher_config = LlamaConfig(
    vocab_size=8000,
    context_length=128,
    emb_dim=256,
    n_heads=16,
    n_layers=12,
    hidden_dim=1024,
    n_kv_groups=8,
    head_dim=None,
    dtype=torch.float32,
    mlp_bias=False,
    rms_norm_eps=1e-6,
    bias=False,
    attention_bias=False,
)

teacher_model = LlamaModel(teacher_config)
teacher_model = teacher_model.to(device)
teacher_model

In [None]:
from safetensors.torch import load_file
tensors = load_file("general_data_token_ids.safetensors")

token_ids = tensors['a']
print(type(token_ids))

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset

pad_id = 63

class TextDataset(Dataset):
  def __init__(self, token_ids: list, context_length: int, stride: int):
    super().__init__()

    self.inputs = []
    self.targets = []

    for i in range(0, len(token_ids) - context_length, stride):
      input_chunk = token_ids[i:i + context_length]
      target_chunk = token_ids[i + 1:i + context_length + 1]

      # truncate to context length
      input_chunk = input_chunk[:context_length]
      target_chunk = target_chunk[:context_length]

      # pad to context length
      input_chunk = input_chunk + [pad_id] * (context_length - len(input_chunk))
      target_chunk = target_chunk + [pad_id] * (context_length - len(target_chunk))

      # truncate to context length
      input_chunk = input_chunk[:context_length]
      target_chunk = target_chunk[:context_length]

      self.inputs.append(torch.tensor(input_chunk, dtype=torch.long))
      self.targets.append(torch.tensor(target_chunk, dtype=torch.long))

  def __len__(self):
    return len(self.inputs)
  
  def __getitem__(self, idx):
    return self.inputs[idx], self.targets[idx]


def create_data_loader(token_ids: list, context_length: int, stride: int,
                       batch_size: int, shuffle: bool = True, device: str = "cpu"):
  dataset = TextDataset(token_ids, context_length, stride)
  dataloader = DataLoader(
      dataset,
      batch_size=batch_size,
      shuffle=shuffle,
      generator=torch.Generator(device=device)
    )
  
  return dataloader


  


In [None]:
stride = 32
train_data_loader = create_data_loader(token_ids.tolist(), 128, stride, 16, False)

len(train_data_loader)

In [None]:
# model parameters count
parameters_count = sum(p.numel() for p in teacher_model.parameters())
print(parameters_count)

# model architecture
print(teacher_model)

In [None]:
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

In [None]:
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.AdamW(teacher_model.parameters(), lr=1e-3)


In [None]:
for i, (X, Y) in enumerate(train_data_loader):
  print(X.shape, Y.shape, Y.flatten().shape)
  break

In [None]:
from tqdm import tqdm

num_epochs = 15  

for epoch in range(num_epochs):
  total_loss = 0.

  for i, (X, Y) in enumerate(tqdm(train_data_loader, desc=f"Epoch {epoch + 1}")):
    X = X.to(device)
    Y = Y.to(device)
    
    pred = teacher_model(X)
    loss = loss_fn(pred.flatten(0, 1), Y.flatten())
    total_loss += loss.item()
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
  average_loss = total_loss / len(train_data_loader)
  print(f"Epoch {epoch + 1} | Last Loss: {loss.item():.4f} | Avg Loss: {average_loss:.4f}")