In [2]:
import random as r

from engine import Value
from modules import *
from ops import *
from gpt import *

In [23]:
batch_size = 2
vocab_len = 10
model_dim = 4
max_seq_len = 5
seq_len = 3
num_heads = 2
head_dim = 2

In [8]:
class MultiHeadSelfAttention(Module):
    def __init__(self, model_dim, num_heads, head_dim, max_seq_len):
        self.Wq = Linear(model_dim, num_heads * head_dim)
        self.Wk = Linear(model_dim, num_heads * head_dim)
        self.Wv = Linear(model_dim, num_heads * head_dim)

        self.mask = Mask(max_seq_len)

        self.scale = head_dim ** -0.5
    
    def __call__(self, x):
        assert isinstance(x, list) and isinstance(x[0], list) and isinstance(x[0][0], list) and isinstance(x[0][0][0], Value),\
            "input to MHSA mechanism must be tensor of ndim==3 for (batch_size, seq_len, model_dim)"
        batch_size, seq_len, model_dim = tuple(get_shape(x))
        assert self.model_dim == model_dim,\
            f"input final dimension {model_dim} must equal MHSA mechanism's given model_dim value at initialization of {self.model_dim}"

        # apply query, key, and value projections to our input
        q = vector_wise_apply(self.Wq, x) # shape (batch_size, seq_len, num_heads * head_dim)
        k = vector_wise_apply(self.Wk, x) # Linear object is meant to take in a single vector, so we use vector_wise_apply
        v = vector_wise_apply(self.Wv, x)
        print(get_shape(q))

        # split apart heads
        q = vector_wise_apply(split_dim, q, dims=(self.num_heads, self.head_dim)) # shape (batch_size, seq_len, num_heads, head_dim)
        k = vector_wise_apply(split_dim, k, dims=(self.num_heads, self.head_dim))
        v = vector_wise_apply(split_dim, v, dims=(self.num_heads, self.head_dim))
        print(get_shape(q))

        # transpose to put seq_len in the path of the matmul for our attention computation
        q = transpose(q, (1,2)) # shape (batch_size, num_heads, seq_len, head_dim)
        k = transpose(k, (1,2))
        v = transpose(v, (1,2))
        print(get_shape(q))

        # get keys ready for attention computation
        k_t = transpose(k, (2,3)) # shape (batch_size, num_heads, head_dim, seq_len)
        # compute attention logits
        logits = tensor_matmul(q, k_t)
        # scale logits
        scaled_logits = vector_wise_apply(mult_vec_by_float, logits, self.scale)

        # apply mask
        #masked_logits = matrix_wise_apply(mult_
        # ok so i need entry-wise multiplication of a tensor by a matrix. 
        # guess i'm gonna have to start worrying about projections & whatnot and do things correctly
        pass

In [80]:
# wait no i think i can still use matrix_wise_apply
def masked_fill(matrix, mask, val = float('-inf')):
    mat_shape, mask_shape = get_shape(matrix), get_shape(mask)
    assert mat_shape == mask_shape, f"shapes of matrix & mask should be equal but are {mat_shape} and {mask_shape}"
    return [[matrix[i][j] if mask[i][j] else Value(val) for j in range(mat_shape[1])] for i in range(mat_shape[0])]

In [93]:
x = [[[[r.uniform(-1,1) for _ in range(seq_len)]
       for _ in range(seq_len)]
      for _ in range(num_heads)]
     for _ in range(batch_size)]
pretty_print_tensor(x)
mask = Mask(max_seq_len)
y = matrix_wise_apply(masked_fill, x, mask(seq_len))
pretty_print_tensor(y)

[
  [
    [
      [-0.9212250879600306, 0.3175362345578827, -0.0787617064634325]
      [-0.07419790053661868, -0.9303524468217099, 0.061551510856735936]
      [-0.5896097739394031, -0.8900058845021095, -0.121975460572874]
    ]
    [
      [-0.4194124608923371, -0.5988650375157958, 0.6379125238070049]
      [0.4069845034082229, 0.47175817300674305, -0.3168095764858674]
      [0.21368078815054403, -0.3300394855822759, 0.008479319794881368]
    ]
  ]
  [
    [
      [-0.7861821280475625, -0.21943771557592062, -0.8823917849806311]
      [0.6852503703869393, -0.7536082089518954, -0.850708673160691]
      [0.6757432711821114, 0.5862671377580468, 0.9245769137530935]
    ]
    [
      [-0.6911640834179114, 0.40861571599965085, -0.7793621010307572]
      [-0.1841809623096613, 0.6666458316301, -0.4611009162965849]
      [-0.12010953818108283, -0.4139901723096442, 0.8088658987166095]
    ]
  ]
]
[
  [
    [
      [-0.9212250879600306, Value(data=-inf, grad=0.000), Value(data=-inf, grad=0.000)]
 

In [87]:
class Mask(Module):
    def __init__(self, max_seq_len):
        self.max_seq_len = max_seq_len
        self.mask = [ [1] * (i + 1) + [0] * (max_seq_len - i - 1) for i in range(max_seq_len)]

    def __call__(self, seq_len):
        assert 0 < seq_len <= self.max_seq_len, f'seq_len {seq_len} must be less than max_seq_len {max_seq_len}'
        return [[i for i in row[:seq_len]] for row in self.mask[:seq_len]]

    def masked_fill(self, matrix, val = float('-inf')):
        mat_shape  = get_shape(matrix)
        assert mat_shape[0] == mat_shape[1], f"masked_fill requires input to be square matrix but instead got shape {mat_shape}"
        mask = self(len(matrix))
        return [[matrix[i][j] if mask[i][j] else Value(val) for j in range(mat_shape[1])] for i in range(mat_shape[0])]

    def __repr__(self):
        weights_repr = "\n".join(
            f"[{', '.join(str(p) for p in row)}]" for row in self.mask
        )
        return f"Causal self-attention mask:\n{weights_repr}"

In [98]:
x = [[[[Value(r.uniform(-1,1)) for _ in range(seq_len)]
       for _ in range(seq_len)]
      for _ in range(num_heads)]
     for _ in range(batch_size)]
pretty_print_tensor(x)
mask = Mask(max_seq_len)
y = matrix_wise_apply(mask.masked_fill, x)
pretty_print_tensor(y)

[
  [
    [
      [Value(data=-0.166, grad=0.000), Value(data=0.539, grad=0.000), Value(data=0.715, grad=0.000)]
      [Value(data=0.921, grad=0.000), Value(data=0.344, grad=0.000), Value(data=-0.613, grad=0.000)]
      [Value(data=0.129, grad=0.000), Value(data=-0.397, grad=0.000), Value(data=-0.049, grad=0.000)]
    ]
    [
      [Value(data=0.037, grad=0.000), Value(data=0.542, grad=0.000), Value(data=-0.379, grad=0.000)]
      [Value(data=0.928, grad=0.000), Value(data=-0.393, grad=0.000), Value(data=0.401, grad=0.000)]
      [Value(data=0.917, grad=0.000), Value(data=-0.273, grad=0.000), Value(data=0.849, grad=0.000)]
    ]
  ]
  [
    [
      [Value(data=-0.267, grad=0.000), Value(data=0.048, grad=0.000), Value(data=0.014, grad=0.000)]
      [Value(data=0.494, grad=0.000), Value(data=-0.047, grad=0.000), Value(data=-0.346, grad=0.000)]
      [Value(data=0.570, grad=0.000), Value(data=-0.375, grad=0.000), Value(data=-0.666, grad=0.000)]
    ]
    [
      [Value(data=0.865, grad=0.