# Rewriting the code in PyTorch

TODO insert line for installing necessary libs and run the final model in a Google Colab notebook to minimize train time (dont wanna wait)

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from typing import Any, List, Optional

%matplotlib inline

In [30]:
import abc


class Layer(abc.ABC):
    @abc.abstractmethod
    def __call__(self, x: torch.tensor) -> torch.tensor:
        pass

    @abc.abstractmethod
    def parameters() -> List[torch.tensor]:
        pass

In [50]:
g = torch.Generator().manual_seed(2147483647)


class Linear(Layer):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device: Optional[str] = None,
    ) -> None:
        self.weights = torch.randn(in_features, out_features, generator=g)
        self.bias = torch.zeros(out_features) if bias else None

    def __call__(self, x: torch.tensor) -> torch.tensor:

        self.out = x @ self.weights
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self) -> List[torch.tensor]:
        return [self.weights] + ([] if self.bias is None else [self.bias])


class BatchNorm1d(Layer):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
    ) -> None:
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # parameters
        self.gamma = torch.ones(num_features)
        self.beta = torch.zeros(num_features)
        # buffers (trained with a running 'momentum update')
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

    def __call__(self, x: torch.tensor) -> torch.tensor:
        # calculate the forward pass
        if self.training:
            xmean = x.mean(dim=0, keepdim=True)
            xvar = x.var(dim=0, keepdim=True)
        else:  # inference
            xmean = self.running_mean
            xvar = self.running_std

        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        self.out = xhat * self.gamma + self.beta

        # update the buffers
        if self.training:
            with torch.no_grad():
                self.running_mean = (
                    1 - self.momentum
                ) * self.running_mean + self.momentum * xmean
                self.running_var = (
                    1 - self.momentum
                ) * self.running_var + self.momentum * xvar
        return self.out

    def parameters(self) -> List[torch.tensor]:
        return [self.gamma, self.beta]


class Tanh(Layer):
    def __init__(self) -> None:
        pass

    def __call__(self, x: torch.tensor) -> torch.tensor:
        self.out = torch.tanh(x)
        return self.out

    def parameters(self) -> Optional[List]:
        return []

In [49]:
nn = Linear(2, 3, bias=False)
x = torch.randn((2, 2), generator=g)
nn(x)

tensor([[-1.6253, -0.8299, -0.2462],
        [ 0.6388, -0.4958,  0.1902]])

In [48]:
Linear(1, 1)
BatchNorm1d(2)(x)
Tanh()(x)

tensor([[ 0.8458, -0.1285],
        [-0.8126,  0.9125]])