# nn

> The nn module, a key part of the minima's toolkit, is a pretty cool and all-encompassing software.  
> It's built with a variety of different classes that can handle all sorts of needs in machine learning.  
> This module is meant to be easy to use and understand, so you can build complex structures by just putting  
> together pre-made parts. But, it's also flexible enough to let you make your own custom parts when you need to.  

In [None]:
#|default_exp nn

In [None]:
#| export
from typing import List, Callable, Any, Tuple
from minima.autograd import Tensor
from minima import operators
import minima.init as init
import numpy as np
import minima as mi
import torch

In [None]:
#| export
class Parameter(Tensor):
    """
    A kind of Tensor that is to be considered a module parameter.

    Parameters are `Tensor` subclasses, that have a very special property when used with
    `Module` s - when they're assigned as Module attributes they are automatically added
    to the list of its parameters, and will appear in `Module.parameters()` iterator.
    Another difference is that parameters can't be volatile and that they require gradient by default.
    """

In [None]:
#| export
def _unpack_params(value: object) -> List[Tensor]:
    """
    Unpack parameters from different Python objects.

    This function takes an object of type `Parameter`, `Module`, `dict`, `list`, or `tuple` and 
    recursively extracts any contained `Parameter` instances, returning them as a list. For other 
    object types, it returns an empty list.

    Args:
        value (object): The input object which could be of type `Parameter`, `Module`, `dict`, 
                        `list`, `tuple`, or any other type.

    Returns:
        List[Tensor]: A list containing all the `Parameter` instances found within the input object.
                      If no `Parameter` instances are found, an empty list is returned.
                      
    Example:
        module = nn.Module(...)
        params = _unpack_params(module)
        print(params)  # Prints list of `Parameter` instances contained in `module`.
    """
    if isinstance(value, Parameter):
        return [value]
    elif isinstance(value, Module):
        return list(value.parameters())
    elif isinstance(value, dict):
        return [item for v in value.values() for item in _unpack_params(v)]
    elif isinstance(value, (list, tuple)):
        return [item for v in value for item in _unpack_params(v)]
    return []

In [None]:
#| export
def _child_modules(value: object) -> List["Module"]:
    """
    Recursively unpack child modules from different Python objects.

    This function takes an object of type `Module`, `dict`, `list`, or `tuple` and 
    recursively extracts any contained `Module` instances, returning them as a list. 
    For other object types, it returns an empty list.

    Args:
        value (object): The input object which could be of type `Module`, `dict`, 
                        `list`, `tuple`, or any other type.

    Returns:
        List[Module]: A list containing all the `Module` instances found within 
                      the input object. If no `Module` instances are found, 
                      an empty list is returned.

    Example:
        class MyModule(Module):
            def __init__(self):
                super().__init__()
                self.layer1 = nn.Linear(20, 20)
                self.layer2 = nn.Linear(20, 20)
        
        my_module = MyModule()
        children = _child_modules(my_module)
        print(children)  # Prints list of `Module` instances contained in `my_module`.
    """
    if isinstance(value, Module):
        return [value] + _child_modules(value.__dict__)
    elif isinstance(value, dict):
        return [item for v in value.values() for item in _child_modules(v)]
    elif isinstance(value, (list, tuple)):
        return [item for v in value for item in _child_modules(v)]
    else:
        return []

In [None]:
#|export
class Module:
    
    """
    Base class for all neural network modules in Minima.

    Your models should also subclass this class. Subclasses should define a `forward` method.

    Attributes:
    - `training` (bool): Module is initialized in training mode by default. Use `eval()` to switch it to evaluation mode.

    Methods:
    - `parameters()`: Returns a list of all `Parameter` instances in the module.
    - `_children()`: Returns a list of all child `Module` instances.
    - `eval()`: Switches the module and all its children to evaluation mode.
    - `train()`: Switches the module and all its children back to training mode.
    - `__call__()`: The call method, which simply calls the `forward` method, must be defined by all subclasses.
    """
    
    def __init__(self):
        self.training = True
        self.forward_hooks = []
        self.backward_hooks = []

    def parameters(self) -> List[Parameter]:
        """
        Returns a list of all `Parameter` instances in the module.
        This is done by unpacking the parameters from the module's dictionary.
        """
        return _unpack_params(self.__dict__)

    def _children(self) -> List["Module"]:
        """
        Returns a list of all child `Module` instances in the module.
        This is done by unpacking the modules from the module's dictionary.
        """
        return _child_modules(self.__dict__)

    def register_forward_hook(self, hook):
        self.forward_hooks.append(hook)

    def register_backward_hook(self, hook):
        self.backward_hooks.append(hook)

    def extra_repr(self) -> str:
        """
        Set the extra information about this module. By default, it returns an empty string.
        You can redefine this method if you need to add extra information to the string output of `__repr__`.
        """
        return ''

    def _get_name(self):
        """
        Returns the class name of the module.
        """
        return self.__class__.__name__

    def __repr__(self):
        """
        Returns a string containing a brief description of the module.
        """
        main_str = self._get_name() + '('
        extra_str = self.extra_repr()
        child_str = ''

        for key, module in self.__dict__.items():
            if isinstance(module, Module):
                mod_str = repr(module)
                mod_str = self._add_indent(mod_str, 2)
                child_str += '  (' + key + '): ' + mod_str + '\n'

        if extra_str:
            # If extra information exists, add it to the main string
            main_str += extra_str

        if child_str:
            # If the module has children, add their information to the main string
            main_str += '\n'
            main_str += child_str

        main_str += ')'

        return main_str


    def _add_indent(self, s_, num_spaces):
        """
        Indents each line of the string `s_` with `num_spaces` spaces.
        """
        s = s_.split('\n')
        if len(s) == 1:
            return s_
        first = s.pop(0)
        s = [(num_spaces * ' ') + line for line in s]
        s = '\n'.join(s)
        s = first + '\n' + s
        return s


    def eval(self):
        """
        Switches the module and all its child modules to evaluation mode.
        """
        self.training = False
        for m in self._children():
            m.training = False

    def train(self):
        """
        Switches the module and all its child modules to training mode.
        """
        self.training = True
        for m in self._children():
            m.training = True

    def __call__(self, *args, **kwargs):
        """
        Defines the call method for the module.
        This method simply calls the forward method and must be overridden by all subclasses.
        """
        self.input = args
        self.output = self.forward(*args, **kwargs)
    
        # call forward hooks
        for hook in self.forward_hooks:
            hook(self, self.input, self.output)
    
        return self.output

    # def backward(self, *args):
    #     """
    #     Defines the backward method for the module.
    #     This method simply calls the _backward method and must be overridden by all subclasses.
    #     """
        
        # self.grad_input = args
        # self.grad_output = self._backward(*args)
    
        # # call backward hooks
        # for hook in self.backward_hooks:
        #     hook(self, self.grad_input, self.grad_output)
    
        # return self.grad_output


In [None]:
#| export
class Sequential(Module):
    """
    A sequential container in Minima.

    Modules will be added to it in the order they are passed in the constructor.
    A `Sequential` module contains a sequence of child modules stored in the order they were added. 
    Each module is applied in order to the input to produce the output.

    The `Sequential` class makes it easy to build networks where the output of one layer is the input to the next.

    Attributes:
    - `modules` (tuple of `Module`): The sequence of child modules to apply.

    Methods:
    - `forward(x: Tensor) -> Tensor`: Passes the input through all the child modules in sequential order.
    """
    def __init__(
        self,
        *modules # The sequence of child modules to apply. Each argument should be an instance of `Module`.
    ):
        """
        Initializes a new `Sequential` instance.
        
        Args:
            *modules: The sequence of child modules to apply. Each argument should be an instance of `Module`.
        """
        for i, module in enumerate(modules):
            setattr(self, f'module_{i}', module)
        super().__init__()
        self.modules = modules
        
    def forward(self, x: Tensor) -> Tensor:
        """
        Defines the forward pass for the sequential module.
        
        Passes the input through all the child modules in the order they were added.

        Args:
            x (Tensor): The input tensor.
        
        Returns:
            Tensor: The output tensor.
        """
        for module in self.modules:
            x = module(x)
        return x

    def __iter__(self):
        self._iter_idx = 0;
        return self
    def __next__(self):
        if self._iter_idx < len(self.modules):
            res = self.modules[self._iter_idx]
            self._iter_idx += 1
            return res
        raise StopIteration()


In [None]:
#| export
class Linear(Module):
    """
    A class representing a fully connected (linear) layer in a neural network.
    This class inherits from the `Module` class.

    Attributes:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        device (str): The device to store the Parameters on (defaults to None, which means CPU).
        dtype (str): The data type of the Parameters (defaults to 'float32').
        weight (Parameter): The weight parameters of the layer.
        bias (Parameter): The bias parameters of the layer, or None if bias=False.

    Methods:
        forward(X: Tensor) -> Tensor: Compute the forward pass of the layer.
    """
    
    def __init__(
        self,
        in_features, # The number of input features.
        out_features,# The number of output features.
        bias=True, # Whether or not to include a bias term. Default is True.
        device=None, # The device to store the Parameters on. Default is None, which means CPU.
        dtype="float32" # The data type of the Parameters. Default is 'float32'.
    ):
        """
        Initialize the layer with given input/output feature sizes and, optionally, bias, device, and dtype.

        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            bias (bool, optional): Whether or not to include a bias term. Default is True.
            device (str, optional): The device to store the Parameters on. Default is None, which means CPU.
            dtype (str, optional): The data type of the Parameters. Default is 'float32'.
        """
        
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.device = device
        self.dtype = dtype

        self.weight = Parameter(init.kaiming_uniform(fan_in=in_features, fan_out=out_features, device=device, dtype=dtype))
        self.bias = (Parameter(init.kaiming_uniform(fan_in=out_features, fan_out=1, device=device, dtype=dtype)).reshape((1, out_features))
                     if bias else None)
        
    def __repr__(self) -> str:
        return f'Linear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None})'
            
    def forward(self, X: Tensor) -> Tensor:
        """
        Compute the forward pass of the layer.

        This function applies the linear transformation to the input tensor X, 
        i.e., performs the matrix multiplication of X and the weight tensor, 
        and then adds the bias tensor (if bias is not None).

        Args:
            X (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor.
        """
        
        out = X @ self.weight
        out = out + self.bias.broadcast_to(out.shape) if self.bias else out
        return out

In [None]:
from minima.optim import *

In [None]:
class MyModule(Module):
    def __init__(self):
        super().__init__()
        self.layer1 = mi.nn.Linear(10, 20)
        self.layer2 = mi.nn.Linear(20, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

my_module = MyModule()
my_module

MyModule()

In [None]:
my_module.parameters()

[]

In [None]:
tt = mi.nn.Linear(in_features=10, out_features=5, bias=True)
isinstance(tt, mi.nn.Module)

True

In [None]:
#| export
class Flatten(Module):
    """
    A `Flatten` module in Minima.

    This module flattens an input tensor into a 2D matrix, typically for transitioning from convolutional layers to linear layers within a neural network model.

    Methods:
    - `forward(X: Tensor) -> Tensor`: Flattens the input tensor.
    """
    
    def forward(self, X: Tensor) -> Tensor:
        """
        Defines the forward pass for the Flatten module.
        
        This method flattens an input tensor along all dimensions except the batch dimension.

        Args:
            X (Tensor): The input tensor. It is expected to have at least two dimensions.

        Returns:
            Tensor: The output tensor, which is a 2D tensor with the same number of elements as the input tensor.
        """
        return X.reshape((X.shape[0], -1))


In [None]:
#| export
class ReLU(Module):
    def forward(self, x: Tensor) -> Tensor:
        return operators.relu(x)

In [None]:
#| export
class Sigmoid(Module):
    def forward(self, x: Tensor) -> Tensor:
        return 1 / (1 + operators.exp(-x))

The implementation you've shared is a numerically stable version of the Cross Entropy Loss formula, which is generally defined for a single sample as:

$$
H(p, q) = - \sum_{i} p_i \log(q_i)
$$

where:
- $p$ is the true distribution (in classification, typically a one-hot encoded vector),
- $q$ is the predicted distribution (output of the softmax function on the logits from the model).

However, directly implementing this formula can lead to numerical instability because of the log operation. The given implementation overcomes this by using the Log-Sum-Exp trick to prevent underflow or overflow.
$$ CE = -\sum_{i=1}^{C} y_i \log(\hat{y_i}) $$

where $y$ is the ground truth label and $\hat{y}$ is the predicted probability, and $C$ is the number of classes. In the case of one-hot encoding, only the term corresponding to the true class contributes to the sum. So, we can simplify it to:

$$
CE = -\log(\hat{y_c})
$$

where $c$ is the correct class.

The predicted probabilities $\hat{y}$ are obtained by applying the softmax function to the logits $z$:

$$
\hat{y_i} = \frac{e^{z_i}}{\sum_{j=1}^{C} e^{z_j}}
$$

Substituting $\hat{y_c}$ in the Cross-Entropy Loss, we have:

$$
CE = -\log\left(\frac{e^{z_c}}{\sum_{j=1}^{C} e^{z_j}}\right)
$$

Applying the logarithm property $\log(a/b) = \log(a) - \log(b)$, we get:

$$
CE = -z_c + \log\left(\sum_{j=1}^{C} e^{z_j}\right)
$$

First, `log_sum_exp_logits = operators.logsumexp(logits, axes=(1, )).sum()` computes the term $\log\left(\sum_{j=1}^{C} e^{z_j}\right)$. The function `logsumexp` computes the logarithm of the sum of exponentials in a numerically stable way, and then these values are summed over all samples.

Second, `true_class_logits_sum = (logits * init.one_hot(logits.shape[1], y)).sum()` computes the $-z_c$ term for each sample. The function `init.one_hot(logits.shape[1], y)` creates a one-hot encoding of the true labels, and this is then multiplied with the logits to pick out the logits for the correct classes. These values are then summed over all samples.

Finally, `(log_sum_exp_logits - true_class_logits_sum) / logits.shape[0]` computes the average loss per sample.

In [None]:
logits = mi.Tensor([[ 0.6734,  0.2576],
        [ 0.4689,  0.4607],
        [-2.2457, -0.3727],
        [ 4.4164, -1.2760],
        [ 0.9233,  0.5347],
        [ 1.0698,  1.6187]])
targ = mi.Tensor([0,1,0,1,1,0])

In [None]:
log_sum_exp_logits = operators.logsumexp(logits, axes=(1, ))
log_sum_exp_logits

minima.Tensor(
[ 1.180104  1.157956 -0.229759  4.419766  1.440906  2.074595])

In [None]:
log_sum_exp_logits = log_sum_exp_logits.sum()

In [None]:
one_hot_y = init.one_hot(logits.shape[1], targ)
one_hot_y

minima.Tensor(
[[1. 0.]
 [0. 1.]
 [1. 0.]
 [0. 1.]
 [0. 1.]
 [1. 0.]])

In [None]:
true_class_logits = (logits * one_hot_y)
true_class_logits

minima.Tensor(
[[ 0.6734  0.    ]
 [ 0.      0.4607]
 [-2.2457 -0.    ]
 [ 0.     -1.276 ]
 [ 0.      0.5347]
 [ 1.0698  0.    ]])

In [None]:
true_class_logits_sum = true_class_logits.sum()
true_class_logits_sum

minima.Tensor(
-0.7830999999999999)

In [None]:
loss = (log_sum_exp_logits - true_class_logits_sum) / targ.shape[0]
loss

minima.Tensor(
1.8044446680186768)

In [None]:
loss = torch.nn.CrossEntropyLoss()(torch.tensor(logits.numpy()), torch.tensor(targ.numpy()))
loss

tensor(1.8044, dtype=torch.float64)

In [None]:
#| export
class CrossEntropyLoss(Module):
    """
    Cross-entropy loss module in Minima.

    This module computes the Cross Entropy Loss between the input logits and the target classes. 
    It's useful in classification tasks where the model outputs probabilities for each class.

    Methods:
    - `forward(input: Tensor, target: Tensor) -> Tensor`: Calculates the cross-entropy loss between the input (logits) and the target (class indices).

    Example:
    ```python
    model = Sequential(
        Linear(10, 20),
        ReLU(),
        Linear(20, 10),
    )
    loss_fn = CrossEntropyLoss()
    output = model(input_tensor)  # compute model output
    loss = loss_fn(output, target_tensor)  # compute loss
    ```
    """

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        """
        Computes the Cross Entropy Loss between the input logits and the target class indices.

        Args:
            input (Tensor): The input tensor. The logits, typically of shape (batch_size, num_classes).
            target (Tensor): The target tensor. The correct class indices, typically of shape (batch_size).

        Returns:
            Tensor: A single tensor that is the average cross-entropy loss.
        """
        log_sum_exp_logits = operators.logsumexp(input, axes=(1, )).sum()
        true_class_logits_sum = (input * init.one_hot(input.shape[1], target)).sum()
        return (log_sum_exp_logits - true_class_logits_sum) / input.shape[0]

In [None]:
#| export
class Softmax(Module):
    """
    Cross-entropy loss module in Minima.

    This module computes the Cross Entropy Loss between the input logits and the target classes. 
    It's useful in classification tasks where the model outputs probabilities for each class.

    Methods:
    - `forward(input: Tensor, target: Tensor) -> Tensor`: Calculates the cross-entropy loss between the input (logits) and the target (class indices).

    Example:
    ```python
    model = Sequential(
        Linear(10, 20),
        ReLU(),
        Linear(20, 10),
    )
    loss_fn = CrossEntropyLoss()
    output = model(input_tensor)  # compute model output
    loss = loss_fn(output, target_tensor)  # compute loss
    ```
    """

    def forward(self, input: Tensor, dim=1) -> Tensor:
        """
        Computes the Cross Entropy Loss between the input logits and the target class indices.

        Args:
            input (Tensor): The input tensor. The logits, typically of shape (batch_size, num_classes).
            target (Tensor): The target tensor. The correct class indices, typically of shape (batch_size).

        Returns:
            Tensor: A single tensor that is the average cross-entropy loss.
        """

        # import pdb; pdb.set_trace()
        exps = operators.exp(input)
        exps_sum = operators.summation(exps, axes=(1,))
        return exps / operators.broadcast_to(operators.reshape(exps_sum, shape=exps_sum.shape + (1,)), shape=exps.shape)

In [None]:
sm = Softmax()
t = mi.Tensor([[ 1.25721,   1.066163],
               [ 1.123695, -0.410479],
               [ 0.047494, -1.200239],
               [ 0.440039,  0.884834],])


In [None]:
sm(t)

minima.Tensor(
[[0.547617 0.452383]
 [0.822616 0.177384]
 [0.776907 0.223093]
 [0.390599 0.609401]])

## Layer normalization

Layer normalization $\text{LN}$ normalizes the input $X$ as follows:

When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings,
where $B$ is the batch size and $C$ is the number of features.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{LN}(X) = \gamma
\frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}}
+ \beta$$

When input $X \in \mathbb{R}^{L \times B \times C}$ is a batch of a sequence of embeddings,
where $B$ is the batch size, $C$ is the number of channels, $L$ is the length of the sequence.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{LN}(X) = \gamma
\frac{X - \underset{C}{\mathbb{E}}[X]}{\sqrt{\underset{C}{Var}[X] + \epsilon}}
+ \beta$$

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
This is not a widely used scenario.
$\gamma \in \mathbb{R}^{C \times H \times W}$ and $\beta \in \mathbb{R}^{C \times H \times W}$.
$$\text{LN}(X) = \gamma
\frac{X - \underset{C, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{C, H, W}{Var}[X] + \epsilon}}
+ \beta$$


In [None]:
X = mi.Tensor(init.rand(5, 10))
X

minima.Tensor(
[[0.401392 0.778235 0.702728 0.115071 0.143369 0.012648 0.923175 0.204186 0.212064 0.821985]
 [0.80416  0.933441 0.316713 0.517084 0.458262 0.140639 0.414799 0.065997 0.959522 0.880972]
 [0.002626 0.401024 0.857979 0.276518 0.41354  0.931634 0.253994 0.325883 0.719129 0.820806]
 [0.889366 0.948938 0.519644 0.53989  0.975683 0.814937 0.800552 0.066098 0.409378 0.033692]
 [0.480783 0.822874 0.294807 0.447157 0.471698 0.307358 0.515836 0.959066 0.761911 0.579533]])

In [None]:
bs, fs = X.shape
bs, fs

(5, 10)

In [None]:
mean = X.sum(axes=(1,)) / fs
mean

minima.Tensor(
[0.431485 0.549159 0.500313 0.599818 0.564102])

In [None]:
mean = mean.reshape((bs, 1))
mean

minima.Tensor(
[[0.431485]
 [0.549159]
 [0.500313]
 [0.599818]
 [0.564102]])

In [None]:
mean = mean.broadcast_to(X.shape)
mean

minima.Tensor(
[[0.431485 0.431485 0.431485 0.431485 0.431485 0.431485 0.431485 0.431485 0.431485 0.431485]
 [0.549159 0.549159 0.549159 0.549159 0.549159 0.549159 0.549159 0.549159 0.549159 0.549159]
 [0.500313 0.500313 0.500313 0.500313 0.500313 0.500313 0.500313 0.500313 0.500313 0.500313]
 [0.599818 0.599818 0.599818 0.599818 0.599818 0.599818 0.599818 0.599818 0.599818 0.599818]
 [0.564102 0.564102 0.564102 0.564102 0.564102 0.564102 0.564102 0.564102 0.564102 0.564102]])

In [None]:
x_centred = X - mean
x_centred

minima.Tensor(
[[-0.030094  0.34675   0.271242 -0.316414 -0.288116 -0.418837  0.49169  -0.2273   -0.219421  0.3905  ]
 [ 0.255001  0.384282 -0.232446 -0.032074 -0.090896 -0.40852  -0.13436  -0.483162  0.410363  0.331813]
 [-0.497688 -0.09929   0.357666 -0.223795 -0.086773  0.431321 -0.24632  -0.17443   0.218816  0.320493]
 [ 0.289548  0.34912  -0.080174 -0.059928  0.375866  0.215119  0.200735 -0.53372  -0.19044  -0.566126]
 [-0.08332   0.258772 -0.269295 -0.116946 -0.092404 -0.256744 -0.048266  0.394964  0.197809  0.01543 ]])

In [None]:
#| export
class LayerNorm1d(Module):
    """
    1D Layer normalization module in Minima.

    Applies layer normalization over a 1D input. The mean and standard deviation are computed over the last dimension.

    Attributes:
    - `dim` (int): The dimension of the input feature space.
    - `eps` (float): A small constant for numerical stability.
    - `weight` (Parameter): The learnable weights of the module of size 'dim', initialized with ones.
    - `bias` (Parameter): The learnable bias of the module of size 'dim', initialized with zeros.

    Methods:
    - `forward(x: Tensor) -> Tensor`: Applies layer normalization to the input tensor.

    """
    def __init__(
        self,
        dim: int, # The dimension of the input feature space.
        eps=1e-5, # A small constant for numerical stability. Default is 1e-5.
        device=None, # The desired device of returned tensor. If None, uses the current device for the default tensor type. Default is None.
        dtype="float32" # The desired data type of returned tensor. If None, uses the default data type. Default is "float32".
    ):
        """
        Initializes a new `LayerNorm1d` instance.
        
        Args:
            dim: The dimension of the input feature space.
            eps: A small constant for numerical stability. Default is 1e-5.
            device: The desired device of returned tensor. If None, uses the current device for the default tensor type. Default is None.
            dtype: The desired data type of returned tensor. If None, uses the default data type. Default is "float32".
        """
        super().__init__()
        self.dim = dim
        self.eps = eps
        
        self.weight = Parameter(init.ones(dim, device=device, dtype=dtype, requires_grad=True))
        self.bias = Parameter(init.zeros(dim, device=device, dtype=dtype, requires_grad=True))

    def forward(self, x: Tensor) -> Tensor:
        """
        Applies the layer normalization over the input.

        Args:
            x (Tensor): The input tensor of shape (batch_size, num_features).

        Returns:
            Tensor: The output tensor after applying layer normalization.
        """
        bs, fs = x.shape
        axes = (-1,)
        mean = x.sum(axes=axes).reshape((bs, 1)) / fs
        x_centered = x - mean.broadcast_to(x.shape)
        std = ((x_centered ** 2).sum(axes=axes).reshape((bs, 1)) / fs + self.eps) ** 0.5
        x_normed = x_centered / std.broadcast_to(x.shape)
        return self.weight.broadcast_to(x.shape) * x_normed + self.bias.broadcast_to(x.shape)


## Batch Norm


This is an implementation of Batch Normalization from paper
 [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://papers.labml.ai/paper/1502.03167).

### Internal Covariate Shift

The paper defines *Internal Covariate Shift* as the change in the
distribution of network activations due to the change in
network parameters during training.
For example, let's say there are two layers $l_1$ and $l_2$.
During the beginning of the training $l_1$ outputs (inputs to $l_2$)
could be in distribution $\mathcal{N}(0.5, 1)$.
Then, after some training steps, it could move to $\mathcal{N}(0.6, 1.5)$.
This is *internal covariate shift*.

Internal covariate shift will adversely affect training speed because the later layers
($l_2$ in the above example) have to adapt to this shifted distribution.

By stabilizing the distribution, batch normalization minimizes the internal covariate shift.

## Normalization

It is known that whitening improves training speed and convergence.
*Whitening* is linearly transforming inputs to have zero mean, unit variance,
and be uncorrelated.

### Normalizing outside gradient computation doesn't work

Normalizing outside the gradient computation using pre-computed (detached)
means and variances doesn't work. For instance. (ignoring variance), let
$$\hat{x} = x - \mathbb{E}[x]$$
where $x = u + b$ and $b$ is a trained bias
and $\mathbb{E}[x]$ is an outside gradient computation (pre-computed constant).

Note that $\hat{x}$ has no effect on $b$.
Therefore,
$b$ will increase or decrease based
$\frac{\partial{\mathcal{L}}}{\partial x}$,
and keep on growing indefinitely in each training update.
The paper notes that similar explosions happen with variances.

### Batch Normalization

Whitening is computationally expensive because you need to de-correlate and
the gradients must flow through the full whitening calculation.

The paper introduces a simplified version which they call *Batch Normalization*.
First simplification is that it normalizes each feature independently to have
zero mean and unit variance:
$$\hat{x}^{(k)} = \frac{x^{(k)} - \mathbb{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$
where $x = (x^{(1)} ... x^{(d)})$ is the $d$-dimensional input.

The second simplification is to use estimates of mean $\mathbb{E}[x^{(k)}]$
and variance $Var[x^{(k)}]$ from the mini-batch
for normalization; instead of calculating the mean and variance across the whole dataset.

Normalizing each feature to zero mean and unit variance could affect what the layer
can represent.
As an example paper illustrates that, if the inputs to a sigmoid are normalized
most of it will be within $[-1, 1]$ range where the sigmoid is linear.
To overcome this each feature is scaled and shifted by two trained parameters
$\gamma^{(k)}$ and $\beta^{(k)}$.
$$y^{(k)} =\gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}$$
where $y^{(k)}$ is the output of the batch normalization layer.

Note that when applying batch normalization after a linear transform
like $Wu + b$ the bias parameter $b$ gets cancelled due to normalization.
So you can and should omit bias parameter in linear transforms right before the
batch normalization.

Batch normalization also makes the back propagation invariant to the scale of the weights
and empirically it improves generalization, so it has regularization effects too.

## Inference

We need to know $\mathbb{E}[x^{(k)}]$ and $Var[x^{(k)}]$ in order to
perform the normalization.
So during inference, you either need to go through the whole (or part of) dataset
and find the mean and variance, or you can use an estimate calculated during training.
The usual practice is to calculate an exponential moving average of
mean and variance during the training phase and use that for inference.


Batch normalization layer $\text{BN}$ normalizes the input $X$ as follows:

When input $X \in \mathbb{R}^{B \times C \times H \times W}$ is a batch of image representations,
where $B$ is the batch size, $C$ is the number of channels, $H$ is the height and $W$ is the width.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B, H, W}{\mathbb{E}}[X]}{\sqrt{\underset{B, H, W}{Var}[X] + \epsilon}}
+ \beta$$

When input $X \in \mathbb{R}^{B \times C}$ is a batch of embeddings,
where $B$ is the batch size and $C$ is the number of features.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B}{\mathbb{E}}[X]}{\sqrt{\underset{B}{Var}[X] + \epsilon}}
+ \beta$$

When input $X \in \mathbb{R}^{B \times C \times L}$ is a batch of a sequence embeddings,
where $B$ is the batch size, $C$ is the number of features, and $L$ is the length of the sequence.
$\gamma \in \mathbb{R}^{C}$ and $\beta \in \mathbb{R}^{C}$.
$$\text{BN}(X) = \gamma
\frac{X - \underset{B, L}{\mathbb{E}}[X]}{\sqrt{\underset{B, L}{Var}[X] + \epsilon}}
+ \beta$$

In [None]:
#| export
class BatchNorm1d(Module):
    """
    1D Batch normalization module in Minima.

    This module applies Batch Normalization over a 1D input as described in the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" by Ioffe and Szegedy.

    Attributes:
    - `dim` (int): The dimension of the input feature space.
    - `eps` (float): A small constant added to the denominator for numerical stability.
    - `momentum` (float): The value used for the running_mean and running_var computation.
    - `weight` (Parameter): The learnable scale factor of the module of size 'dim', initialized with ones.
    - `bias` (Parameter): The learnable offset of the module of size 'dim', initialized with zeros.
    - `running_mean` (Tensor): The running mean. Represents the mean of the features over batches. Initialized with zeros.
    - `running_std` (Tensor): The running standard deviation. Represents the standard deviation of the features over batches. Initialized with ones.

    Methods:
    - `update_stats(x: Tensor) -> Tuple[Tensor, Tensor]`: Calculates the mean and standard deviation of the input tensor.
    - `forward(x: Tensor) -> Tensor`: Applies batch normalization to the input tensor.

    Example:
    ```python
    batch_norm = BatchNorm1d(dim=512)
    output = batch_norm(input_tensor)  # Apply batch normalization
    ```

    """
    def __init__(
        self,
        dim: int,
        eps=1e-5,
        momentum=0.1,
        device=None,
        dtype="float32"
    ):
        """
        Initializes a new `BatchNorm1d` instance.
        
        Args:
            dim: The dimension of the input feature space.
            eps: A small constant for numerical stability. Default is 1e-5.
            momentum: The value used for the running_mean and running_var computation. Default is 0.1.
            device: The desired device of returned tensor. If None, uses the current device for the default tensor type. Default is None.
            dtype: The desired data type of returned tensor. If None, uses the default data type. Default is "float32".
        """
        
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.momentum = momentum
        
        self.weight = Parameter(init.ones(dim, device=device, dtype=dtype, requires_grad=True))
        self.bias = Parameter(init.zeros(dim, device=device, dtype=dtype, requires_grad=True))
        
        self.running_mean = Tensor(init.zeros(dim, device=device, dtype=dtype))
        self.running_std = Tensor(init.ones(dim, device=device, dtype=dtype))
        
    def update_stats(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Updates the running mean and running variance of the input tensor.
        
        Parameters:
        ----------
        x : Tensor
            Input tensor.
        
        Returns:
        ----------
        Tuple[Tensor, Tensor]
            Mean and variance of the input tensor.
        """

        bs, fs = x.shape
        axes=(0,)
        mean = x.sum(axes=axes) / bs
        x_centered = x - mean.broadcast_to(x.shape)
        std = ((x_centered ** 2).sum(axes=axes) / bs)
        self.running_mean = self.momentum * mean.data  + (1 - self.momentum) * self.running_mean
        self.running_std = self.momentum * std.data + (1 - self.momentum) * self.running_std
        return mean,std

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward propagation of the batch normalization layer.
        
        Applies the batch normalization to the input tensor.
        
        Parameters:
        ----------
        x : Tensor
            Input tensor.
        
        Returns:
        ----------
        Tensor
            Output tensor after applying batch normalization.
        """
        
        if self.training:
            mean, std = self.update_stats(x)
        else:
            mean, std = self.running_mean, self.running_std
        x_normed = (x - mean.broadcast_to(x.shape)) / (std.broadcast_to(x.shape) + self.eps) ** .5
        return self.weight.broadcast_to(x.shape) * x_normed + self.bias.broadcast_to(x.shape)

In [None]:
#| export
class Dropout(Module):
    """
    Dropout Layer for a Neural Network.
    
    This class represents a dropout layer in a neural network, which is a simple 
    and effective regularization technique.
    During training, it randomly zeroes out some of the elements of the input tensor
    with probability p using samples from a Bernoulli distribution.
    
    Parameters:
    ----------
    p: float, optional, default = 0.5
        Probability of an element to be zeroed. Default: 0.5.
    """
    
    def __init__(self, p = 0.5):
        """
        Initializes the Dropout layer with the specified probability.
        
        Parameters:
        ----------
        p : float
            Probability of an element to be zeroed.
        """
        
        super().__init__()
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward propagation of the dropout layer.
        
        If the layer is in training mode, it applies dropout to the input tensor. 
        If the layer is in evaluation mode, it returns the input tensor as is.
        
        Parameters:
        ----------
        x : Tensor
            Input tensor.
        
        Returns:
        ----------
        Tensor
            Output tensor after applying dropout.
        """
        
        binary_mask = init.randb(*x.shape, p=self.p)
        if self.training:
            return (binary_mask * x) / (1 - self.p)
        return x


In [None]:
#| export
class Residual(Module):
    """
    Residual Layer for a Neural Network.
    
    This class represents a residual layer in a neural network, which is a technique that helps to overcome
    the problem of vanishing and exploding gradients in deep neural networks. It achieves this by allowing
    gradients to pass through layers directly (via an identity shortcut connection) without any modification.

    Parameters:
    ----------
    fn: Module
        The function to be applied to the input tensor.
    """
    
    def __init__(self, fn: Module):
        """
        Initializes the Residual layer with the specified function.
        
        Parameters:
        ----------
        fn : Module
            The function to be applied to the input tensor.
        """
        super().__init__()
        self.fn = fn

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward propagation of the residual layer.
        
        Applies the function to the input tensor and then adds the result to the original input tensor.
        
        Parameters:
        ----------
        x : Tensor
            Input tensor.
        
        Returns:
        ----------
        Tensor
            Output tensor after applying the function and adding the result to the original input.
        """
        return x + self.fn(x)

In [None]:
#| export
class Identity(Module):
    def forward(self, x):
        return x

## Export

In [None]:
import nbdev; nbdev.nbdev_export()