<a href="https://colab.research.google.com/github/kmalik22/colabs/blob/main/transformer_training_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [125]:
import torch
import numpy as np
import torch.nn.functional as F
import itertools
from typing import Optional
from torch.utils.data import IterableDataset, DataLoader
torch.set_printoptions(linewidth=250, precision=3)

# Modules

In [271]:
class MLP(torch.nn.Module):
  def __init__(self, d_model, d_ffn):
    super().__init__()
    self.d_model = d_model
    self.d_ffn = d_ffn
    self.fc1 = torch.nn.Linear(in_features=d_model, out_features=d_ffn, bias=False)
    torch.nn.init.trunc_normal_(self.fc1.weight, mean=0, std=1/d_model**0.5)
    self.fc2 = torch.nn.Linear(in_features=d_ffn, out_features=d_model, bias=False)
    torch.nn.init.trunc_normal_(self.fc2.weight, mean=0, std=1/d_ffn**0.5, a=-3, b=3)

  def forward(self, x: torch.Tensor):
    fc1_out = self.fc1(x)
    relu_out = torch.nn.functional.relu(fc1_out)
    fc2_out = self.fc2(relu_out)
    return fc2_out

class MultiHeadAttention(torch.nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model//num_heads
    self.wqkv = torch.nn.Linear(d_model, int(3*d_model), bias=False)
    torch.nn.init.trunc_normal_(self.wqkv.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    self.wo = torch.nn.Linear(d_model, d_model)
    torch.nn.init.trunc_normal_(self.wo.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    # TODO: add rope

  def scaled_dot_product_attn(self, q, k, v):
    # q, k, v = (bsz, num_heads, seqlen, head_dim)
    # qk.t()/sqrt(head_dim)
    # causal mask
    # softmax
    # @ v
    seqlen = q.shape[2]
    attn_wts = q @ k.transpose(2,3)  #(bsz,num_h,seqlen,head_dim) (bsz,num_h,head_dim,seqlen) -> (bsz,num_h,seqlen,seqlen)
    # create mask, do torch.where. 1=use attn wts, else use -inf
    mask = torch.tril(torch.ones(seqlen,seqlen)).to(torch.bool)
    masked_attn_wts = torch.where(mask, attn_wts, float('-inf')) # (bsz,num_heads,seqlen,seqlen)

    # softmax TODO check once
    softmax_wts = torch.nn.functional.softmax(masked_attn_wts,dim=-1)
    # (bsz,num_heads,seqlen,seqlen) @ (bsz,num_heasds,seqlen,head_dim) -> (bsz,num_heads,seqlen,head_dim)
    return softmax_wts @ v


  def forward(self, x:torch.Tensor):
    #x = (bsz, seqlen, d_model)
    bsz, seqlen, d_model = x.shape
    assert d_model == self.d_model
    wqkv_out = self.wqkv(x) # (bsz, seqlen, 3*d_model)
    #print(f"wqkv_out:{wqkv_out[0,0,:].norm()}")
    # rearrange so seqlen,head_dim are the last two dims
    wqkv_out = wqkv_out.reshape(bsz, seqlen, 3, self.num_heads, self.head_dim)
    #(bsz,seqlen,3,num_heads,head_dim) -> (bsz,3,num_heads,seqlen,head_dim)
    #  0    1    2    3        4
    wqkv_out = wqkv_out.permute(0, 2, 3, 1, 4)  # (bsz, 3, num_heads, seqlen, head_dim)
    q = wqkv_out[:,0,:,:,:] # (bsz, num_heads, seqlen, head_dim)
    k = wqkv_out[:,1,:,:,:] # (bsz, num_heads, seqlen, head_dim)
    v = wqkv_out[:,2,:,:,:] # (bsz, num_heads, seqlen, head_dim)
    # TODO: add rope to q and k
    #print(f"q:{q[0,:,0,:].norm()}, k:{k[0,:,0,:].norm()}, v:{v[0,:,0,:].norm()}")
    self_attn_out = self.scaled_dot_product_attn(q, k, v) #(bsz, num_heads, seqlen, head_dim)
    #print(f"self_attn_out:{self_attn_out[0,:,0,:].norm()}")
    # transpose
    self_attn_out = self_attn_out.transpose(1,2) # (bsz,seqlen,num_heads,head_dim)
    # concat the last dim
    self_attn_out = self_attn_out.reshape(bsz,seqlen,-1)
    # wo
    self_attn_out = self_attn_out.reshape(bsz, seqlen, d_model)
    wo_out = self.wo(self_attn_out)
    return wo_out


class TransformerBlock(torch.nn.Module):
  def __init__(self, d_model, d_ffn, num_heads, max_seqlen):
    super().__init__()
    self.mlp_norm = torch.nn.RMSNorm(normalized_shape=d_model)
    self.mlp = MLP(d_model=d_model, d_ffn=d_ffn)
    self.attn_norm = torch.nn.RMSNorm(normalized_shape=d_model)
    self.attn = MultiHeadAttention(d_model=d_model,num_heads=num_heads)


  def forward(self, x:torch.Tensor):
    # norm --> attn, add back to original --> norm --> mlp, add back
    attn_norm_out = self.attn_norm(x)
    #print(f"attn_norm_out: {attn_norm_out[0,0,:]}")
    attn_out = self.attn(attn_norm_out)
    #print(f"attn_out: {attn_out[0,0,:]}")
    attn_out_post_resid = x + attn_out
    mlp_norm_out = self.mlp_norm(attn_out_post_resid)
    mlp_out = self.mlp(mlp_norm_out)
    return mlp_out + attn_out_post_resid

class Transformer(torch.nn.Module):
  def __init__(self, d_model, d_ffn, vocab_size, num_layers, max_seqlen, num_heads, simple_pos_embed=False):
    super().__init__()
    self.max_seqlen = max_seqlen
    self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
    torch.nn.init.trunc_normal_(self.embedding.weight, mean=0, std=1/d_model**0.5, a=-3, b=-3)

    self.simple_pos_embed = simple_pos_embed
    if simple_pos_embed:
      self.pos_embedding = torch.nn.Embedding(num_embeddings=max_seqlen, embedding_dim=d_model)

    self.output_layer = torch.nn.Linear(in_features=d_model, out_features=vocab_size)
    torch.nn.init.trunc_normal_(self.output_layer.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    self.layers = torch.nn.ModuleList(
        [TransformerBlock(d_model=d_model, d_ffn=d_ffn, num_heads=num_heads, max_seqlen=max_seqlen) for _ in range(num_layers)]
    )
    # TODO: do proper init
    # add attention
    # add layer norms

  def forward(self, tokens: torch.Tensor):
    # tokens: (bsz,seqlen,d_model)
    seqlen = tokens.shape[1]
    curr_out = self.embedding(tokens)
    if self.simple_pos_embed:
      curr_out = curr_out + self.pos_embedding(torch.arange(seqlen))
    for l in self.layers:
      curr_out = curr_out + l(curr_out)
    logits = self.output_layer(curr_out)  # bsz,seqlen,vocab_size
    return logits


# Test code

In [274]:
with torch.no_grad():
  tblock_test = TransformerBlock(d_model=D_MODEL, d_ffn=D_FFN, num_heads=NUM_HEADS, max_seqlen=SEQLEN)
  input_act = torch.randn(BSZ, SEQLEN, D_MODEL)
  tblock_out = tblock_test(input_act)
  assert tblock_out.shape == (BSZ, SEQLEN, D_MODEL)

  input_act_trunc = input_act[:,:2,:]
  tblock_out2 = tblock_test(input_act_trunc)

  assert torch.all(torch.isclose(tblock_out[0,0,:], tblock_out2[0,0,:]))
  print("Logits match")

Logits match


# Dataloader

In [275]:
EOS_TOKEN=-1

class DigitDataset(IterableDataset):
  def __init__(self, num_digits=10):
    self.num_digits = 10

  def __iter__(self):
    while True:
      for i in range(self.num_digits):
        yield i


def collate_batch_digits(batch, seqlen, bsz):
  x = torch.tensor(batch, dtype=torch.long).view(bsz, seqlen)
  y = torch.roll(x, shifts=-1, dims=(1,))
  y[:,-1] = EOS_TOKEN
  return x,y

# dataset that spits out 8 unsorted digits, then sorts them all

class SortedDigitsDataset(IterableDataset):
  # outputs a full batch
  def __init__(self, num_digits=10, max_seqlen=8, eos_token=EOS_TOKEN):
    self.num_digits = 10
    self.max_seqlen = max_seqlen
    self.eos_token = eos_token

  def __iter__(self):
    while True:
      half = int(self.max_seqlen//2)
      half_batch = torch.randint(low=0,high=(self.num_digits-1), size=(half,))
      sorted = torch.sort(half_batch).values
      batch = torch.cat([half_batch, sorted])
      targets = torch.roll(batch, shifts=-1)
      targets[:int(half)-1] = self.eos_token
      targets[-1] = self.eos_token
      #print(f"batch:{batch}\n targets:{targets}\n")
      yield (batch, targets)


# Dataloader test

In [276]:
sorted_digits_dataset = SortedDigitsDataset(max_seqlen=SEQLEN)
sorted_dataloader = DataLoader(dataset=sorted_digits_dataset, batch_size=BSZ, shuffle=False)

an_iter = iter(sorted_digits_dataset)
while True:
  print(next(an_iter))
  break

(tensor([1, 4, 4, 6, 1, 7, 1, 5, 1, 1, 1, 4, 4, 5, 6, 7]), tensor([-1, -1, -1, -1, -1, -1, -1,  1,  1,  1,  4,  4,  5,  6,  7, -1]))


# Constants

In [277]:
D_MODEL=16
D_FFN=64
NUM_LAYERS=2
VOCAB_SIZE=16
SEQLEN=16
BSZ=32
NUM_HEADS=4
HEAD_DIM=D_MODEL//NUM_HEADS

# Dataloader

In [278]:
sorted_digits_dataset = SortedDigitsDataset(max_seqlen=SEQLEN)
sorted_dataloader = DataLoader(dataset=sorted_digits_dataset, batch_size=BSZ, shuffle=False)
dl_iter = iter(sorted_dataloader)

# Trainer

In [279]:

STEPS=5001

model = Transformer(d_model=D_MODEL, d_ffn=D_FFN, vocab_size=VOCAB_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, max_seqlen=SEQLEN, simple_pos_embed=True)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.001, weight_decay=0)

#digits = DigitDataset()
#digit_dataloader = DataLoader(dataset=digits, batch_size=SEQLEN*BSZ, shuffle=False, collate_fn=lambda b:collate_batch_digits(batch=b, seqlen=SEQLEN, bsz=BSZ))
#dl_iter = iter(digit_dataloader)


for step in range(STEPS):
  optimizer.zero_grad()
  tokens, targets = next(dl_iter) #targets = (bsz, seqlen)   tokens=(bsz, seqlen)
  logits = model(tokens)  # (bsz, seqlen, vocab_size)
  vocab_size = logits.shape[-1]
  loss = torch.nn.functional.cross_entropy(input=logits.view(-1, vocab_size), target=targets.view(-1), ignore_index=EOS_TOKEN)
  if step %100 == 0:
    print(f"{step=} {loss=}")
    #print(f"tokens:{tokens[0,:]} \n targets:{targets[0,:]} \n\n")
    #print(f"tokens:{tokens[0,:]} \n targets:{targets[0,:]} \n probs:\n{logits[0,:,:]}\n")
  loss.backward()
  optimizer.step()




  return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)


step=0 loss=tensor(32.994, grad_fn=<NllLossBackward0>)
step=100 loss=tensor(2.026, grad_fn=<NllLossBackward0>)
step=200 loss=tensor(1.403, grad_fn=<NllLossBackward0>)
step=300 loss=tensor(1.323, grad_fn=<NllLossBackward0>)
step=400 loss=tensor(1.248, grad_fn=<NllLossBackward0>)
step=500 loss=tensor(1.168, grad_fn=<NllLossBackward0>)
step=600 loss=tensor(1.187, grad_fn=<NllLossBackward0>)
step=700 loss=tensor(1.166, grad_fn=<NllLossBackward0>)
step=800 loss=tensor(1.158, grad_fn=<NllLossBackward0>)
step=900 loss=tensor(1.108, grad_fn=<NllLossBackward0>)
step=1000 loss=tensor(1.106, grad_fn=<NllLossBackward0>)
step=1100 loss=tensor(1.066, grad_fn=<NllLossBackward0>)
step=1200 loss=tensor(1.047, grad_fn=<NllLossBackward0>)
step=1300 loss=tensor(1.003, grad_fn=<NllLossBackward0>)
step=1400 loss=tensor(0.988, grad_fn=<NllLossBackward0>)
step=1500 loss=tensor(0.954, grad_fn=<NllLossBackward0>)
step=1600 loss=tensor(0.940, grad_fn=<NllLossBackward0>)
step=1700 loss=tensor(0.803, grad_fn=<NllL

In [280]:
# generation
# dumb way: 1 token at a time. get logits, do softmax, sample token. feed in full generated thing back

# kv cache way: attention takes in a kv cache which are pre-computed k and v values for previous tokens



# Simple generation

In [286]:
def softmax_with_temp_last_dim(logits: torch.Tensor, temp: float=1.0, eps=1e-6):
  # softmax across the innermost (last) dimension
  # compute max
  # subtract max
  # numerator = exponentiate with temperature
  # denominator = sum of numerator
  #import pdb; pdb.set_trace()
  # logits: (..., vocab_size)
  maxval = torch.amax(logits, dim=-1).unsqueeze(-1)
  max_subtracted_logits = logits - maxval
  scaled_max_subtracted_logits = max_subtracted_logits/torch.Tensor([temp+eps])
  numerator = torch.exp(scaled_max_subtracted_logits)
  denominator = torch.sum(numerator, dim=-1).unsqueeze(-1)
  return numerator/(denominator + eps)


def retain_top_p(probs: torch.Tensor, p: float):
  # retain only probabilities with mass = top_p
  assert len(probs.shape) == 1 # batch later
  sorted_probs = torch.sort(probs, descending=True)
  sorted_probs_values = sorted_probs.values
  sorted_probs_indices = sorted_probs.indices
  cumsum_probs = list(torch.cumsum(sorted_probs_values, dim=0))
  for i, prob in enumerate(cumsum_probs):  # super ugly, need vectorized version
    if prob > p:
      break
  not_useful_indices = sorted_probs_indices[(i+1):]
  if len(not_useful_indices) > 0:
    probs[not_useful_indices] = 0 # TODO check here
  return probs



def sample(logits: torch.Tensor, top_p: Optional[float]=None, temp: float = 1.0):
  # logits: (vocab_size,)
  probs = softmax_with_temp_last_dim(logits=logits, temp=temp)
  print(probs)
  if top_p:
    probs = retain_top_p(probs=probs, p=top_p)
  sampled = np.random.multinomial(n=1, pvals=probs)
  return int(np.argwhere(sampled)[0][0])


def generate(the_model: torch.nn.Module, input_tokens: torch.Tensor, toks_to_generate: int, temp=1.0):
  # for each token, run forward
    # run forward get logits
    # convert logits to probabilities
    # sample from probabilities
    # update input tokens
  generated_toks = []
  for _ in range(toks_to_generate):
    logits = the_model(input_tokens)
    #print(f"logits.shape:{logits.shape}")
    #print(f"full_logits:{logits}")
    logits = logits.squeeze(0)[-1,:] #unsqueeze to remove batch dim. -1 for last token
    #print(f"logits:{logits}")
    sampled_token = sample(logits, temp=temp)
    generated_toks.append(sampled_token)
    #print(input_tokens.shape)
    #print(torch.tensor([sampled_token]).unsqueeze(0).shape)
    input_tokens = torch.cat([input_tokens.squeeze(0), torch.tensor([sampled_token])]).unsqueeze(0)
  return generated_toks


# Simple tests

In [282]:
input, targets = next(dl_iter) #get the first entry only
input = input[:1,:]
targets = targets[:1,:]
print(input)
print(targets)
trunc_input = input[:1,:int(SEQLEN//2)]
print(trunc_input)

tensor([[2, 6, 5, 1, 8, 7, 8, 8, 1, 2, 5, 6, 7, 8, 8, 8]])
tensor([[-1, -1, -1, -1, -1, -1, -1,  1,  2,  5,  6,  7,  8,  8,  8, -1]])
tensor([[2, 6, 5, 1, 8, 7, 8, 8]])


In [287]:
logits_raw = model(input)
print(f"raw_logits:{logits_raw}")
print(f"logits for last:{logits_raw.squeeze(0)[7,:]}")

raw_logits:tensor([[[ 9.250e-01,  6.369e+00,  4.578e+01,  1.682e+01,  1.213e+00, -3.247e+01, -2.743e+01, -5.083e+01, -3.037e+01, -4.624e+01, -1.019e+01, -8.336e+00, -2.496e+01, -3.601e+01, -5.231e+00, -3.079e+01],
         [ 7.895e-02, -4.665e+00, -8.957e-01, -1.325e+01,  7.796e+00,  3.988e+00,  1.104e+01, -1.361e+01, -6.082e+00, -2.768e+01, -1.301e+01, -1.166e+01, -2.218e+01, -9.998e+00, -8.620e+00, -3.163e+01],
         [-1.834e+01, -3.757e+00,  3.879e+00,  2.956e-01, -6.140e+00,  1.164e+01,  1.186e+01, -1.313e+01, -1.888e+01, -2.251e+01, -9.471e+00, -1.815e+01, -2.338e+01, -6.518e+00, -1.831e+01, -3.042e+01],
         [ 3.009e+00, -8.810e+00,  1.465e+01,  9.414e+00,  3.503e+00, -1.591e+01, -4.405e+00, -1.555e+01,  1.195e+00, -1.322e+01, -9.820e+00, -1.963e+00,  4.824e+00, -9.279e+00, -2.240e+00, -6.977e+00],
         [ 1.902e+01,  3.678e+01,  3.548e+01,  9.157e+00,  4.421e+00, -2.037e+01, -4.308e+01, -5.010e+01, -3.421e+01, -4.400e+01, -1.350e+01, -1.097e+01, -3.653e+01, -2.887e+01,

In [288]:
# test for sampling functions
#logits = torch.tensor([5,1,1,5,1])
#sample(logits, temp=0)

with torch.no_grad():
  # test for full generation
  #input = next(dl_iter)[0][0,:].unsqueeze(0) #get the first entry only
  generated = generate(the_model=model, input_tokens=trunc_input, toks_to_generate=8, temp=1)
  print(generated)


tensor([2.242e-08, 1.000e+00, 7.289e-06, 3.873e-14, 1.510e-10, 3.906e-18, 4.456e-35, 4.610e-30, 1.594e-24, 2.127e-28, 2.732e-18, 9.010e-21, 1.933e-27, 3.393e-22, 4.898e-22, 1.098e-28])
tensor([1.572e-11, 1.179e-04, 9.999e-01, 7.856e-06, 1.618e-08, 5.396e-13, 3.864e-22, 2.487e-30, 6.742e-24, 1.368e-29, 1.723e-16, 3.572e-20, 1.266e-24, 6.884e-18, 1.361e-20, 3.796e-27])
tensor([5.205e-15, 6.702e-09, 2.777e-10, 3.062e-09, 6.653e-04, 9.958e-01, 3.512e-03, 1.455e-12, 4.218e-15, 3.700e-21, 1.910e-13, 4.051e-16, 1.464e-20, 8.254e-09, 6.992e-15, 1.612e-23])
tensor([2.432e-17, 2.546e-15, 1.240e-25, 5.567e-26, 2.902e-13, 2.387e-03, 9.976e-01, 2.896e-05, 8.656e-08, 6.944e-20, 4.537e-20, 7.240e-21, 5.129e-24, 9.839e-15, 4.911e-17, 1.243e-22])
tensor([1.509e-24, 2.053e-25, 1.064e-32, 2.172e-31, 3.706e-21, 5.678e-14, 1.005e-02, 9.803e-01, 9.612e-03, 1.132e-25, 8.430e-29, 2.143e-25, 7.663e-27, 1.082e-22, 8.116e-19, 3.283e-22])
tensor([6.718e-30, 1.919e-32, 3.496e-38, 9.600e-36, 2.757e-24, 2.823e-22, 2

# Sample

In [62]:
probs

tensor([0.487, 0.009, 0.009, 0.487, 0.009])

In [57]:
int(np.argwhere(np.random.multinomial(n=1, pvals=probs))[0][0])

3

In [133]:
model

Transformer(
  (embedding): Embedding(16, 16)
  (pos_embedding): Embedding(16, 16)
  (output_layer): Linear(in_features=16, out_features=16, bias=True)
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (mlp_norm): RMSNorm((16,), eps=None, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=16, out_features=64, bias=False)
        (fc2): Linear(in_features=64, out_features=16, bias=False)
      )
      (attn_norm): RMSNorm((16,), eps=None, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (wqkv): Linear(in_features=16, out_features=48, bias=False)
        (wo): Linear(in_features=16, out_features=16, bias=True)
      )
    )
  )
)