In [2]:
import random as r

from engine import Value
from modules import *
from ops import *
from gpt import *

In [141]:
batch_size = 2
vocab_len = 10
model_dim = 8
max_seq_len = 5
seq_len = 3
num_heads = 2
head_dim = 4

In [152]:
class MultiHeadSelfAttention(Module):
    def __init__(self, model_dim, num_heads, head_dim, max_seq_len):
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        
        self.Wq = Linear(model_dim, num_heads * head_dim)
        self.Wk = Linear(model_dim, num_heads * head_dim)
        self.Wv = Linear(model_dim, num_heads * head_dim)
        
        self.scale = head_dim ** -0.5

        self.mask = Mask(max_seq_len)
        
        self.Wo = Linear(num_heads * head_dim, model_dim)
    
    def __call__(self, x):
        assert isinstance(x, list) and isinstance(x[0], list) and isinstance(x[0][0], list) and isinstance(x[0][0][0], Value),\
            "input to MHSA mechanism must be tensor of ndim==3 for (batch_size, seq_len, model_dim)"
        batch_size, seq_len, model_dim = tuple(get_shape(x))
        assert self.model_dim == model_dim,\
            f"input final dimension {model_dim} must equal MHSA mechanism's given model_dim value at initialization of {self.model_dim}"
        
        # apply query, key, and value projections to our input
        q = vector_wise_apply(self.Wq, x) # shape (batch_size, seq_len, num_heads * head_dim)
        k = vector_wise_apply(self.Wk, x) # Linear object is meant to take in a single vector, so we use vector_wise_apply
        v = vector_wise_apply(self.Wv, x)

        # split apart heads
        q = vector_wise_apply(split_dim, q, dims=(self.num_heads, self.head_dim)) # shape (batch_size, seq_len, num_heads, head_dim)
        k = vector_wise_apply(split_dim, k, dims=(self.num_heads, self.head_dim))
        v = vector_wise_apply(split_dim, v, dims=(self.num_heads, self.head_dim))

        # transpose to put seq_len in the path of the matmul for our attention computation
        q = transpose(q, (1,2)) # shape (batch_size, num_heads, seq_len, head_dim)
        k = transpose(k, (1,2))
        v = transpose(v, (1,2))

        # get keys ready for attention computation
        k_t = transpose(k, (2,3)) # shape (batch_size, num_heads, head_dim, seq_len)
        # compute attention logits
        logits = tensor_matmul(q, k_t) # shape (batch_size, num_heads, seq_len, seq_len)
        # scale logits
        scaled_logits = vector_wise_apply(mult_vec_by_float, logits, self.scale)
        # apply mask
        masked_logits = matrix_wise_apply(self.mask.masked_fill, scaled_logits)
        # turn the logits into probability scores
        scores = vector_wise_apply(softmax, masked_logits)

        # use scores to select from values
        output_values = tensor_matmul(scores, v) # shape (batch_size, num_heads, seq_len, head_dim)
        # rearrange back to be of size model_dim
        output_values = transpose(output_values, (1,2)) # shape (batch_size, seq_len, num_heads, head_dim)
        output_values = matrix_wise_apply(flatten, output_values) # shape (batch_size, seq_len, num_heads * head_dim)

        # mix output values of each head together
        return vector_wise_apply(self.Wo, output_values) # shape (batch_size, seq_len, model_dim)



In [157]:
x = [[[Value(r.uniform(-1,1)) for _ in range(model_dim)]
      for _ in range(seq_len)]
     for _ in range(batch_size)]
print(get_shape(x))
mhsa = MultiHeadSelfAttention(model_dim, num_heads, head_dim, max_seq_len)
y = mhsa(x)
print(get_shape(y))

[2, 3, 8]
[2, 3, 8]
[2, 3, 2, 4]
[2, 2, 3, 4]
[
  [
    [
      [Value(data=1.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000)]
      [Value(data=0.834, grad=0.000), Value(data=0.166, grad=0.000), Value(data=0.000, grad=0.000)]
      [Value(data=0.158, grad=0.000), Value(data=0.748, grad=0.000), Value(data=0.094, grad=0.000)]
    ]
    [
      [Value(data=1.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000)]
      [Value(data=0.486, grad=0.000), Value(data=0.514, grad=0.000), Value(data=0.000, grad=0.000)]
      [Value(data=0.356, grad=0.000), Value(data=0.454, grad=0.000), Value(data=0.190, grad=0.000)]
    ]
  ]
  [
    [
      [Value(data=1.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000)]
      [Value(data=0.525, grad=0.000), Value(data=0.475, grad=0.000), Value(data=0.000, grad=0.000)]
      [Value(data=0.184, grad=0.000), Value(data=0.759, grad=0.000), Value(data=0.056, grad=0.000)]
    ]
    