In [1]:
import numpy as np
from numpy.typing import ArrayLike
import cupy as cp
import torch
from typing import Union, Optional,Callable


In [2]:
# helper:
def _multi_dim_matmul(
        mat_a: cp.ndarray,
        mat_b: cp.ndarray,
        transpose_a: bool = False,
        transpose_b: bool = False,
        reshape_output: bool = True,
    ) -> cp.ndarray:
        """
        Replicate torch behavior of flattening all but the
        last dimension of an input of the matrix multiplication
        in linear layers. We implement this for both the first
        and the last matrix in the matrix multiplication to
        provide a unified operation for both the forward and
        the backward pass.
        """

        if (len(mat_a.shape) > 2) or (len(mat_b.shape) > 2):
            # Dimension handling.
            # We should refactor this if we find the time.

            dims_internal_mat_a = (
                mat_a.shape
                if len(mat_a.shape) <= 2
                else (cp.prod(cp.array(mat_a.shape[:-1])).item(), mat_a.shape[-1])
            )

            dims_internal_mat_b = (
                mat_b.shape
                if len(mat_b.shape) <= 2
                else (cp.prod(cp.array(mat_b.shape[:-1])).item(), mat_b.shape[-1])
            )

            mat_a_shape = mat_a.shape[::-1] if transpose_a else mat_a.shape
            mat_b_shape = mat_b.shape[::-1] if transpose_b else mat_b.shape

            dims_out_first = (
                mat_a.shape[:-1]
                if reshape_output
                else (
                    dims_internal_mat_a[1] if transpose_a else dims_internal_mat_a[0],
                )
            )

            dims_out = (*dims_out_first, mat_b_shape[-1])

            def mat_a_transform():
                if transpose_a:
                    return mat_a.reshape(dims_internal_mat_a).T
                else:
                    return mat_a.reshape(dims_internal_mat_a)

            def mat_b_transform():
                if transpose_b:
                    return mat_b.reshape(dims_internal_mat_b).T
                else:
                    return mat_b.reshape(dims_internal_mat_b)

            return cp.matmul(mat_a_transform(), mat_b_transform()).reshape(dims_out)

        else:
            return cp.matmul(mat_a, mat_b.T) if transpose_b else cp.matmul(mat_a, mat_b)

Testing GELU backwards pass by comparing output gradient to torch GELU gradient

In [3]:
class GELU:
    def __init__(self) -> None:
        self._sqrt_of_2_by_pi = cp.sqrt(2 / cp.pi)
        self.input = None

    def forward(self, input: ArrayLike) -> cp.ndarray:
        self.input = cp.asanyarray(input)
        return (
            0.5
            * input
            * (
                1
                + cp.tanh(
                    self._sqrt_of_2_by_pi * (input + 0.044715 * cp.power(input, 3))
                )
            )
        )

    def backward(self, grad_output: ArrayLike) -> cp.ndarray:
        # raise NotImplementedError("Implement the GELU backward path")
        x = self.input
        m1 = self._sqrt_of_2_by_pi
        m2 = 0.044715
        m3 = m1 * (x+m2 * x**3)
        tanhm3 = cp.tanh(
                    m3
                )
        first = 0.5 * (
                1
                + tanhm3
            )
        second = x/2 * (1- tanhm3**2) * (m1+2*x**2 * m2*m1)
        grad_out = (first + second) * grad_output
        return grad_out


In [4]:

Gelu= GELU()
a=cp.random.random((5,))

In [5]:
a_torch = torch.tensor(a.copy(),requires_grad=True,device="cpu")
b_t=torch.nn.functional.gelu(a_torch)

b_n = Gelu.forward(a.copy())
b_n_grad = Gelu.backward(cp.ones((5,)))

In [6]:
print(b_t.clone().cpu().detach().numpy())
print(b_n) 
# close enough

[0.34228479 0.79604536 0.18636156 0.05129523 0.18866275]
[0.34226806 0.79590748 0.18635901 0.0512952  0.1886601 ]


In [7]:
print(b_n_grad)
torch_sum = b_t.sum() # the gradient of sum equals the cp.ones((5,))
grads = cp.array(torch.autograd.grad(outputs=[torch_sum],inputs=[a_torch])[0].numpy())
print(grads)

diff = (b_n_grad- grads)
print(diff)
# diff is very smol
assert(abs(diff).max() <= 1e-2)

[0.86305331 1.0633282  0.73281609 0.57582989 0.73507947]
[0.8650387  1.07250098 0.7333096  0.5758464  0.73558789]
[-1.98539357e-03 -9.17278582e-03 -4.93511413e-04 -1.65041191e-05
 -5.08424395e-04]


Now comparing LayerNorm

In [8]:
class LayerNorm:
    def __init__(
        self,
        normalized_shape: Union[int, tuple[int]],
        eps: float = 1e-05,
        lr: float = 1e-3,
        weight_init =None,
        bias_init=None,
    ) -> None:

        self.normalized_shape = (
            (normalized_shape,)
            if isinstance(normalized_shape, int)
            else normalized_shape
        )

        self.eps = eps
        self.lr = lr
        self.weight =weight_init

        self.bias = bias_init

        self.axis = None

        self.input = None

        self.grad_weight = None
        self.grad_bias = None

        self.x_centered = None
        self.stddev_inv = None

    def forward(self, input: ArrayLike) -> cp.ndarray:

        input = cp.asanyarray(input)

        self.input = input

        self.axis = tuple(range(-len(self.normalized_shape), 0))
        #  -n,..., -2 , -1 ohne 0

        mean = cp.mean(input, axis=self.axis, keepdims=True)
        self.x_centered = input - mean
        var = cp.var(
            input, axis=self.axis, keepdims=True, # mean=mean
        )  # can we pass the mean to the var()?  YES (with newer numpy versions)!
        # the var stays the same after centering. Usefull for gradient calculation (not really)
        self.stddev_inv = 1 / cp.sqrt(var + self.eps)

        output = self.x_centered * self.stddev_inv

        return self.weight * output + self.bias

    def _rip_backward(self,grad):
        grad_main = grad * self.weight # no transpose needed because multiplication is element wise
        outer = cp.einsum("...i,...j->...ij", self.x_centered, self.x_centered) 
        reshaped_invvar = self.stddev_inv.reshape((*self.stddev_inv.shape,1))
        part = outer * reshaped_invvar **3 / C
        part2 = (reshaped_invvar -part)
        grad_centered =  cp.matmul(grad_main.reshape((B,T,1,C)),part2)
        grad_mean_input = cp.full((B,T,C,C), -1 / self.input.shape[-1])
        grad_mean_input+= cp.diag(cp.ones((self.input.shape[-1],))).reshape((1,1,C,C))
        grad_input = cp.matmul(grad_centered,grad_mean_input)

        grad_out = grad_input
        return grad_out.squeeze() # does not work

    def backward(self, grad: ArrayLike) -> cp.ndarray:
        B,T,C = self.input.shape
        self.grad_bias = grad.mean(axis=(0, 1))  # upstream gradient * 1.
        self.grad_weight = (
            grad * (self.x_centered * self.stddev_inv)
        ).mean(axis=(0, 1))  # upstream * centered * invvar

        normalized = self.x_centered * self.stddev_inv
        grad_normalized = grad * self.weight
        grad_x = grad_normalized - grad_normalized.mean(-1,keepdims=True) - normalized * (grad_normalized *normalized).mean(-1,keepdims=True)
        grad_x = grad_x * self.stddev_inv
        return grad_x


    def update(self):
        self.weight -= self.lr * self.grad_weight
        self.bias -= self.lr * self.grad_bias
        return
        # raise NotImplementedError("Implement the LayerNorm update routine")


In [9]:
B,T,C = (2,3,4)

i_np = np.random.random((B,T,C)) # input_array
#i_np = np.array([1,2,3]).reshape((B,T,C))
print(i_np.shape,i_np.mean())

(2, 3, 4) 0.6759062075736666


In [10]:
t_norm = torch.nn.LayerNorm((C,) ,device="cpu",dtype=torch.float64)
print(t_norm,t_norm.weight.dtype)
print(t_norm.weight.shape)
print(t_norm.bias.shape)
i_t = torch.tensor(i_np,device="cpu",requires_grad=True,dtype=t_norm.weight.dtype)
t_out = t_norm.forward(i_t)
t_dout = torch.randn(B,T,C)
t_loss = (t_out * t_dout).sum()  # the gradient of sum equals the cp.ones((5,))
#t_grads = cp.array(torch.autograd.grad(outputs=[torch_sum],inputs=[i_t])[0].numpy())
#print(t_grads.shape,t_grads.mean())
#print(t_grads.squeeze())
t_loss.backward()
i_t.grad



LayerNorm((4,), eps=1e-05, elementwise_affine=True) torch.float64
torch.Size([4])
torch.Size([4])


tensor([[[-0.5191, -4.7207, -0.1122,  5.3520],
         [ 0.2664,  1.6462, -0.4588, -1.4537],
         [-2.9517, -0.6816,  1.6801,  1.9532]],

        [[ 1.2894, -1.5235, -1.2015,  1.4356],
         [-0.4970, -1.4635,  9.6867, -7.7262],
         [ 1.3423,  1.7049, -4.1252,  1.0780]]], dtype=torch.float64)

In [14]:
s_norm = LayerNorm((C,),weight_init=cp.array(t_norm.weight.detach().numpy()),bias_init=cp.array(t_norm.bias.detach().numpy()))
c_np = cp.array(i_np,dtype=cp.float64)
c_result = s_norm.forward(c_np)
upstr = cp.ones((B,T,C))
c_grads = s_norm.backward(cp.array(t_dout.detach().numpy()))
print(c_grads.shape,c_grads.mean())
print(c_grads.squeeze())

(2, 3, 4) -8.326672684688674e-16
[[[-0.51912519 -4.72066233 -0.11216766  5.35195518]
  [ 0.2664128   1.64616202 -0.458826   -1.45374883]
  [-2.95174785 -0.68156979  1.68008565  1.95323199]]

 [[ 1.28937416 -1.52346011 -1.20148196  1.43556791]
  [-0.49700066 -1.46347688  9.68665205 -7.72617452]
  [ 1.34228999  1.70490914 -4.12523728  1.07803815]]]


In [12]:
print(torch_sum)
print(c_result.sum())

tensor(1.5646, dtype=torch.float64, grad_fn=<SumBackward0>)
3.3306690738754696e-15


In [13]:
# inspect forward:
print(t_result.max(),c_result.max())
# forward is the same!
print(t_result.mean(dim=(2)).shape,t_result.mean(dim=(2))) # the average over the last dim is indeed 0
print(t_result.var(dim=(2)).shape,t_result.var(dim=(2))) # and var is 1
# now c
print(c_result.mean(axis=(2)).shape,t_result.mean(axis=(2))) # the average over the last dim is indeed 0
print(c_result.var(axis=(2)).shape,t_result.var(axis=(2)))
# perfect

NameError: name 't_result' is not defined

In [None]:
# backward:
c_grads=c_grads.squeeze()
print(t_grads.mean(),c_grads.mean())
print(t_grads.var(),c_grads.var())
print(t_grads.max(),c_grads.max())
diff = t_grads - c_grads
print(diff.mean()) 
point = (0,0,0)
# print(t_grads[point])
# print(c_grads[point])
# print(t_grads[point]-c_grads[point])


-2.220446049250313e-16 -2.2109675761565468e-31
2.3830173178551396e-30 1.3069015216202024e-30
3.552713678800501e-15 3.0933114078038186e-15


ValueError: operands could not be broadcast together with shapes (2, 2, 3) (2, 3, 4)

In [None]:
# tst
init = np.array([[[1, 2], [2, 3], [3, 4]],
            [[4, 5], [5, 6], [6, 7]]])  # Shape (2, 3,2)
print(init.shape)
print(init)
result = np.einsum("...i,...j->...ij", init, init)
print(result.shape)
print(result) # is the same as result.transpose(0,1,3,2) due to the symerty in the last dimension
a=0
b=1
print(result[(a,b,0,1)])  # == result[a,b,0] * result[a,b,1]
print( 6)

(2, 3, 2)
[[[1 2]
  [2 3]
  [3 4]]

 [[4 5]
  [5 6]
  [6 7]]]
(2, 3, 2, 2)
[[[[ 1  2]
   [ 2  4]]

  [[ 4  6]
   [ 6  9]]

  [[ 9 12]
   [12 16]]]


 [[[16 20]
   [20 25]]

  [[25 30]
   [30 36]]

  [[36 42]
   [42 49]]]]
6
6


In [None]:
dima = (4,2,3,3)
A = np.arange(18*4).reshape(dima)
B = np.arange(2,26).reshape((4,2,3,1))
print(A.shape)
print(B.shape)
res= np.matmul(A,B)
print(res.shape)
print(res[(3,1)])
manual = A[(3,1,0,0)] * B[(3,1,0,0)] + A[(3,1,0,1)] * B[(3,1,1,0)] +A[(3,1,0,2)] * B[(3,1,2,0)]
print(manual)

(4, 2, 3, 3)
(4, 2, 3, 1)
(4, 2, 3, 1)
[[4610]
 [4826]
 [5042]]
4610
