<a href="https://colab.research.google.com/github/kmalik22/colabs/blob/main/transformer_training_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [86]:
import torch
import numpy as np
import torch.nn.functional as F
import itertools
from typing import Optional
from torch.utils.data import IterableDataset, DataLoader
torch.set_printoptions(linewidth=250, precision=3)

# Modules

In [210]:
class MLP(torch.nn.Module):
  def __init__(self, d_model, d_ffn):
    super().__init__()
    self.d_model = d_model
    self.d_ffn = d_ffn
    self.fc1 = torch.nn.Linear(in_features=d_model, out_features=d_ffn, bias=False)
    #torch.nn.init.trunc_normal_(self.fc1.weight, mean=0, std=1/d_model**0.5)
    self.fc2 = torch.nn.Linear(in_features=d_ffn, out_features=d_model, bias=False)
    #torch.nn.init.trunc_normal_(self.fc2.weight, mean=0, std=1/d_ffn**0.5, a=-3, b=3)

  def forward(self, x: torch.Tensor):
    fc1_out = self.fc1(x)
    relu_out = torch.nn.functional.relu(fc1_out)
    fc2_out = self.fc2(relu_out)
    return fc2_out

class MultiHeadAttention(torch.nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model//num_heads
    self.wqkv = torch.nn.Linear(d_model, int(3*d_model), bias=False)
    torch.nn.init.trunc_normal_(self.wqkv.weight, 0, 1/d_model**0.5, a=-3, b=3)
    self.wo = torch.nn.Linear(d_model, d_model)
    torch.nn.init.trunc_normal_(self.wo.weight, 0, 1/d_model**0.5, a=-3, b=3)
    # TODO: add rope

  def scaled_dot_product_attn(self, q, k, v):
    # q, k, v = (bsz, num_heads, seqlen, head_dim)
    # qk.t()/sqrt(head_dim)
    # causal mask
    # softmax
    # @ v
    seqlen = q.shape[2]
    attn_wts = q @ k.transpose(2,3)  #(bsz,num_h,seqlen,head_dim) (bsz,num_h,head_dim,seqlen) -> (bsz,num_h,seqlen,seqlen)
    # create mask, do torch.where. 1=use attn wts, else use -inf
    mask = torch.tril(torch.ones(seqlen,seqlen)).to(torch.bool)
    masked_attn_wts = torch.where(mask, attn_wts, float('-inf')) # (bsz,num_heads,seqlen,seqlen)
    # softmax TODO check once
    softmax_wts = torch.nn.functional.softmax(masked_attn_wts,dim=-1)
    # (bsz,num_heads,seqlen,seqlen) @ (bsz,num_heasds,seqlen,head_dim) -> (bsz,num_heads,seqlen,head_dim)
    return softmax_wts @ v


  def forward(self, x:torch.Tensor):
    #x = (bsz, seqlen, d_model)
    bsz, seqlen, d_model = x.shape
    assert d_model == self.d_model
    wqkv_out = self.wqkv(x) # (bsz, seqlen, 3*d_model)
    # rearrange so seqlen,head_dim are the last two dims
    wqkv_out = wqkv_out.reshape(bsz, 3, self.num_heads, seqlen, self.head_dim)
    q = wqkv_out[:,0,:,:,:] # (bsz, num_heads, seqlen, head_dim)
    k = wqkv_out[:,1,:,:,:] # (bsz, num_heads, seqlen, head_dim)
    v = wqkv_out[:,2,:,:,:] # (bsz, num_heads, seqlen, head_dim)
    # TODO: add rope to q and k
    self_attn_out = self.scaled_dot_product_attn(q, k, v) #(bsz, num_heads, seqlen, head_dim)
    # transpose
    self_attn_out = self_attn_out.transpose(1,2) # (bsz,seqlen,num_heads,head_dim)
    # concat the last dim
    self_attn_out = self_attn_out.reshape(bsz,seqlen,-1)
    # wo
    self_attn_out = self_attn_out.reshape(bsz, seqlen, d_model)
    wo_out = self.wo(self_attn_out)
    return wo_out


class TransformerBlock(torch.nn.Module):
  def __init__(self, d_model, d_ffn, num_heads, max_seqlen):
    super().__init__()
    self.mlp_norm = torch.nn.RMSNorm(normalized_shape=d_model)
    self.mlp = MLP(d_model=d_model, d_ffn=d_ffn)
    self.attn_norm = torch.nn.RMSNorm(normalized_shape=d_model)
    self.attn = MultiHeadAttention(d_model=d_model,num_heads=num_heads)


  def forward(self, x:torch.Tensor):
    # norm --> attn, add back to original --> norm --> mlp, add back
    attn_norm_out = self.attn_norm(x)
    attn_out = self.attn(attn_norm_out)
    attn_out_post_resid = x + attn_out
    mlp_norm_out = self.mlp_norm(attn_out_post_resid)
    mlp_out = self.mlp(mlp_norm_out)
    return mlp_out + attn_out_post_resid

class Transformer(torch.nn.Module):
  def __init__(self, d_model, d_ffn, vocab_size, num_layers, max_seqlen, num_heads, simple_pos_embed=False):
    super().__init__()
    self.max_seqlen = max_seqlen
    self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
    torch.nn.init.trunc_normal_(self.embedding.weight, mean=0, std=1/d_model**0.5, a=-3, b=-3)

    self.simple_pos_embed = simple_pos_embed
    if simple_pos_embed:
      self.pos_embedding = torch.nn.Embedding(num_embeddings=max_seqlen, embedding_dim=d_model)

    self.output_layer = torch.nn.Linear(in_features=d_model, out_features=vocab_size)
    torch.nn.init.trunc_normal_(self.output_layer.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    self.layers = torch.nn.ModuleList(
        [TransformerBlock(d_model=d_model, d_ffn=d_ffn, num_heads=num_heads, max_seqlen=max_seqlen) for _ in range(num_layers)]
    )
    # TODO: do proper init
    # add attention
    # add layer norms

  def forward(self, tokens: torch.Tensor):
    curr_out = self.embedding(tokens)
    if self.simple_pos_embed:
      curr_out = curr_out + self.pos_embedding(torch.arange(self.max_seqlen))
    for l in self.layers:
      curr_out = curr_out + l(curr_out)
    logits = self.output_layer(curr_out)  # bsz,seqlen,vocab_size
    return logits


# Test code

In [211]:
tblock_test = TransformerBlock(d_model=D_MODEL, d_ffn=D_FFN, num_heads=NUM_HEADS, max_seqlen=SEQLEN)
input_act = torch.randn(BSZ, SEQLEN, D_MODEL)
tblock_out = tblock_test(input_act)
assert tblock_out.shape == (BSZ, SEQLEN, D_MODEL)

# Dataloader

In [212]:
EOS_TOKEN=-1

class DigitDataset(IterableDataset):
  def __init__(self, num_digits=10):
    self.num_digits = 10

  def __iter__(self):
    while True:
      for i in range(self.num_digits):
        yield i


def collate_batch_digits(batch, seqlen, bsz):
  x = torch.tensor(batch, dtype=torch.long).view(bsz, seqlen)
  y = torch.roll(x, shifts=-1, dims=(1,))
  y[:,-1] = EOS_TOKEN
  return x,y

# dataset that spits out 8 unsorted digits, then sorts them all

class SortedDigitsDataset(IterableDataset):
  # outputs a full batch
  def __init__(self, num_digits=10, max_seqlen=8, eos_token=EOS_TOKEN):
    self.num_digits = 10
    self.max_seqlen = max_seqlen
    self.eos_token = eos_token

  def __iter__(self):
    while True:
      half = int(self.max_seqlen//2)
      half_batch = torch.randint(low=0,high=(self.num_digits-1), size=(half,))
      sorted = torch.sort(half_batch).values
      batch = torch.cat([half_batch, sorted])
      targets = torch.roll(batch, shifts=-1)
      targets[:int(half)-1] = self.eos_token
      targets[-1] = self.eos_token
      #print(f"batch:{batch}\n targets:{targets}\n")
      yield (batch, targets)


# Dataloader test

In [202]:
sorted_digits_dataset = SortedDigitsDataset(max_seqlen=SEQLEN)
sorted_dataloader = DataLoader(dataset=sorted_digits_dataset, batch_size=BSZ, shuffle=False)

an_iter = iter(sorted_digits_dataset)
while True:
  print(next(an_iter))
  break

batch:tensor([8, 1, 8, 3, 2, 0, 3, 1, 0, 1, 1, 2, 3, 3, 8, 8])
 targets:tensor([-1, -1, -1, -1, -1, -1, -1,  0,  1,  1,  2,  3,  3,  8,  8, -1])

(tensor([8, 1, 8, 3, 2, 0, 3, 1, 0, 1, 1, 2, 3, 3, 8, 8]), tensor([-1, -1, -1, -1, -1, -1, -1,  0,  1,  1,  2,  3,  3,  8,  8, -1]))


# Trainer

In [217]:
STEPS=5001
D_MODEL=16
D_FFN=64
NUM_LAYERS=2
VOCAB_SIZE=16
SEQLEN=16
BSZ=32
NUM_HEADS=4
HEAD_DIM=D_MODEL//NUM_HEADS

model = Transformer(d_model=D_MODEL, d_ffn=D_FFN, vocab_size=VOCAB_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, max_seqlen=SEQLEN, simple_pos_embed=True)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.001, weight_decay=0)

#digits = DigitDataset()
#digit_dataloader = DataLoader(dataset=digits, batch_size=SEQLEN*BSZ, shuffle=False, collate_fn=lambda b:collate_batch_digits(batch=b, seqlen=SEQLEN, bsz=BSZ))
#dl_iter = iter(digit_dataloader)


sorted_digits_dataset = SortedDigitsDataset(max_seqlen=SEQLEN)
sorted_dataloader = DataLoader(dataset=sorted_digits_dataset, batch_size=BSZ, shuffle=False)
dl_iter = iter(sorted_dataloader)

for step in range(STEPS):
  optimizer.zero_grad()
  tokens, targets = next(dl_iter) #targets = (bsz, seqlen)   tokens=(bsz, seqlen)
  logits = model(tokens)  # (bsz, seqlen, vocab_size)
  vocab_size = logits.shape[-1]
  loss = torch.nn.functional.cross_entropy(input=logits.view(-1, vocab_size), target=targets.view(-1), ignore_index=EOS_TOKEN)
  if step %100 == 0:
    print(f"{step=} {loss=}")
    #print(f"tokens:{tokens[0,:]} \n targets:{targets[0,:]} \n\n")
    #print(f"tokens:{tokens[0,:]} \n targets:{targets[0,:]} \n probs:\n{logits[0,:,:]}\n")
  loss.backward()
  optimizer.step()




step=0 loss=tensor(26.004, grad_fn=<NllLossBackward0>)
step=100 loss=tensor(1.815, grad_fn=<NllLossBackward0>)
step=200 loss=tensor(1.312, grad_fn=<NllLossBackward0>)
step=300 loss=tensor(1.234, grad_fn=<NllLossBackward0>)
step=400 loss=tensor(1.178, grad_fn=<NllLossBackward0>)
step=500 loss=tensor(1.153, grad_fn=<NllLossBackward0>)
step=600 loss=tensor(1.126, grad_fn=<NllLossBackward0>)
step=700 loss=tensor(1.037, grad_fn=<NllLossBackward0>)
step=800 loss=tensor(0.958, grad_fn=<NllLossBackward0>)
step=900 loss=tensor(0.934, grad_fn=<NllLossBackward0>)
step=1000 loss=tensor(0.911, grad_fn=<NllLossBackward0>)
step=1100 loss=tensor(0.839, grad_fn=<NllLossBackward0>)
step=1200 loss=tensor(0.855, grad_fn=<NllLossBackward0>)
step=1300 loss=tensor(0.766, grad_fn=<NllLossBackward0>)
step=1400 loss=tensor(0.700, grad_fn=<NllLossBackward0>)
step=1500 loss=tensor(0.612, grad_fn=<NllLossBackward0>)
step=1600 loss=tensor(0.513, grad_fn=<NllLossBackward0>)
step=1700 loss=tensor(0.399, grad_fn=<NllL

In [None]:
# generation
# dumb way: 1 token at a time. get logits, do softmax, sample token. feed in full generated thing back

# kv cache way: attention takes in a kv cache which are pre-computed k and v values for previous tokens



# Simple generation

In [113]:
def softmax_with_temp_last_dim(logits: torch.Tensor, temp: float=1.0, eps=1e-6):
  # softmax across the innermost (last) dimension
  # compute max
  # subtract max
  # numerator = exponentiate with temperature
  # denominator = sum of numerator
  #import pdb; pdb.set_trace()
  # logits: (..., vocab_size)
  maxval = torch.amax(logits, dim=-1).unsqueeze(-1)
  max_subtracted_logits = logits - maxval
  scaled_max_subtracted_logits = max_subtracted_logits/torch.Tensor([temp+eps])
  numerator = torch.exp(scaled_max_subtracted_logits)
  denominator = torch.sum(numerator, dim=-1).unsqueeze(-1)
  return numerator/(denominator + eps)


def retain_top_p(probs: torch.Tensor, p: float):
  # retain only probabilities with mass = top_p
  assert len(probs.shape) == 1 # batch later
  sorted_probs = torch.sort(probs, descending=True)
  sorted_probs_values = sorted_probs.values
  sorted_probs_indices = sorted_probs.indices
  cumsum_probs = list(torch.cumsum(sorted_probs_values, dim=0))
  for i, prob in enumerate(cumsum_probs):  # super ugly, need vectorized version
    if prob > p:
      break
  not_useful_indices = sorted_probs_indices[(i+1):]
  if len(not_useful_indices) > 0:
    probs[not_useful_indices] = 0 # TODO check here
  return probs



def sample(logits: torch.Tensor, top_p: Optional[float]=None, temp: float = 1.0):
  # logits: (vocab_size,)
  probs = softmax_with_temp_last_dim(logits=logits, temp=temp)
  if top_p:
    probs = retain_top_p(probs=probs, p=top_p)
  sampled = np.random.multinomial(n=1, pvals=probs)
  return int(np.argwhere(sampled)[0][0])


def generate(m: torch.nn.Module, input_tokens: torch.Tensor, toks_to_generate: int):
  # for each token, run forward
    # run forward get logits
    # convert logits to probabilities
    # sample from probabilities
    # update input tokens
  pass

# Simple tests

In [120]:
logits = torch.tensor([5,1,1,5,1])
sample(logits, temp=0)


3

# Sample

In [62]:
probs

tensor([0.487, 0.009, 0.009, 0.487, 0.009])

In [57]:
int(np.argwhere(np.random.multinomial(n=1, pvals=probs))[0][0])

3