# Intermediate PyTorch: Building architectures

## Every neural net requires the `init` method and the `forward` method. 
`init` : Define the kinds of layers from PyTorch you want to utilize
`forward` : Define the input feature matrix interaction with the layers

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


### Building a model class

In [3]:

class CustomNet(nn.Module): # nn.Module is the base class for all neural network modules
    ## Copilot recommended
    def __init__(self):
        super(CustomNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x): # Defines the computation performed at every call , every neuron
        # Forward pass
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x

    def predict(self, x):
        # Predicting the output
        pred = self.forward(x)
        return pred

    def get_weights(self):
        # Get the weights
        return self.fc1.weight, self.fc2.weight

    def get_bias(self):
        # Get the bias
        return self.fc1.bias, self.fc2.bias

### Preventing Exploding Gradients and Vanishing Gradients 
 - Use Batch Normalization : During training, we normalize the values of each layer neuron
 - Use elu instead of relu to prevent non-zero gradients for 0 valued outputs

In [6]:
# Sample Batch norm
class CustomNet(nn.Module): # nn.Module is the base class for all neural network modules
    ## Copilot recommended
    def __init__(self):
        super(CustomNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.bn1 = nn.BatchNorm1d(5)
        self.fc2 = nn.Linear(5, 3)
        self.bn2 = nn.BatchNorm1d(3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x): # Defines the computation performed at every call , every neuron
        x = F.elu(self.fc1(x))
        x = self.bn1(x)
        x = F.elu(self.fc2(x))
        x = self.bn2(x)
        x = F.sigmoid(self.fc3(x))
        return x
    
    
    