# PyTorch Implementation

[torch.autograd](https://pytorch.org/docs/stable/autograd.html#) *provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. It requires minimal changes to the existing code.*

In [1]:
from torch.autograd import Function

In [2]:
class ReluFunction(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.clamp_min(0.)
        ctx.save_for_backward(i)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        i, = ctx.saved_tensors
        return (i>0).float() * grad_output

In [3]:
class MSE(Function):
    @staticmethod
    def forward(ctx, x, targets):
        result = (x.squeeze() - targets).pow(2).mean()
        ctx.save_for_backward(x, targets)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        x, targets = ctx.saved_tensors
        return (2. * (x.squeeze() - targets).unsqueeze(-1)) / targets.shape[0]

In [4]:
class LinearFunction(Function):
    @staticmethod
    def forward(ctx, i, w, b):
        result = i @ w.t() + b
        ctx.save_for_backward(i, w, b)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        i, w, b = ctx.saved_tensors
        return w.t() @ grad_output

[torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module)  *Base class for all neural network modules. Your models should also subclass this class.*

In [5]:
import torch
import torch.nn as nn
from math import sqrt

In [6]:
class Linear(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(n_out, n_in) * sqrt(2/n_in))
        self.bias = nn.Parameter(torch.zeros(n_out))
        
    def forward(self, x):
        return LinearFunction.apply(x, self.weights, self.bias)

In [7]:
class ReLU(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return ReluFunction.apply(x)

In [8]:
lin = Linear(10, 2)
p1, p2 = lin.parameters()
p1.shape, p2.shape

(torch.Size([2, 10]), torch.Size([2]))

# Module implementation

In [9]:
class MyModel(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = nn.Sequential(
            Linear(n_in, nh), ReLU(), Linear(nh, n_out))
        self.loss = MSE
        
    def forward(self, x, targets):
        return self.loss(self.layers(x).squeeze(), targets)

In [10]:
model = MyModel(10, 5, 2)

In [11]:
model

MyModel(
  (layers): Sequential(
    (0): Linear()
    (1): ReLU()
    (2): Linear()
  )
)