In [210]:
import numpy as np
from numpy.typing import ArrayLike
import cupy as cp
import torch
from typing import Union, Optional,Callable


Testing GELU backwards pass by comparing output gradient to torch GELU gradient

In [211]:
class GELU:
    def __init__(self) -> None:
        self._sqrt_of_2_by_pi = cp.sqrt(2 / cp.pi)
        self.input = None

    def forward(self, input: ArrayLike) -> cp.ndarray:
        self.input = cp.asanyarray(input)
        return (
            0.5
            * input
            * (
                1
                + cp.tanh(
                    self._sqrt_of_2_by_pi * (input + 0.044715 * cp.power(input, 3))
                )
            )
        )

    def backward(self, grad_output: ArrayLike) -> cp.ndarray:
        # raise NotImplementedError("Implement the GELU backward path")
        x = self.input
        m1 = self._sqrt_of_2_by_pi
        m2 = 0.044715
        m3 = m1 * (x+m2 * x**3)
        tanhm3 = cp.tanh(
                    m3
                )
        first = 0.5 * (
                1
                + tanhm3
            )
        second = x/2 * (1- tanhm3**2) * (m1+2*x**2 * m2*m1)
        grad_out = (first + second) * grad_output
        return grad_out


In [212]:

Gelu= GELU()
a=cp.random.random((5,))

In [213]:
a_torch = torch.tensor(a.copy(),requires_grad=True,device="cpu")
b_t=torch.nn.functional.gelu(a_torch)

b_n = Gelu.forward(a.copy())
b_n_grad = Gelu.backward(cp.ones((5,)))

In [214]:
print(b_t.clone().cpu().detach().numpy())
print(b_n) 
# close enough

[0.53843265 0.78671516 0.40918719 0.77219431 0.62633567]
[0.53837527 0.78658037 0.40915942 0.77206434 0.62625296]


In [215]:
print(b_n_grad)
torch_sum = b_t.sum() # the gradient of sum equals the cp.ones((5,))
grads = cp.array(torch.autograd.grad(outputs=[torch_sum],inputs=[a_torch])[0].numpy())
print(grads)

diff = (b_n_grad- grads)
print(diff)
# diff is very smol
assert(abs(diff).max() <= 1e-2)

[0.97547696 1.06107309 0.90687149 1.05744095 1.01192964]
[0.98035172 1.07009616 0.90974865 1.06622904 1.0182753 ]
[-0.00487475 -0.00902307 -0.00287715 -0.00878809 -0.00634566]


Now comparing LayerNorm

In [216]:
class LayerNorm:
    def __init__(
        self,
        normalized_shape: Union[int, tuple[int]],
        eps: float = 1e-05,
        lr: float = 1e-3,
        weight_init =None,
        bias_init=None,
    ) -> None:

        self.normalized_shape = (
            (normalized_shape,)
            if isinstance(normalized_shape, int)
            else normalized_shape
        )

        self.eps = eps
        self.lr = lr
        self.weight =weight_init

        self.bias = bias_init

        self.axis = None

        self.input = None

        self.grad_weight = None
        self.grad_bias = None

        self.x_centered = None
        self.stddev_inv = None

    def forward(self, input: ArrayLike) -> cp.ndarray:

        input = cp.asanyarray(input)

        self.input = input

        self.axis = tuple(range(-len(self.normalized_shape), 0))
        #  -n,..., -2 , -1 ohne 0

        mean = cp.mean(input, axis=self.axis, keepdims=True)
        var = cp.var(
            input, axis=self.axis, keepdims=True, # mean=mean
        )  # can we pass the mean to the var()?  YES (with newer numpy versions)!
        # the var stays the same after centering. Usefull for gradient calculation (not really)
        self.x_centered = input - mean
        self.stddev_inv = 1 / cp.sqrt(var + self.eps)

        output = self.x_centered * self.stddev_inv

        return self.weight * output + self.bias

    def backward(self, grad: ArrayLike) -> cp.ndarray:
        self.grad_bias = grad  # upstream gradient * 1.
        self.grad_weight = (
            grad * self.x_centered * self.stddev_inv
        )  # upstream * centered * invvar

        # fuck this is hard
        grad = grad * self.weight.transpose()  # add dims to transpose
        grad = grad * self.stddev_inv  # .squeeze()
        # grad_out = grad.reshape((*grad.shape, 1)) * (
        #     -2 * cp.power(self.x_centered, 2) * cp.power(self.stddev_inv, 2)
        #     + (self.stddev_inv * (1 - self.normalized_shape[-1]))
        # )  # TODO: check
        #
        grad_out = grad * (1 - 1 / self.input.shape[-1])

        return grad_out

    def update(self):
        self.weight -= self.lr * self.grad_weight.mean(axis=(0, 1))
        self.bias -= self.lr * self.grad_bias.mean(axis=(0, 1))
        return
        # raise NotImplementedError("Implement the LayerNorm update routine")


In [217]:
B,T,C = (16,256,384)
h = 6
i_np = np.random.random((B,T,C)) # input_array
print(i_np.shape,i_np.mean())


(16, 256, 384) 0.4995360355350611


In [218]:
t_norm = torch.nn.LayerNorm((C,) ,device="cpu")
print(t_norm,t_norm.weight.dtype)
print(t_norm.weight.shape)
print(t_norm.bias.shape)
i_t = torch.tensor(i_np,device="cpu",requires_grad=True,dtype=t_norm.weight.dtype)
result = t_norm.forward(i_t)
torch_sum = result.sum() # the gradient of sum equals the cp.ones((5,))
grads = cp.array(torch.autograd.grad(outputs=[torch_sum],inputs=[i_t])[0].numpy())
print(grads.shape,grads.max())

LayerNorm((384,), eps=1e-05, elementwise_affine=True) torch.float32
torch.Size([384])
torch.Size([384])
(16, 256, 384) 3.0994415e-06


In [219]:
s_norm = LayerNorm((C,),weight_init=t_norm.weight,bias_init=t_norm.bias)
