In [2]:
import random as r

from engine import Value
from modules import *
from ops import *
from gpt import *

In [3]:
batch_size = 2
vocab_len = 10
model_dim = 4
seq_len = 5
num_heads = 2
head_dim = 2

In [None]:
class MultiHeadSelfAttention(Module):
    def __init__(self, model_dim, num_heads, head_dim, max_seq_len):
        self.Wq = Linear(model_dim, num_heads * head_dim)
        self.Wk = Linear(model_dim, num_heads * head_dim)
        self.Wv = Linear(model_dim, num_heads * head_dim)

        # TODO:
        # - causal mask
    
    def __call__(self, x):
        assert isinstance(x, list) and isinstance(x[0], list) and isinstance(x[0][0], list) and isinstance(x[0][0][0], Value),\
            "input to MHSA mechanism must be tensor of ndim==3 for (batch_size, seq_len, model_dim)"
        batch_size, seq_len, model_dim = len(x), len(x[0]), len(x[0][0])
        assert self.model_dim == model_dim,\
            f"input final dimension {model_dim} must equal MHSA mechanism's given model_dim value at initialization of {self.model_dim}"

        q = vector_wise_apply(self.Wq, x) # shape (batch_size, seq_len, num_heads * head_dim)
        k = vector_wise_apply(self.Wk, x)
        v = vector_wise_apply(self.Wv, x)
        # oops i need a split_dim function to separate out num_heads from head_dim
        pass

In [None]:
def split_dim(x, dims):
    '''
    splits input vector of shape (dims[0] + dims[1]) into matrix of shape (dims[0], dims[1])
    '''
    assert isinstance(x, list), "x should be a list of Value objects"
    assert all(isinstance(idx, Value) for idx in x), "All elements in x must be Value objects"
    # oops i need a vector_wise_apply() function that takes in extra arguments in order to make this work
    pass
    

In [35]:
def vector_wise_apply(function, vec, extra_arg = None):
    '''
    applies the input function to the tensor vector-wise
    
    inputs: 
        function - a function meant to be applied to a list of Value objects
        x - list of lists of .... of Value objects
        extra_arg - a second value the function of interest may or may not require as an argument
            *i should probably be using **args or **kwargs there somehow but idk how to use those tbh
    output: 
        out - list of lists of .... of Value objects
    '''
    assert isinstance(x, list), "input must be at least a vector (aka a list of Value objects)"
    if isinstance(x[0], list):
        if extra_arg is not None:
            return [vector_wise_apply(function, sub_x, extra_arg) for sub_x in x]
        else:
            return [vector_wise_apply(function, sub_x) for sub_x in x]
    else: # base case: the final vector dimension
        return function(x, extra_arg) if extra_arg is not None else function(x)

In [37]:
def mul(x, c):
    '''
    multiplies all elements in the vector x by the constant c
    for division just input a fraction
    '''
    assert isinstance(x, list), "x should be a list of Value objects"
    assert all(isinstance(idx, Value) for idx in x), "All elements in x must be Value objects"
    return [xi * c for xi in x]

def add(x, c):
    '''
    adds all elements in the vector x by the constant c
    for subtraction just input a negative number
    '''
    assert isinstance(x, list), "x should be a list of Value objects"
    assert all(isinstance(idx, Value) for idx in x), "All elements in x must be Value objects"
    return [xi + c for xi in x]

# ok so these aren't working cleanly with vector_wise_apply bc it only exxpects 1 input. 
    # can I somehow use **args or **kwargs? should learn those at some point
    # oh yeah i'm gonna need entry-wise mult by a single float for scaling in the attention mechanism
    # and i'm gonna need a vector_wise_apply that takes in an extra argument for splitting up num_heads from head_dim

In [39]:


print('\n\n-------------- test entry-wise add by a single float on a vector -------------')
x = [Value(r.uniform(-1,1)) for _ in range(model_dim)]
print(x)
y = add(x, 100.)
print(y)
# tensor
print('\n\n-------------- test entry-wise add by a single float on a tensor -------------')
x = [[[Value(r.uniform(-1,1)) for _ in range(model_dim)]
      for _ in range(seq_len)]
     for _ in range(batch_size)]
pretty_print_tensor(x)
print('\n')
y = vector_wise_apply(add, x, 100.)
pretty_print_tensor(y)



-------------- test entry-wise add by a single float on a vector -------------
[Value(data=-0.172, grad=0.000), Value(data=-0.711, grad=0.000), Value(data=-0.067, grad=0.000), Value(data=-0.150, grad=0.000)]
[Value(data=99.828, grad=0.000), Value(data=99.289, grad=0.000), Value(data=99.933, grad=0.000), Value(data=99.850, grad=0.000)]


-------------- test entry-wise add by a single float on a tensor -------------
[
  [
    [Value(data=-0.363, grad=0.000), Value(data=-0.318, grad=0.000), Value(data=-0.645, grad=0.000), Value(data=0.022, grad=0.000)]
    [Value(data=-0.004, grad=0.000), Value(data=0.639, grad=0.000), Value(data=-0.226, grad=0.000), Value(data=0.417, grad=0.000)]
    [Value(data=0.587, grad=0.000), Value(data=-0.842, grad=0.000), Value(data=0.026, grad=0.000), Value(data=-0.295, grad=0.000)]
    [Value(data=-0.013, grad=0.000), Value(data=0.628, grad=0.000), Value(data=0.616, grad=0.000), Value(data=0.997, grad=0.000)]
    [Value(data=-0.390, grad=0.000), Value(data=0.

RecursionError: maximum recursion depth exceeded