<a href="https://colab.research.google.com/github/kmalik22/colabs/blob/main/transformer_training_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [298]:
import torch
import numpy as np
import torch.nn.functional as F
import itertools
from typing import Optional, Tuple, List
from torch.utils.data import IterableDataset, DataLoader
torch.set_printoptions(linewidth=250, precision=3)

In [398]:
a = torch.arange(24).reshape(2,3,4)
b = torch.arange(100,124).reshape(2,3,4)
print(a)
print(b)

torch.cat([a, b], dim=1) #expected shape: 2,6,4. 0-11. then 100-111

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[[100, 101, 102, 103],
         [104, 105, 106, 107],
         [108, 109, 110, 111]],

        [[112, 113, 114, 115],
         [116, 117, 118, 119],
         [120, 121, 122, 123]]])


tensor([[[  0,   1,   2,   3],
         [  4,   5,   6,   7],
         [  8,   9,  10,  11],
         [100, 101, 102, 103],
         [104, 105, 106, 107],
         [108, 109, 110, 111]],

        [[ 12,  13,  14,  15],
         [ 16,  17,  18,  19],
         [ 20,  21,  22,  23],
         [112, 113, 114, 115],
         [116, 117, 118, 119],
         [120, 121, 122, 123]]])

# Modules

In [417]:
class MLP(torch.nn.Module):
  def __init__(self, d_model, d_ffn):
    super().__init__()
    self.d_model = d_model
    self.d_ffn = d_ffn
    self.fc1 = torch.nn.Linear(in_features=d_model, out_features=d_ffn, bias=False)
    torch.nn.init.trunc_normal_(self.fc1.weight, mean=0, std=1/d_model**0.5)
    self.fc2 = torch.nn.Linear(in_features=d_ffn, out_features=d_model, bias=False)
    torch.nn.init.trunc_normal_(self.fc2.weight, mean=0, std=1/d_ffn**0.5, a=-3, b=3)

  def forward(self, x: torch.Tensor):
    fc1_out = self.fc1(x)
    relu_out = torch.nn.functional.relu(fc1_out)
    fc2_out = self.fc2(relu_out)
    return fc2_out

class MultiHeadAttention(torch.nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model//num_heads
    self.wqkv = torch.nn.Linear(d_model, int(3*d_model), bias=False)
    torch.nn.init.trunc_normal_(self.wqkv.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    self.wo = torch.nn.Linear(d_model, d_model)
    torch.nn.init.trunc_normal_(self.wo.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    # TODO: add rope

  def scaled_dot_product_attn(self, q, k, v):
    # q = (bsz, num_heads, gen_seqlen, head_dim)
    # ctx_seqle+gen_seqlen = total_seqlen
    # k, v = (bsz, num_heads, total_seqlen, head_dim)
    # qk.t()/sqrt(head_dim)
    # causal mask
    # softmax
    # @ v
    gen_seqlen = q.shape[2]
    total_seqlen = k.shape[2]
    assert gen_seqlen <= total_seqlen
    attn_wts = q @ k.transpose(2,3)  #(bsz,num_h,gen_seqlen,head_dim) (bsz,num_h,head_dim,total_seqlen) -> (bsz,num_h,gen_seqlen,total_seqlen)
    # create mask, do torch.where. 1=use attn wts, else use -inf
    mask = torch.tril(torch.ones(total_seqlen,total_seqlen)).to(torch.bool)
    mask = mask[(total_seqlen - gen_seqlen):,:]
    masked_attn_wts = torch.where(mask, attn_wts, float('-inf')) # (bsz,num_heads,gen_seqlen,total_seqlen)

    # softmax TODO check once
    softmax_wts = torch.nn.functional.softmax(masked_attn_wts,dim=-1)
    # (bsz,num_heads,gen_seqlen,total_seqlen) @ (bsz,num_heads,total_seqlen,head_dim) -> (bsz,num_heads,gen_seqlen,head_dim)
    #print(f"{softmax_wts.shape=}  {v.shape=}")
    return softmax_wts @ v


  def forward(self, x:torch.Tensor, kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]]=None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    #x = (bsz, gen_seqlen, d_model)
    bsz, gen_seqlen, d_model = x.shape
    assert d_model == self.d_model
    if kv_cache is not None:
      assert kv_cache[0].shape[0] == bsz
      assert kv_cache[0].shape[3] == self.head_dim, f"{kv_cache[0].shape[3]=} {self.head_dim=}"
      assert kv_cache[0].shape[1] == self.num_heads, f"{kv_cache[0].shape[2]=} {self.num_heads=}"
      assert kv_cache[0].shape == kv_cache[1].shape
      ctx_seqlen = kv_cache[0].shape[1]
    else:
      ctx_seqlen = gen_seqlen
    wqkv_out = self.wqkv(x) # (bsz, gen_seqlen, 3*d_model)
    #print(f"wqkv_out:{wqkv_out[0,0,:].norm()}")
    # rearrange so seqlen,head_dim are the last two dims
    wqkv_out = wqkv_out.reshape(bsz, gen_seqlen, 3, self.num_heads, self.head_dim)
    #(bsz,seqlen,3,num_heads,head_dim) -> (bsz,3,num_heads,seqlen,head_dim)
    #  0    1    2    3        4
    wqkv_out = wqkv_out.permute(0, 2, 3, 1, 4)  # (bsz, 3, num_heads, seqlen, head_dim)
    q = wqkv_out[:,0,:,:,:] # (bsz, num_heads, gen_seqlen, head_dim)
    k = wqkv_out[:,1,:,:,:] # (bsz, num_heads, gen_seqlen, head_dim)
    v = wqkv_out[:,2,:,:,:] # (bsz, num_heads, gen_seqlen, head_dim)
    if kv_cache is not None:
      #print(f"{k.shape=} {kv_cache[0].shape=} final_k.shape={torch.cat([kv_cache[0], k], dim=2).shape}")
      k = torch.cat([kv_cache[0], k], dim=2)
      v = torch.cat([kv_cache[1], v], dim=2)

    #print(f"q first token: {q[0,:,0,:].norm()}  q second token: {q[0,:,1,:].norm()}")
    #print(f"k first token: {k[0,:,0,:].norm()}  k second token: {k[0,:,1,:].norm()}")
    #print(f"v first token: {v[0,:,0,:].norm()}  v second token: {v[0,:,1,:].norm()}")

    # now k and v are (bsx, num_hads, (ctx_seqlen+gen_seqlen), head_dim)
    # TODO: add rope to q and k
    #print(f"q:{q[0,:,0,:].norm()}, k:{k[0,:,0,:].norm()}, v:{v[0,:,0,:].norm()}")
    self_attn_out = self.scaled_dot_product_attn(q, k, v) #(bsz, num_heads, gen_seqlen, head_dim)
    #print(f"self_attn_out first token:{self_attn_out[0,:,0,:].norm()} self_attn_out second token:{self_attn_out[0,:,1,:].norm()}")
    # transpose
    self_attn_out = self_attn_out.transpose(1,2) # (bsz,gen_seqlen,num_heads,head_dim)
    # concat the last dim
    self_attn_out = self_attn_out.reshape(bsz,gen_seqlen,-1)
    # wo
    self_attn_out = self_attn_out.reshape(bsz, gen_seqlen, d_model)
    wo_out = self.wo(self_attn_out)
    return wo_out, (k, v)


class TransformerBlock(torch.nn.Module):
  def __init__(self, d_model, d_ffn, num_heads, max_seqlen):
    super().__init__()
    self.mlp_norm = torch.nn.RMSNorm(normalized_shape=d_model)
    self.mlp = MLP(d_model=d_model, d_ffn=d_ffn)
    self.attn_norm = torch.nn.RMSNorm(normalized_shape=d_model)
    self.attn = MultiHeadAttention(d_model=d_model,num_heads=num_heads)


  def forward(self, x:torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    # norm --> attn, add back to original --> norm --> mlp, add back
    attn_norm_out = self.attn_norm(x)
    #print(f"attn_norm_out: {attn_norm_out[0,0,:]}")
    attn_out, (k, v) = self.attn(attn_norm_out)
    #print(f"attn_out: {attn_out[0,0,:]}")
    attn_out_post_resid = x + attn_out
    mlp_norm_out = self.mlp_norm(attn_out_post_resid)
    mlp_out = self.mlp(mlp_norm_out)
    return mlp_out + attn_out_post_resid, (k, v)

class Transformer(torch.nn.Module):
  def __init__(self, d_model, d_ffn, vocab_size, num_layers, max_seqlen, num_heads, simple_pos_embed=False):
    super().__init__()
    self.max_seqlen = max_seqlen
    self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
    torch.nn.init.trunc_normal_(self.embedding.weight, mean=0, std=1/d_model**0.5, a=-3, b=-3)

    self.simple_pos_embed = simple_pos_embed
    if simple_pos_embed:
      self.pos_embedding = torch.nn.Embedding(num_embeddings=max_seqlen, embedding_dim=d_model)

    self.output_layer = torch.nn.Linear(in_features=d_model, out_features=vocab_size)
    torch.nn.init.trunc_normal_(self.output_layer.weight, mean=0, std=1/d_model**0.5, a=-3, b=3)
    self.layers = torch.nn.ModuleList(
        [TransformerBlock(d_model=d_model, d_ffn=d_ffn, num_heads=num_heads, max_seqlen=max_seqlen) for _ in range(num_layers)]
    )
    # TODO: do proper init
    # add attention
    # add layer norms

  def forward(self, tokens: torch.Tensor) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
    # tokens: (bsz,seqlen,d_model)
    kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
    seqlen = tokens.shape[1]
    curr_out = self.embedding(tokens)
    if self.simple_pos_embed:
      curr_out = curr_out + self.pos_embedding(torch.arange(seqlen))
    for l in self.layers:
      layer_out, (k, v) = l(curr_out)
      curr_out = curr_out + layer_out
      kv_cache.append((k,v))
    logits = self.output_layer(curr_out)  # bsz,seqlen,vocab_size
    return logits, kv_cache


# Test code: transformer block is causal

In [418]:
with torch.no_grad():
  tblock_test = TransformerBlock(d_model=D_MODEL, d_ffn=D_FFN, num_heads=NUM_HEADS, max_seqlen=SEQLEN)
  input_act = torch.randn(BSZ, SEQLEN, D_MODEL)
  tblock_out, _ = tblock_test(input_act)
  assert tblock_out.shape == (BSZ, SEQLEN, D_MODEL)

  input_act_trunc = input_act[:,:2,:]
  tblock_out2, _ = tblock_test(input_act_trunc)

  assert torch.all(torch.isclose(tblock_out[0,0,:], tblock_out2[0,0,:]))
  print("Logits match")

Logits match


# Test: MHA with KV cache works same as MHA

In [419]:
# step 1: just return KV values
mha_test = MultiHeadAttention(d_model=D_MODEL, num_heads=NUM_HEADS)
input_act = torch.randn(BSZ, SEQLEN, D_MODEL)
mha_out, kv_cache = mha_test(input_act)
#print(f"input second token:{input_act[0,1,:].norm()} second token: {mha_out[0,1,:].norm()}")


# step 2: eat first token in input. feed this token in as kv-cache instead
input_act2 = input_act[:,1:,:]
# kv_cache: (bsz,num_heads,seqlen,head_dim)
kv_first_token = kv_cache[0][:,:,:1,:], kv_cache[1][:,:,:1,:]
mha_out2, kv_cache2 = mha_test(input_act2, kv_cache=kv_first_token)
#print(f"input second token:{input_act2[0,0,:].norm()} second token: {mha_out2[0,0,:].norm()}")

assert torch.all(torch.isclose(mha_out[0,1,:], mha_out2[0,0,:]))
print("Output token logits match")


Output token logits match


# Dataloader

In [420]:
EOS_TOKEN=-1

class DigitDataset(IterableDataset):
  def __init__(self, num_digits=10):
    self.num_digits = 10

  def __iter__(self):
    while True:
      for i in range(self.num_digits):
        yield i


def collate_batch_digits(batch, seqlen, bsz):
  x = torch.tensor(batch, dtype=torch.long).view(bsz, seqlen)
  y = torch.roll(x, shifts=-1, dims=(1,))
  y[:,-1] = EOS_TOKEN
  return x,y

# dataset that spits out 8 unsorted digits, then sorts them all

class SortedDigitsDataset(IterableDataset):
  # outputs a full batch
  def __init__(self, num_digits=10, max_seqlen=8, eos_token=EOS_TOKEN):
    self.num_digits = 10
    self.max_seqlen = max_seqlen
    self.eos_token = eos_token

  def __iter__(self):
    while True:
      half = int(self.max_seqlen//2)
      half_batch = torch.randint(low=0,high=(self.num_digits-1), size=(half,))
      sorted = torch.sort(half_batch).values
      batch = torch.cat([half_batch, sorted])
      targets = torch.roll(batch, shifts=-1)
      targets[:int(half)-1] = self.eos_token
      targets[-1] = self.eos_token
      #print(f"batch:{batch}\n targets:{targets}\n")
      yield (batch, targets)


# Dataloader test

In [421]:
sorted_digits_dataset = SortedDigitsDataset(max_seqlen=SEQLEN)
sorted_dataloader = DataLoader(dataset=sorted_digits_dataset, batch_size=BSZ, shuffle=False)

an_iter = iter(sorted_digits_dataset)
while True:
  print(next(an_iter))
  break

(tensor([0, 0, 0, 7, 7, 1, 7, 6, 0, 0, 0, 1, 6, 7, 7, 7]), tensor([-1, -1, -1, -1, -1, -1, -1,  0,  0,  0,  1,  6,  7,  7,  7, -1]))


# Constants

In [422]:
D_MODEL=16
D_FFN=64
NUM_LAYERS=2
VOCAB_SIZE=16
SEQLEN=16
BSZ=32
NUM_HEADS=4
HEAD_DIM=D_MODEL//NUM_HEADS

# Dataloader

In [423]:
sorted_digits_dataset = SortedDigitsDataset(max_seqlen=SEQLEN)
sorted_dataloader = DataLoader(dataset=sorted_digits_dataset, batch_size=BSZ, shuffle=False)
dl_iter = iter(sorted_dataloader)

# Trainer

In [424]:

STEPS=5001

torch.manual_seed(100)
model = Transformer(d_model=D_MODEL, d_ffn=D_FFN, vocab_size=VOCAB_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, max_seqlen=SEQLEN, simple_pos_embed=True)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.001, weight_decay=0)

#digits = DigitDataset()
#digit_dataloader = DataLoader(dataset=digits, batch_size=SEQLEN*BSZ, shuffle=False, collate_fn=lambda b:collate_batch_digits(batch=b, seqlen=SEQLEN, bsz=BSZ))
#dl_iter = iter(digit_dataloader)


for step in range(STEPS):
  optimizer.zero_grad()
  tokens, targets = next(dl_iter) #targets = (bsz, seqlen)   tokens=(bsz, seqlen)
  logits, _ = model(tokens)  # (bsz, seqlen, vocab_size)
  vocab_size = logits.shape[-1]
  loss = torch.nn.functional.cross_entropy(input=logits.view(-1, vocab_size), target=targets.view(-1), ignore_index=EOS_TOKEN)
  if step %100 == 0:
    print(f"{step=} {loss=}")
    #print(f"tokens:{tokens[0,:]} \n targets:{targets[0,:]} \n\n")
    #print(f"tokens:{tokens[0,:]} \n targets:{targets[0,:]} \n probs:\n{logits[0,:,:]}\n")
  loss.backward()
  optimizer.step()




step=0 loss=tensor(26.213, grad_fn=<NllLossBackward0>)
step=100 loss=tensor(2.689, grad_fn=<NllLossBackward0>)
step=200 loss=tensor(1.623, grad_fn=<NllLossBackward0>)
step=300 loss=tensor(1.298, grad_fn=<NllLossBackward0>)
step=400 loss=tensor(1.338, grad_fn=<NllLossBackward0>)
step=500 loss=tensor(1.159, grad_fn=<NllLossBackward0>)
step=600 loss=tensor(1.160, grad_fn=<NllLossBackward0>)
step=700 loss=tensor(1.182, grad_fn=<NllLossBackward0>)
step=800 loss=tensor(1.132, grad_fn=<NllLossBackward0>)
step=900 loss=tensor(1.042, grad_fn=<NllLossBackward0>)
step=1000 loss=tensor(1.032, grad_fn=<NllLossBackward0>)
step=1100 loss=tensor(1.007, grad_fn=<NllLossBackward0>)
step=1200 loss=tensor(0.990, grad_fn=<NllLossBackward0>)
step=1300 loss=tensor(0.941, grad_fn=<NllLossBackward0>)
step=1400 loss=tensor(0.948, grad_fn=<NllLossBackward0>)
step=1500 loss=tensor(0.887, grad_fn=<NllLossBackward0>)
step=1600 loss=tensor(0.813, grad_fn=<NllLossBackward0>)
step=1700 loss=tensor(0.806, grad_fn=<NllL

In [425]:
# kv cache
# mha takes in kv_cache. if present, k and v values come from cache, PLUS add k and v for the current token
# wqkv operates on seqlen=1. q will one be 1


# Simple generation

In [426]:
def softmax_with_temp_last_dim(logits: torch.Tensor, temp: float=1.0, eps=1e-6):
  # softmax across the innermost (last) dimension
  # compute max
  # subtract max
  # numerator = exponentiate with temperature
  # denominator = sum of numerator
  #import pdb; pdb.set_trace()
  # logits: (..., vocab_size)
  maxval = torch.amax(logits, dim=-1).unsqueeze(-1)
  max_subtracted_logits = logits - maxval
  scaled_max_subtracted_logits = max_subtracted_logits/torch.Tensor([temp+eps])
  numerator = torch.exp(scaled_max_subtracted_logits)
  denominator = torch.sum(numerator, dim=-1).unsqueeze(-1)
  return numerator/(denominator + eps)


def retain_top_p(probs: torch.Tensor, p: float):
  # retain only probabilities with mass = top_p
  assert len(probs.shape) == 1 # batch later
  sorted_probs = torch.sort(probs, descending=True)
  sorted_probs_values = sorted_probs.values
  sorted_probs_indices = sorted_probs.indices
  cumsum_probs = list(torch.cumsum(sorted_probs_values, dim=0))
  for i, prob in enumerate(cumsum_probs):  # super ugly, need vectorized version
    if prob > p:
      break
  not_useful_indices = sorted_probs_indices[(i+1):]
  if len(not_useful_indices) > 0:
    probs[not_useful_indices] = 0 # TODO check here
  return probs



def sample(logits: torch.Tensor, top_p: Optional[float]=None, temp: float = 1.0):
  # logits: (vocab_size,)
  probs = softmax_with_temp_last_dim(logits=logits, temp=temp)
  #print(probs)
  if top_p:
    probs = retain_top_p(probs=probs, p=top_p)
  sampled = np.random.multinomial(n=1, pvals=probs)
  return int(np.argwhere(sampled)[0][0])


def generate(the_model: torch.nn.Module, input_tokens: torch.Tensor, toks_to_generate: int, temp=1.0):
  # for each token, run forward
    # run forward get logits
    # convert logits to probabilities
    # sample from probabilities
    # update input tokens
  generated_toks = []
  for _ in range(toks_to_generate):
    logits, _ = the_model(input_tokens)
    #print(f"logits.shape:{logits.shape}")
    #print(f"full_logits:{logits}")
    logits = logits.squeeze(0)[-1,:] #unsqueeze to remove batch dim. -1 for last token
    #print(f"logits:{logits}")
    sampled_token = sample(logits, temp=temp)
    generated_toks.append(sampled_token)
    #print(input_tokens.shape)
    #print(torch.tensor([sampled_token]).unsqueeze(0).shape)
    input_tokens = torch.cat([input_tokens.squeeze(0), torch.tensor([sampled_token])]).unsqueeze(0)
  return generated_toks

# Simple tests

In [427]:
input, targets = next(dl_iter) #get the first entry only
input = input[:1,:]
targets = targets[:1,:]
print(input)
print(targets)
trunc_input = input[:1,:int(SEQLEN//2)]
print(trunc_input)

tensor([[4, 0, 8, 0, 8, 2, 4, 4, 0, 0, 2, 4, 4, 4, 8, 8]])
tensor([[-1, -1, -1, -1, -1, -1, -1,  0,  0,  2,  4,  4,  4,  8,  8, -1]])
tensor([[4, 0, 8, 0, 8, 2, 4, 4]])


In [428]:
logits_raw, _ = model(input)
print(f"raw_logits:{logits_raw}")
print(f"logits for last:{logits_raw.squeeze(0)[7,:]}")

raw_logits:tensor([[[ -8.898, -15.959,  -1.909,  16.486,  30.774,  -1.500,  -7.926, -21.109, -37.579,   0.783, -13.763,  -6.702,   0.223, -23.742, -11.606, -13.094],
         [ 62.565,  12.333, -24.260, -21.372, -27.307, -25.315, -41.662, -17.038, -15.714,  12.552,  15.028,   8.288,   2.274,   5.355, -12.679,  23.037],
         [ 34.121,  20.171, -16.655,  -9.214, -10.213,  -6.413, -34.274, -19.019,   5.486,   8.353,  17.150,   6.676,  -1.785,   8.697,  -6.161,  10.700],
         [ 51.282,  24.071, -26.236, -13.121, -26.347, -27.381, -35.450, -21.988,  -5.609,   5.431,  15.267,   3.178,  -4.136,   4.262, -19.274,  10.499],
         [-19.772, -11.509, -15.624,   1.613,  20.270,  28.447,   2.317, -12.397,   1.911,  -5.104,  -4.801,  -3.112, -15.362,  -9.443,  -7.918, -10.321],
         [ 34.789,   5.031,  -1.144, -14.085,  -1.573,  -8.980, -31.694, -24.887, -25.228,  14.106,   6.898,  15.158,   7.803,   7.606,  -1.116,  14.085],
         [ 35.788,   1.865, -19.196, -20.822,   1.793, -10.

In [429]:
# test for sampling functions
#logits = torch.tensor([5,1,1,5,1])
#sample(logits, temp=0)

with torch.no_grad():
  # test for full generation
  input = next(dl_iter)[0][0,:].unsqueeze(0) #get the first entry only
  trunc_input = input[:1,:int(SEQLEN//2)]
  generated = generate(the_model=model, input_tokens=trunc_input, toks_to_generate=8, temp=0.01)
  print(f"input:{input}, expected_op:{input[:1, int(SEQLEN//2):]},  generated:{generated}")
  assert torch.all(torch.isclose(input[:1, int(SEQLEN//2):], torch.tensor(generated)))
  print(f"Output is sorted")


input:tensor([[6, 6, 4, 1, 8, 0, 3, 5, 0, 1, 3, 4, 5, 6, 6, 8]]), expected_op:tensor([[0, 1, 3, 4, 5, 6, 6, 8]]),  generated:[0, 1, 3, 4, 5, 6, 6, 8]
Output is sorted


# Sample

In [318]:
model

Transformer(
  (embedding): Embedding(16, 16)
  (pos_embedding): Embedding(16, 16)
  (output_layer): Linear(in_features=16, out_features=16, bias=True)
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (mlp_norm): RMSNorm((16,), eps=None, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=16, out_features=64, bias=False)
        (fc2): Linear(in_features=64, out_features=16, bias=False)
      )
      (attn_norm): RMSNorm((16,), eps=None, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (wqkv): Linear(in_features=16, out_features=48, bias=False)
        (wo): Linear(in_features=16, out_features=16, bias=True)
      )
    )
  )
)

In [57]:
int(np.argwhere(np.random.multinomial(n=1, pvals=probs))[0][0])

3

In [297]:
model

Transformer(
  (embedding): Embedding(16, 16)
  (pos_embedding): Embedding(16, 16)
  (output_layer): Linear(in_features=16, out_features=16, bias=True)
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (mlp_norm): RMSNorm((16,), eps=None, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=16, out_features=64, bias=False)
        (fc2): Linear(in_features=64, out_features=16, bias=False)
      )
      (attn_norm): RMSNorm((16,), eps=None, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (wqkv): Linear(in_features=16, out_features=48, bias=False)
        (wo): Linear(in_features=16, out_features=16, bias=True)
      )
    )
  )
)