In [2]:
import random as r

from engine import Value
from modules import *
from ops import *
from gpt import *

In [23]:
batch_size = 2
vocab_len = 10
model_dim = 4
max_seq_len = 5
seq_len = 3
num_heads = 2
head_dim = 2

In [8]:
class MultiHeadSelfAttention(Module):
    def __init__(self, model_dim, num_heads, head_dim, max_seq_len):
        self.Wq = Linear(model_dim, num_heads * head_dim)
        self.Wk = Linear(model_dim, num_heads * head_dim)
        self.Wv = Linear(model_dim, num_heads * head_dim)
        self.Wo = Linear(num_heads * head_dim, model_dim)

        self.mask = Mask(max_seq_len)

        self.scale = head_dim ** -0.5
    
    def __call__(self, x):
        assert isinstance(x, list) and isinstance(x[0], list) and isinstance(x[0][0], list) and isinstance(x[0][0][0], Value),\
            "input to MHSA mechanism must be tensor of ndim==3 for (batch_size, seq_len, model_dim)"
        batch_size, seq_len, model_dim = tuple(get_shape(x))
        assert self.model_dim == model_dim,\
            f"input final dimension {model_dim} must equal MHSA mechanism's given model_dim value at initialization of {self.model_dim}"

        # apply query, key, and value projections to our input
        q = vector_wise_apply(self.Wq, x) # shape (batch_size, seq_len, num_heads * head_dim)
        k = vector_wise_apply(self.Wk, x) # Linear object is meant to take in a single vector, so we use vector_wise_apply
        v = vector_wise_apply(self.Wv, x)
        print(get_shape(q))

        # split apart heads
        q = vector_wise_apply(split_dim, q, dims=(self.num_heads, self.head_dim)) # shape (batch_size, seq_len, num_heads, head_dim)
        k = vector_wise_apply(split_dim, k, dims=(self.num_heads, self.head_dim))
        v = vector_wise_apply(split_dim, v, dims=(self.num_heads, self.head_dim))
        print(get_shape(q))

        # transpose to put seq_len in the path of the matmul for our attention computation
        q = transpose(q, (1,2)) # shape (batch_size, num_heads, seq_len, head_dim)
        k = transpose(k, (1,2))
        v = transpose(v, (1,2))
        print(get_shape(q))

        # get keys ready for attention computation
        k_t = transpose(k, (2,3)) # shape (batch_size, num_heads, head_dim, seq_len)
        # compute attention logits
        logits = tensor_matmul(q, k_t) # shape (batch_size, num_heads, seq_len, seq_len)
        # scale logits
        scaled_logits = vector_wise_apply(mult_vec_by_float, logits, self.scale)
        # apply mask
        masked_logits = matrix_wise_apply(self.mask.masked_fill, x)
        # turn the logits into probability scores
        scores = vector_wise_apply(softmax, masked_logits)

        # use scores to select from values
        output_values = tensor_matmul(scores, v) # shape (batch_size, num_heads, seq_len, head_dim)
        # rearrange back to be of size model_dim
        output_values = transpose(output_values, (1,2)) # shape (batch_size, seq_len, num_heads, head_dim)
        # okay i need the opposite of split_dim
        pass

In [122]:
def flatten(mat):
    assert isinstance(mat[0], list) and isinstance(mat[0][0], Value),\
        'mat should be a matrix (AKA list of list of Value objects)'
    m, n = len(mat), len(mat[0])
    return [mat[i][j] for i in range(m) for j in range(n)]

In [127]:
x = [[[[Value(r.uniform(-1,1)) for _ in range(head_dim)]
       for _ in range(num_heads)]
      for _ in range(seq_len)]
     for _ in range(batch_size)]
pretty_print_tensor(x)
y = matrix_wise_apply(flatten, x)
print(get_shape(y))
pretty_print_tensor(y)

[
  [
    [
      [Value(data=0.872, grad=0.000), Value(data=0.104, grad=0.000)]
      [Value(data=0.767, grad=0.000), Value(data=0.080, grad=0.000)]
    ]
    [
      [Value(data=-0.861, grad=0.000), Value(data=0.657, grad=0.000)]
      [Value(data=0.693, grad=0.000), Value(data=-0.794, grad=0.000)]
    ]
    [
      [Value(data=0.796, grad=0.000), Value(data=-0.575, grad=0.000)]
      [Value(data=0.736, grad=0.000), Value(data=-0.830, grad=0.000)]
    ]
  ]
  [
    [
      [Value(data=-0.403, grad=0.000), Value(data=0.730, grad=0.000)]
      [Value(data=0.912, grad=0.000), Value(data=-0.754, grad=0.000)]
    ]
    [
      [Value(data=-0.281, grad=0.000), Value(data=-0.284, grad=0.000)]
      [Value(data=-0.903, grad=0.000), Value(data=0.607, grad=0.000)]
    ]
    [
      [Value(data=0.047, grad=0.000), Value(data=0.342, grad=0.000)]
      [Value(data=-0.989, grad=0.000), Value(data=0.362, grad=0.000)]
    ]
  ]
]
[2, 3, 4]
[
  [
    [Value(data=0.872, grad=0.000), Value(data=0.104, 