# Neural networks and modules assembly

The `EmbeddingNetwork` module is a torch `Module` that permit to easily add torch modules before or after a neural network. It can be useful to customize a neural network from a classic architecture like `FullyConnected`. It can also be used to mimic the use `Operator` with torch functions, for instance if we want to differentiate a network with respect to the inputs variables rather than normalized variables.

In [1]:
import os
import sys

sys.path.append(os.path.join(os.path.abspath(""), ".."))

import torch
import numpy as np
from torch import nn

from nnbma.layers import AdditionalModule, AdditionalModuleFromExisting
from nnbma.networks import FullyConnected, EmbeddingNetwork

## `AdditionalModule` module

An `AdditionalModule` is basically a torch `Module`. The advantage of these modules is to ensure upstream compatibility of input and output dimensions.

In addition to the Module class, they have two attributes `input_features` and `output_features`. As these modules are compatible with the use of batches, these values correspond to the last dimension of the tensors.

Here's an example of a module that takes tensors of size 2 as arguments and returns tensors of size 3.

In [2]:
class MatMul(AdditionalModule):
    def __init__(self):
        super().__init__(3, 2)
        self.W = torch.normal(0, 1, size=(3, 2))

    def forward(self, x):
        return torch.matmul(x, self.W)


matmul = MatMul()

You may want to create a module that takes a tensor of arbitrary size as input and return also a tensor of arbitrary size.

__Note:__ In this case, we show an alternative to the implementation based on `AdditionalModule`, using this time `AdditionalModuleFromExisting`. This class is useful when the additional module is directly based on an existing Torch function or Module as we don't need to override the class.

In [3]:
exp = AdditionalModuleFromExisting(None, None, torch.exp)

Alternatively, you may want to create a module that takes as input a tensor of arbitrary size and returns a tensor of fixed size.

In [4]:
class Moments(AdditionalModule):
    def __init__(self):
        super().__init__(None, 2)

    def forward(self, x):
        m1 = torch.mean(x, axis=-1, keepdim=True)
        m2 = torch.mean((x - m1) ** 2, axis=-1, keepdim=True)
        return torch.concatenate((m1, m2), axis=-1)


moments = Moments()

## `EmbeddingNetwork` example

The `EmbeddingNetwork` module allows to chain several `AdditionalModule` instances before and/or after an instance of `NeuralNetwork`. The only limitation is the compatibility of the number of input and output features between two consecutive modules.

- If a module has a fixed number of output features `output_features`, the next module must have an `input_features` attribute which is identical.
- If a module has an arbitrary number of output features (`output_features = None`), the next module must also have an arbitrary number of input features (`input_features = None`). __Note:__ the inverse is not true, a module with a fixed number of output is compatible with a module with an arbitrary number of input.

We assume that we have the following `NeuralNetwork` which compute 20 outputs from 2 inputs:

In [5]:
subnet = FullyConnected(
    [2, 10, 10, 20],
    nn.ReLU(),
)
print(subnet.input_features, subnet.output_features)

2 20


We will use it as a base to build a larger model making the following operation:
- Multiplication by a 3x2 matrix to map 3 inputs into 2 outputs
- Processing by the fully connected neural network
- Application of the exponential function
- Computation of th mean and the variance of the different features

__Note:__ This architecture has only been created to set an example, and it seems unlikely that it will be of any practical use.

In [6]:
net = EmbeddingNetwork(subnet, preprocessing=[matmul], postprocessing=[exp, moments])







[MatMul()]












[AdditionalModuleFromExisting(), Moments()]








In [7]:
x = np.random.normal(0, 1, size=(10, 3)).astype("float32")

net(x)

array([[0.98087585, 0.07246768],
       [0.9779798 , 0.06297462],
       [0.9804139 , 0.05586525],
       [0.9834059 , 0.04427576],
       [0.9792155 , 0.03643883],
       [0.9786695 , 0.04505088],
       [0.9833721 , 0.04092345],
       [0.9824467 , 0.05052687],
       [0.98621017, 0.0392889 ],
       [1.0015066 , 0.04561822]], dtype=float32)