In [38]:
import numpy as np

In [61]:
class Initializer(object):

    def __call__(self, x):
        if x.ndim > 1:
            # init to a truncated normal dist
            x[:] = self.sample_from_truncated_normal(x.size).reshape(x.shape)
        else:
            # biases get initialized to zero
            x[:] = np.zeros_like(x)
        
    def sample_from_truncated_normal(self, size, stdev=0.02):
        min_value = -3 * stdev
        max_value = 3 * stdev
        truncated_normal = np.random.normal(0, stdev, size)
        oob_idxs = (truncated_normal > max_value) | (truncated_normal < min_value)
        num_oob = sum(oob_idxs)
        while num_oob:
            truncated_normal[oob_idxs] = np.random.normal(0, stdev, num_oob)
            oob_idxs = (truncated_normal > max_value) | (truncated_normal < min_value)
            num_oob = sum(oob_idxs)
        return truncated_normal

In [63]:
class Module(object):

    def __init__(self, **param_kwargs):
        self.grads  = {}
        self.params = {name: param for name, param in param_kwargs}
        self.initializer = initializer=Initializer()
        self.reset_parameters()

    def reset_parameters(self):
        for x in self.params.values():
            self.initializer(x)

    def __call__(self, x):
        out = self.forward(x)
        return out

    def forward(self, x):
        """
        Transforms the input and then returns the output
        """        
        raise NotImplementedError
    
    def backward(self, dl_out):
        """
        Updates self.grads with the gradients of each parameter w.r.t. the loss and
        returns the derivative of the input w.r.t. the loss  
        """
        raise NotImplementedError

In [66]:
class Linear(Module):

    def __init__(self, in_dim, out_dim, bias=False):
        kwargs = {}
        kwargs["w"] = np.zeros((out_dim, in_dim))
        if bias:
            kwargs["bias"] = np.zeros(out_dim)
        super().__init__(**kwargs)
        
    def __call__(self, x):
        """
        x: (B, in_dim)
        """
        out = np.dot(x, self.params["w"].T)
        if "bias" in self.params:
            out += self.params["bias"]

    def backward(self, dl_out):
        pass