In [53]:
import numpy as np

In [54]:
import torch
import torch.nn as nn

## Things that would be good to implement

1. standard attention
2. multi-query attention
3. group-query attention << means fewer k & v
4. ring attention
5. linear attention 
6. rope
7. speculative decoding
8. KNN
9. k-means
10. different types of parallelism e.g., FSDP, data parallelism, model parallelism

In [155]:
class Initializer(object):

    def __call__(self, x):
        if x.ndim > 1:
            # init to a truncated normal dist
            x[:] = self.sample_from_truncated_normal(x.size).reshape(x.shape)
        else:
            # biases get initialized to zero
            x[:] = np.zeros_like(x)
        
    def sample_from_truncated_normal(self, size, stdev=0.02):
        min_value = -3 * stdev
        max_value = 3 * stdev
        truncated_normal = np.random.normal(0, stdev, size)
        oob_idxs = (truncated_normal > max_value) | (truncated_normal < min_value)
        num_oob = sum(oob_idxs)
        while num_oob:
            truncated_normal[oob_idxs] = np.random.normal(0, stdev, num_oob)
            oob_idxs = (truncated_normal > max_value) | (truncated_normal < min_value)
            num_oob = sum(oob_idxs)
        return truncated_normal

In [156]:
class Module(object):

    def __init__(self, **param_kwargs):
        self.grads  = {}
        self.params = {name: param for name, param in param_kwargs.items()}
        self.initializer = initializer=Initializer()
        self.reset_parameters()

    def reset_parameters(self):
        for x in self.params.values():
            self.initializer(x)

    def __call__(self, x):
        out = self.forward(x)
        return out

    def forward(self, x):
        """
        Transforms the input and then returns the output
        """        
        raise NotImplementedError
    
    def backward(self, dl_out):
        """
        Updates self.grads with the gradients of each parameter w.r.t. the loss and
        returns the derivative of the input w.r.t. the loss  
        """
        raise NotImplementedError

In [157]:
class Linear(Module):

    def __init__(self, in_features, out_features, bias=False):
        kwargs = {}
        kwargs["weight"] = np.zeros((out_features, in_features))
        if bias:
            kwargs["bias"] = np.zeros(out_features)
        super().__init__(**kwargs)
        
    def __call__(self, x):
        """
        x: (B, in_features)

        Returns: (B, out_features)
        """
        out = np.dot(x, self.params["weight"].T)
        if "bias" in self.params:
            out += self.params["bias"]
        return out

    def backward(self, x, d_up):
        """
        x: (B, in_features) input to this layer 
        d_up: (B, out_features) derivative of the Loss w.r.t. the output of this module

        Returns: the derivative of the output w.r.t. the input X
        """

        # TODO: looks like we should be using a cache for backward instead of relying on the 
        # value of the params stored in the b/c 
        # those might update b4 the backward or something! 
        
        if "bias" in self.params:
            self.grads["bias"] = np.sum(d_up, axis=0)

        # (out_dim, in_features) = (out_features, B) x (B, in_features)
        self.grads["weight"] = np.dot(d_up.T, x)
        # (B, in_features) = (B, out_features) x (out_features, in_features)
        self.grads["in"] = np.dot(d_up, self.params["weight"])


In [158]:
def test_forward(in_shape, get_modules):
    np_in = np.random.rand(*in_shape).astype(np.float32)
    tensor_in = torch.tensor(np_in)

    my_module, torch_module = get_modules()
    my_out = my_module(np_in)
    torch_out = torch_module(tensor_in).detach().numpy()
    print("outputs equal:", np.array_equal(my_out, torch_out))


def test_backward(in_shape, out_shape, get_modules):
    #pass
    my_module, torch_module = get_modules()

    np_in = np.random.rand(*in_shape).astype(np.float32)
    tensor_in = torch.tensor(np_in, requires_grad=True) # b/c we're gonna check dL/dIn

    np_d_up = np.random.rand(*out_shape).astype(np.float32)
    tensor_d_up = torch.tensor(np_d_up)

    torch_out = torch_module(tensor_in)
    # TODO: think we're passing upstream here but torch docs are sparse
    torch_out.backward(tensor_d_up)
    
    my_module.backward(np_in, np_d_up)

    for name, value in torch_module.named_parameters():
        my_grad = my_module.grads[name]
        torch_grad = value.grad.numpy()
        print(f"gradient of {name} is close:", np.allclose(my_grad, torch_grad))

    print("gradient of input is close:", np.allclose(my_module.grads["in"], tensor_in.grad.numpy()))


In [159]:
def get_lin_modules(init_kwargs, my_cls, torch_cls):
    my_lin = my_cls(**init_kwargs)
    torch_lin = torch_cls(**init_kwargs)

    my_lin.params["weight"] = torch_lin.weight.detach().numpy()
    if torch_lin.bias is not None:
        my_lin.params["bias"] = torch_lin.bias.detach().numpy()
    return my_lin, torch_lin


def test_linear(n_batch, d_in, d_out, use_bias):
    kwargs = {
        "in_features": d_in,
        "out_features": d_out,
        "bias": use_bias,
    }

    print(f"========= testing forward ==========")
    
    test_forward(
        in_shape=(n_batch, d_in),
        get_modules=lambda : get_lin_modules(kwargs, Linear, nn.Linear),
    )

    print(f"\n========= testing backward ==========\n")
    return test_backward(
        in_shape=(n_batch, d_in),
        out_shape=(n_batch, d_out),
        get_modules=lambda : get_lin_modules(kwargs, Linear, nn.Linear),
    )


In [160]:
test_linear(3, 4, 2, True)

outputs equal: True


gradient of weight is close: True
gradient of bias is close: True
gradient of input is close: True
