<a href="https://colab.research.google.com/github/baua1/llm/blob/main/pre_training_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import numpy as np
import pandas as pd

In [4]:
A = torch.randn(1, 64, 1152, 1, 8)
B = torch.randn(10, 1, 1152, 8, 16)

C = A @ B
print(C.size())

torch.Size([10, 64, 1152, 1, 16])


In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout  = nn.Dropout
    self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagnol=1))

  def forward(self,x):
    b, num_token, d_in = x.shape

    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    keys = keys.view(b, num_token, self.num_heads, self.head_dim)
    values = values.view(b, num_token, self.num_heads, self.head_dim)
    queries = queries.view(b, num_token, self.num_heads, self.head_dim)

    keys = keys.transpose(1,2)
    queries = queries.transpose(1,2)
    values = queries.transpose(1,2)

    attn_scores = queries @ keys.transpose(2,3)

    mask_bool = self.mask.bool()[:num_token, :num_token]

    attn_scores.masked_fill_(mask_bool, -torch.inf)

    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vec = (attn_weights @ values).transpose(1,2)

    context_vec = context_vec.contiguous().view(b, num_token, self.d_out)
    context_vec = self.out_proj(context_vec)

    return context_vec



In [None]:
class LayerNorm(nn.Module):
  def __init__(self, emb_dim):
    super().__init__()
    self.eps = 1e-5
    self.scale = nn.Parameter(torch.ones(emb_dim))
    self.shift = nn.Parameter(torch.zeros(emb_dim))

  def forward(self, x):

    mean = x.mean(dim=-1,keepdim=True)
    variance = x.var(dim=-1, keepdim=True, unbiased=False)
    norm_x = (x - mean) / torch.sqrt(variance + self.eps)

    return self.scale*norm_x + self.shift

In [None]:
class GPT(nn.Module):

  """
  cfg['vocab_size'] = the size if the vocabulary. That is how many tokens are present in the whole dataset.

  """

  def __init__(self, cfg):
    super().__init__()
    self.token_embeddings = nn.Embedding(cfg['vocab_size'], cfg['embd_dim'])
    self.pos_embeddings = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
    self.drop_embeddings = nn.Embedding(cfg['drop_rate'])

    self.transformer_block = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])

    self.final_norm = LayerNorm(cfg['embd_dim'])
    self.output_head = nn.Linear(cfg['embd_dim'], cfg['vocab_size'], bias=False)


  def forward(self,input_idx):

    batch_size, seq_len = input_idx.shape
    token_embeddings = self.token_embeddings(input_idx)
    position_embedding = self.pos_embeddings(torch.arange(seq_len, device=input_idx.device))
    x = token_embeddings + position_embedding
    x = self.drop(x)
    x = self.transformer_block(x)
    x = self.final_norm(x)
    logits = self.output_head(x)


    return logits


