<a href="https://colab.research.google.com/github/kmalik22/colabs/blob/main/transformer_numpy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
import numpy as np
import torch
from typing import List

In [28]:
D_MODEL = 3
D_FF = int(D_MODEL * 2)
BSZ = 2

In [34]:
#class torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
class MyLinear:
  def __init__(self, in_features, out_features, bias=True, debug=True):
    # create a matrix (in_features, out_features)
    # backprop. During forward, store activations in a buffer.
    self.in_features = in_features
    self.out_features = out_features
    self.has_bias = bias
    self.weight = np.random.normal(0, 1, size=(in_features, out_features))
    self.bias = np.zeros(shape=(out_features))
    self.stored_activations: List[np.array] = []
    self.weight_grad = None
    self.bias_grad = None
    self.debug = debug

  def forward(self, activations: np.array):
    assert len(activations.shape) >= 2
    assert activations.shape[-1] == self.in_features
    self.stored_activations.append(activations)
    if self.debug:
       print(f"MyLinear.forward(), batch={len(self.stored_activations)}")
    return ( (activations @ self.weight) + self.bias)

  def backward(self, output_act_grad: np.array) -> np.array:
    """Computes weight grad internally. Returns input_act_grad
    """
    assert len(self.stored_activations) > 0
    input_acts = np.concat(self.stored_activations)
    if self.debug:
      print(f"MyLinear.bwd(), batches:{len(self.stored_activations)}, act.shape:{input_acts.shape}")
    # input_acts = (bsz in_features)
    # output_act_grad = (bsz out_features)
    # wts = (in_features out_features)
    # wgrad = (in_features out_features)
    # input_act_grad = (bsz in_features)
    bsz = input_acts.shape[0]
    assert input_acts.shape[1] == self.in_features, f"{input_acts.shape[1]=} {self.in_features=}"
    assert output_act_grad.shape == (bsz, self.out_features), f"{output_act_grad.shape=} {bsz}"
    self.weight_grad = input_acts.transpose() @ output_act_grad # (in_features, bsz) @ (bsz, out_features)
    self.bias_grad = output_act_grad.sum(axis=0) # (out_features)
    return output_act_grad @ self.weight.transpose()

  def clear_grads(self):
    self.stored_activations = []
    self.weight_grad = None
    self.bias_grad = None


def make_wts_same(my_lin: MyLinear, torch_lin: torch.nn.Linear):
  """In place modifies torch_lin to have same wts as my_lin
  """
  state_dict =  {
        "weight": torch.tensor(my_lin.weight.transpose()),
  }
  if my_lin.has_bias:
    state_dict["bias"]: torch.tensor(my_lin.bias)
  torch_lin.load_state_dict(state_dict)

def make_similar_linear(my_lin: MyLinear) -> torch.nn.Linear:
  """Returns a torch.nn.Linear class that has the same weights as MyLinear
  """
  rv = torch.nn.Linear(in_features = my_lin.in_features, out_features=my_lin.out_features, bias=my_lin.has_bias, dtype=torch.float32)
  make_wts_same(my_lin, rv)
  return rv


class MyReLU:
  def __init__():
    self.stored_inp_activations: List[np.array] = []

  def forward(self, activations: np.array) -> np.array:
    self.stored_inp_activations.append(activations)
    return np.where(activations>0, activations, 0)

  def backward(self, output_act_grad: np.array) -> np.array:
    # output_act_grad: (bsz, d_model)
    # self.stored_inp_activations: (bsz, d_model)
    inp_activations = np.concat(self.stored_inp_activations)
    input_act_grad = np.where(inp_activations > 0, inp_activations, 0)
    return input_act_grad.sum(axis=0)


def make_two_linears(input_dim, output_dim):
  custom_lin =  MyLinear(input_dim, output_dim, False)
  torch_lin = make_similar_linear(custom_lin)
  return custom_lin, torch_lin

def clear_grads(custom_lin, torch_lin):
  for p in torch_lin.parameters():
    if p.grad is not None:
      p.grad = None
  custom_lin.clear_grads()


def get_random_act(in_features, bsz):
    # create random activations in numpy
    random_act = np.random.normal(0, 1, size=(bsz, in_features)).astype(np.float32)
    # convert to torch tensor
    torch_random_act = torch.tensor(random_act, dtype=torch.float32, requires_grad=True)
    return random_act, torch_random_act


def compare_linear(my_lin, torch_lin, in_features):
    random_act, torch_random_act = get_random_act(in_features, BSZ)
    # forward passes
    my_act = my_lin.forward(random_act)
    torch_act = torch_lin(torch_random_act)

    # check they match (within tolerance)
    assert np.all(np.isclose(my_act, torch_act.detach().numpy()))
    return my_act, torch_act, random_act, torch_random_act


# Match Forward Backward one layer

In [35]:
def match_fwd_bwd_one_mlp_simple_loss():
  my_linear1, torch_linear1 = make_two_linears(D_MODEL, D_FF)
  my_output_act, torch_output_act, my_input_act, torch_input_act = compare_linear(my_linear1, torch_linear1, D_MODEL)
  print("fwd matches")

  torch_loss = torch_output_act.sum()
  torch_loss.backward()

  my_loss = my_output_act.sum() #dloss/d_my_act = 1
  my_inp_act_grad = my_linear1.backward(output_act_grad=np.ones_like(my_output_act))

  assert np.all(np.allclose(my_linear1.weight_grad, torch_linear1.weight.grad.t()))
  assert np.all(np.allclose(my_inp_act_grad, torch_input_act.grad))
  print("Bwd matches")


def match_fwd_bwd_one_mlp_complex_loss():
  my_linear1, torch_linear1 = make_two_linears(D_MODEL, D_FF)
  my_output_act, torch_output_act, my_input_act, torch_input_act = compare_linear(my_linear1, torch_linear1, D_MODEL)
  print("fwd matches")

  torch_loss = torch_output_act.square().sum()/2
  torch_loss.backward()

  my_loss = np.square(my_output_act).sum()/2  # grad = my_output_act
  my_inp_act_grad = my_linear1.backward(output_act_grad=my_output_act)

  assert np.all(np.allclose(my_linear1.weight_grad, torch_linear1.weight.grad.t()))
  assert np.all(np.allclose(my_inp_act_grad, torch_input_act.grad))
  print("Bwd matches")


# Match grad with simple dl/dy

In [36]:
match_fwd_bwd_one_mlp_simple_loss()

MyLinear.forward(), batch=1
fwd matches
MyLinear.bwd(), batches:1, act.shape:(2, 3)
Bwd matches


# more complex grad

In [37]:
match_fwd_bwd_one_mlp_complex_loss()

MyLinear.forward(), batch=1
fwd matches
MyLinear.bwd(), batches:1, act.shape:(2, 3)
Bwd matches


# MLP

In [33]:
# linear --> ReLu --> linear

class TorchMLP(torch.nn.Module):
  def __init__(self, d_model, d_ffn):
    super().__init__()
    self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_ffn)
    self.linear2 = torch.nn.Linear(in_features=d_ffn, out_features=d_model)

  def forward(self, inp_act: torch.Tensor):
    linear1_out = self.linear1(inp_act)
    relu_out = torch.nn.functional.relu(linear1_out)
    return self.linear2(relu_out)


class MyMLP:
  def __init__(self, d_model, d_ffn):
    self.linear1 = MyLinear(in_features=d_model, out_features=d_ffn)
    self.relu = MyReLU()
    self.linear2 = MyLinear(in_features=d_ffn, out_features=d_model)

  def forward(self, inp_act: torch.Tensor):
    linear1_out = self.linear1.forward(inp_act)
    relu_out = self.relu.forward(linear1_out)
    linear2_out = self.linear2.forward(relu_out)
    return linear2_out

  def backward(self, output_act_grad: torch.Tensor):
    linear2_input_act_grad = self.linear2.backward(output_act_grad)
    relu_input_act_grad = self.relu.backward(linear2_input_act_grad)
    linear1_input_act_grad = self.linear1.backward(relu_input_act_grad)


  def clear_grads(self):
    self.linear1.clear_grads()
    self.relu.clear_grads()
    self.linear2.clear_grads()


def make_mlp_wts_same(my_mlp: MyMLP, torch_mlp: TorchMLP):
  make_wts_same(my_mlp.linear1, torch_mlp.linear1)
  make_wts_same(my_mlp.linear2, torch_mlp.linear2)


In [None]:
custom_mlp = MyMLP(D_MODEL, D_FF)
torch_mlp = TorchMLP(D_MODEL, D_FF)

