In [5]:
import numpy as np

In [15]:
import torch
import torch.nn as nn

In [6]:
class Initializer(object):

    def __call__(self, x):
        if x.ndim > 1:
            # init to a truncated normal dist
            x[:] = self.sample_from_truncated_normal(x.size).reshape(x.shape)
        else:
            # biases get initialized to zero
            x[:] = np.zeros_like(x)
        
    def sample_from_truncated_normal(self, size, stdev=0.02):
        min_value = -3 * stdev
        max_value = 3 * stdev
        truncated_normal = np.random.normal(0, stdev, size)
        oob_idxs = (truncated_normal > max_value) | (truncated_normal < min_value)
        num_oob = sum(oob_idxs)
        while num_oob:
            truncated_normal[oob_idxs] = np.random.normal(0, stdev, num_oob)
            oob_idxs = (truncated_normal > max_value) | (truncated_normal < min_value)
            num_oob = sum(oob_idxs)
        return truncated_normal

In [22]:
class Module(object):

    def __init__(self, **param_kwargs):
        self.grads  = {}
        self.params = {name: param for name, param in param_kwargs.items()}
        self.initializer = initializer=Initializer()
        self.reset_parameters()

    def reset_parameters(self):
        for x in self.params.values():
            self.initializer(x)

    def __call__(self, x):
        out = self.forward(x)
        return out

    def forward(self, x):
        """
        Transforms the input and then returns the output
        """        
        raise NotImplementedError
    
    def backward(self, dl_out):
        """
        Updates self.grads with the gradients of each parameter w.r.t. the loss and
        returns the derivative of the input w.r.t. the loss  
        """
        raise NotImplementedError

In [83]:
class Linear(Module):

    def __init__(self, in_dim, out_dim, bias=False):
        kwargs = {}
        kwargs["w"] = np.zeros((out_dim, in_dim))
        if bias:
            kwargs["bias"] = np.zeros(out_dim)
        super().__init__(**kwargs)
        
    def __call__(self, x):
        """
        x: (B, in_dim)

        Returns: (B, out_dim)
        """
        out = np.dot(x, self.params["w"].T)
        if "bias" in self.params:
            out += self.params["bias"]
        return out

    def backward(self, x, d_up):
        """
        x: (B, in_dim). input to this layer 
        
        d_up: (B, out_dim)

        Returns: the derivative of the output w.r.t. the input X
        """
        if "bias" in self.params:
            self.grads["bias"] = np.ones_like(self.bias)

        # (out_dim, in_dim) = (out_dim, B) x (B, in_dim)
        self.grads["w"] = np.dot(d_up.T, x)
        # (B, in_dim) = (B, out_dim) x (out_dim, in_dim)
        self.grads["in"] = np.dot(d_up, self.params["w"])


In [84]:
def test_forward(my_module, torch_module, np_in):
    tensor_in = torch.tensor(np_in)
    my_out = my_module(np_in)
    torch_out = torch_module(tensor_in).detach().numpy()
    return np.array_equal(my_out, torch_out)

### Test Linear Layer

In [92]:
def test_linear_forward(n_batch, d_in, d_out, use_bias):
    my_lin = Linear(d_in, d_out, bias=False)
    torch_lin = nn.Linear(d_in, d_out, bias=False)
    # set init weights equal
    my_lin.params["w"] = torch_lin.weight.detach().numpy()
    # create dummy input
    x_in = np.random.rand(n_batch, d_in).astype(np.float32)
    return test_forward(my_lin, torch_lin, x_in)

In [94]:
test_linear_forward(3, 4, 2, True)

True

In [98]:
torch_lin.weight.grad