In [131]:
import torch
import cupy as cp
import numpy as np
from typing import Union,Callable,Optional
from layers import Linear,Dropout,Softmax,MLP,LayerNorm
from utilsco import compress_numpy_array,decompress_numpy_array
from numpy.typing import ArrayLike
import math
import copy
from torch import nn
from torch.nn import functional as F

In [132]:
B,T,C = 2,16,24
fake_upst = np.random.randint(-10,10,(B,T,C))
t_fake_upstr = torch.tensor(fake_upst,dtype=torch.float32)

In [133]:
# MHA

class MultiHeadAttention:
    def __init__(
        self,
        d_model: int,
        context_size: int,
        n_heads: int,
        batch_size: int,
        lr: float = 0.1,
        dropout: float = 0.1,
        c_attn_weight_init_func: Union[Callable, None] = None,
        c_proj_weight_init_func: Union[Callable, None] = None,
        bias_init_func: Union[Callable, None] = None,
    ) -> None:

        self.d_model = d_model
        self.context_size = context_size
        self.n_heads = n_heads
        self.scale = math.sqrt(d_model)
        self.batch_size = batch_size

        self.attn_dropout = Dropout(dropout)
        self.resid_dropout = Dropout(dropout)
        self.softmax_attn = Softmax(axis=-1)

        if d_model % n_heads != 0:
            raise ValueError("d_model must be divisible by n_heads")

        self.depth = d_model // n_heads

        self.c_attn = Linear(
            d_model,
            3 * d_model,
            batch_size,
            lr,
            weight_init_func=c_attn_weight_init_func,
            bias_init_func=bias_init_func,
        )

        self.c_proj = Linear(
            d_model,
            d_model,
            batch_size,
            lr,
            weight_init_func=c_proj_weight_init_func,
            bias_init_func=bias_init_func,
        )

        self.mask = cp.tril(
            cp.ones((context_size, context_size), dtype=cp.float64)
        ).reshape(1, 1, context_size, context_size)

        self.input = None
        self.v = None
        self.q = None
        self.k = None
        self.attn = None

    def forward(self, input: ArrayLike,train:bool) -> tuple:

        self.input = cp.asanyarray(input)

        B, T, C = self.input.shape

        q, k, v = cp.split(self.c_attn.forward(self.input), 3, axis=2)

        k = k.reshape((B, T, self.n_heads, C // self.n_heads)).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape((B, T, self.n_heads, C // self.n_heads)).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape((B, T, self.n_heads, C // self.n_heads)).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        # this works because we use reduced dimensionalality with mutlihead attention,
        # making it similar in complexity to single attention.

        self.k = k
        self.q = q
        self.v = v

        attn = (q @ k.transpose(0, 1, 3, 2)) * (1.0 / math.sqrt(k.shape[-1]))
        # k.shape[-1] == C // self.n_heads == multi_head_attention_head_dim == depth

        attn = cp.where(self.mask == 0, -1e9, attn)
        attn = self.softmax_attn.forward(attn)
        attn = self.attn_dropout.forward(attn,train)

        self.attn = attn  # 16 x 6 x 256 x 256
        # v.shape: 16 x 6 x 256 x 64
        x = attn @ v  # x: 16 x 6 x 256 x 64

        x = (
            cp.ascontiguousarray(x)
            .transpose(0, 2, 1, 3)
            .reshape(B, -1, self.n_heads * self.depth)
        )
        x = self.c_proj.forward(x)  # keeps dims
        x = self.resid_dropout.forward(x,train)

        return x, attn

    def backward(self, grad: ArrayLike) -> cp.ndarray:
        # grad: 16 x 256 x ...
        
        B, T, C = self.input.shape
        grad = self.resid_dropout.backward(grad)
        grad = self.c_proj.backward(grad)
        grad = grad.reshape(
            (B, T, self.n_heads, self.depth)
        ).transpose(
            0, 2, 1, 3
        )

        v_grad2 = self.attn.transpose(0, 1, 3, 2) @ grad
        # long_grad is gradient for self.attn
        long_grad = grad @ self.v.transpose(0, 1, 3, 2)# long_grad: 16 x 6 x 256 x 64
        # v.shape: 16 x 6 x 256 x 64
        long_grad = self.attn_dropout.backward(long_grad)
        long_grad = self.softmax_attn.backward(long_grad)
        long_grad = cp.where(self.mask == 0, 0, long_grad)

        long_grad = long_grad * (1 / cp.sqrt(self.depth))
        q_grad = long_grad @ self.k  # insert dimensions swaps
        k_grad = long_grad.transpose(0, 1, 3, 2) @ self.q  #
    

        grad = cp.concatenate(
            (
                q_grad.transpose(0, 2, 1, 3).reshape((B, T, C)),
                k_grad.transpose(0, 2, 1, 3).reshape((B, T, C)),
                v_grad2.transpose(0, 2, 1, 3).reshape((B, T, C)),
            ),
            2,
        )
        down = self.c_attn.backward(grad)
        return down

    def update(self) -> None:
        self.c_proj.update()
        self.c_attn.update()
        return

    def get_params(self) -> dict:
        return {
            "c_attn": [
                compress_numpy_array(self.c_attn.weight),
                compress_numpy_array(self.c_attn.bias),
            ],
            "c_proj": [
                compress_numpy_array(self.c_proj.weight),
                compress_numpy_array(self.c_proj.bias),
            ],
        }

    def load_params(self, state_dict: dict) -> None:
        self.c_attn.weight = decompress_numpy_array(state_dict["c_attn"][0])
        self.c_attn.bias = decompress_numpy_array(state_dict["c_attn"][1])
        self.c_proj.weight = decompress_numpy_array(state_dict["c_proj"][0])
        self.c_proj.bias = decompress_numpy_array(state_dict["c_proj"][1])


In [134]:
class NewGELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        z = self.c_attn(x)
        q, k ,v  = z.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        
        att = (q @ k.transpose(-2, -1)) 
        att = att *(1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        
        # out = y
        # t_fakeloss = (out * t_fake_upstr).sum()
        # t_grad =  torch.autograd.grad(inputs=[x],outputs=[t_fakeloss])[0] 
        t_grad = None

        return y


class TBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

    def forward(self, x):
        y1 = self.ln_1(x)
        y2 = self.attn(y1)
        y3 = x + y2
        y4 = self.ln_2(y3)
        y5 = self.mlpf(y4)
        y6 = y3 + y5 # y6 is final grad

        #fakeloss = (y6 * t_fake_upstr ).sum()
        
        #t_grad = torch.autograd.grad(outputs=[fakeloss],inputs=[x])[0]
        #print(t_grad[(0,0)])
        return y6# ,t_grad


In [135]:

np.random.seed(0)

torch.manual_seed(0)
cp.random.seed(0)

class Conf():
    def __init__(self):
        pass
config = Conf()
config.n_embd = C
config.resid_pdrop = 0.
config.attn_pdrop = 0.
config.n_head = 4 # reduced dim is 6, because 6 x 4 = 24
config.block_size = T # context length
t_block = TBlock(config)
inp = np.random.random((B,T,C))

t_inp = torch.tensor(inp,device="cpu",dtype=torch.float32,requires_grad=True)
# t_out ,t_grad= t_block.forward(t_inp)
t_out = t_block.forward(t_inp)

t_out.shape

fakeloss = (t_out * t_fake_upstr ).sum()
t_grad = torch.autograd.grad(outputs=[fakeloss],inputs=[t_inp])[0]
print(t_grad.shape)

torch.Size([2, 16, 24])


In [136]:
# Block
class Block:

    def __init__(
        self,
        d_model: int,
        context_size: int,
        n_heads: int,
        batch_size: int,
        lr: float,
        dropout: float,
    ) -> None:
        def weight_init_func(size):
            return cp.random.normal(size=size, loc=0.0, scale=0.02).astype(cp.float64)

        def c_proj_init_func(size):
            return cp.random.normal(
                size=size, loc=0.0, scale=0.02 / math.sqrt(2 * 6) # 6 is n_layer default
            ).astype(cp.float64)

        def bias_init_func(size):
            return cp.zeros(shape=size, dtype=cp.float64)

        self.d_model = d_model
        self.context_size = context_size
        self.n_heads = n_heads
        self.batch_size = batch_size
        self.lr = lr
        self.dropout = dropout

        self.ln_1 = LayerNorm(d_model, weight_init_func=weight_init_func)

        self.attn = MultiHeadAttention(
            d_model,
            context_size,
            n_heads,
            batch_size,
            lr,
            dropout,
            c_attn_weight_init_func=weight_init_func,
            c_proj_weight_init_func=c_proj_init_func,
            bias_init_func=bias_init_func,
        )

        self.ln_2 = LayerNorm(d_model, weight_init_func=weight_init_func)

        self.mlp = MLP(
            d_model,
            batch_size,
            lr,
            dropout,
            c_fc_init_func=weight_init_func,
            c_proj_init_func=c_proj_init_func,
            bias_init_func=bias_init_func,
        )
    def forward(self, input: ArrayLike,train) -> cp.ndarray:

        input = cp.asanyarray(input) # 0.99 # same
        res = copy.deepcopy(input)
        x = self.ln_1.forward(input) # 2.164 # same
        x = self.attn.forward(x,train)[0] # 1.911 # not same: 2.16
        x = res + x # 2.023
        residual = copy.deepcopy(x)
        x = self.ln_2.forward(x) # 2.798 # vs 3.092 # vs 2.46
        x = self.mlp.forward(x,train) # 0.612
        x = residual + x

        return x

    def backward(self, upstream_grad):
        """
        Computes gradients for the transformer block.

        Args:
            upstream_grad: Gradient from the subsequent layer.

        Returns:
            Gradient with respect to the input `x`.
        """
        x = upstream_grad
        # Backward pass for Residual Connection 2
        x = self.mlp.backward(upstream_grad)
        x = self.ln_2.backward(x)
        x+=upstream_grad
        # is correct
        y = self.attn.backward(x)
        y = self.ln_1.backward(y)
        x = x + y
        return x




    def update(self) -> None:
        self.ln_2.update()
        self.ln_1.update()
        self.mlp.update()
        self.attn.update()
        return
        # raise NotImplementedError("Implement the Block update")

    def state_dict(self) -> dict:
        return {
            "ln_1": [
                compress_numpy_array(self.ln_1.weight),
                compress_numpy_array(self.ln_1.bias),
            ],
            "ln_2": [
                compress_numpy_array(self.ln_2.weight),
                compress_numpy_array(self.ln_2.bias),
            ],
            "mlp": self.mlp.get_params(),
            "attn": self.attn.get_params(),
        }

    def load_params(self, state_dict: dict) -> None:
        self.ln_1.weight = decompress_numpy_array(state_dict["ln_1"][0])
        self.ln_1.bias = decompress_numpy_array(state_dict["ln_1"][1])
        self.ln_2.weight = decompress_numpy_array(state_dict["ln_2"][0])
        self.ln_2.bias = decompress_numpy_array(state_dict["ln_2"][1])

        self.mlp.load_params(state_dict["mlp"])
        self.attn.load_params(state_dict["attn"])

# MLP: 
    # def load_params(self, state_dict: dict) -> None:
    #     self.c_fc.weight = decompress_numpy_array(state_dict["c_fc"][0])
    #     self.c_fc.bias = decompress_numpy_array(state_dict["c_fc"][1])
    #     self.c_proj.weight = decompress_numpy_array(state_dict["c_proj"][0])
    #     self.c_proj.bias = decompress_numpy_array(state_dict["c_proj"][1])

# attn:
# def load_params(self, state_dict: dict) -> None:
#         self.c_attn.weight = decompress_numpy_array(state_dict["c_attn"][0])
#         self.c_attn.bias = decompress_numpy_array(state_dict["c_attn"][1])
#         self.c_proj.weight = decompress_numpy_array(state_dict["c_proj"][0])
#         self.c_proj.bias = decompress_numpy_array(state_dict["c_proj"][1])

In [137]:
c_Block = Block(d_model=C,context_size=T,n_heads=config.n_head,batch_size=B,lr=0.1,dropout=config.resid_pdrop,)
# extract params from torch
params = {"ln_1":[],"ln_2":[],"mlp":[],"attn":[]}
params["ln_1"].append(compress_numpy_array(cp.array( t_block.ln_1.weight.detach().numpy())))
params["ln_1"].append(compress_numpy_array(cp.array( t_block.ln_1.bias.detach().numpy()))) # bias will be 0 (24,)
params["ln_2"].append(compress_numpy_array(cp.array( t_block.ln_2.weight.detach().numpy())))
params["ln_2"].append(compress_numpy_array(cp.array( t_block.ln_2.bias.detach().numpy())))
params["mlp"]={"c_fc":[],"c_proj":[]}
params["attn"]={"c_attn":[],"c_proj":[]}
params["mlp"]["c_fc"].append(compress_numpy_array(cp.array(t_block.mlp.c_fc.weight.detach().numpy().T)))
params["mlp"]["c_fc"].append(compress_numpy_array(cp.array(t_block.mlp.c_fc.bias.detach().numpy())))
params["mlp"]["c_proj"].append(compress_numpy_array(cp.array(t_block.mlp.c_proj.weight.detach().numpy().T)))
params["mlp"]["c_proj"].append(compress_numpy_array(cp.array(t_block.mlp.c_proj.bias.detach().numpy())))
params["attn"]["c_attn"].append(compress_numpy_array(cp.array(t_block.attn.c_attn.weight.detach().numpy().T) ))
params["attn"]["c_attn"].append(compress_numpy_array(cp.array(t_block.attn.c_attn.bias.detach().numpy())))
params["attn"]["c_proj"].append(compress_numpy_array(cp.array(t_block.attn.c_proj.weight.detach().numpy().T)))
params["attn"]["c_proj"].append(compress_numpy_array(cp.array(t_block.attn.c_proj.bias.detach().numpy())))

c_Block.load_params(params)

c_inp = cp.array(inp)
c_forward = c_Block.forward(c_inp,train=True)
c_fake_upst = cp.array(fake_upst)
c_grad = c_Block.backward(c_fake_upst)


print()
print(t_out.detach().numpy()[0][0])
print(c_forward[0][0])


[-0.25807333  0.54939985  0.9438534   0.8831857   0.44403353  0.76653683
  1.0767456   0.93072057  0.6664855   0.00323181  0.82971454  0.06231073
  0.2319231   0.94379604  0.25367013 -0.03447413  0.365102    1.5262859
  0.6247262   0.801388    1.353657    1.6885219   0.4616583   1.2506313 ]
[-0.25807319  0.54939985  0.94385338  0.88318546  0.44403343  0.76653687
  1.07674568  0.93072042  0.66648558  0.00323188  0.8297146   0.06231068
  0.23192303  0.94379614  0.25367012 -0.03447396  0.36510199  1.52628568
  0.62472611  0.80138799  1.35365719  1.6885217   0.46165839  1.25063129]


In [138]:
print(t_grad.max())
print(c_grad.max())
print(t_grad[(0,0)])
print(c_grad[(0,0)])


tensor(18.3603)
18.3643249226239
tensor([ -0.5297,  15.4927,  -4.8079,  -8.4415,  14.7030,   1.4559,  11.8544,
          7.6653,  -9.3051,  -5.8100,  14.6671,   8.9370,  -4.6227,  -2.8054,
          4.5430,  -9.4771,   2.6450,   0.1734, -14.1352,  -3.7626, -11.3567,
         -6.2144, -18.3086,   0.4401])
[ -0.53784775  15.48721386  -4.81829949  -8.4723755   14.74274608
   1.47139379  11.84151806   7.69470273  -9.27653668  -5.78913732
  14.70168211   8.89696564  -4.62626788  -2.80444168   4.55083333
  -9.46444707   2.63951683   0.18077429 -14.13560267  -3.72385471
 -11.38018738  -6.26906922 -18.3048754    0.39559602]


In [139]:
# # compare attn
# t_attn = t_block.attn
# c_attn = c_Block.attn

# t_out ,t_grad= t_attn.forward(t_inp)
# c_out,attn = c_attn.forward(c_inp,True)


# c_grad = c_attn.backward(c_fake_upst)
# print(t_out.detach().numpy()[(0,0)])
# print(c_out[(0,0)])
# # forward is the same!!
# print("Now gradients:")
# print(t_grad.detach().numpy().shape,c_grad.shape,t_grad.max().item(),c_grad.max())
# point = (0,0)
# print(t_grad.detach().numpy()[point])
# print(c_grad[point])
# print((t_grad.detach().numpy()-cp.asnumpy(c_grad)).max())