In [2]:
import numpy as np
import torch # neural network library to test our implementation

import sys
import os

module_path = os.path.join(os.getcwd(), 'modules')
sys.path.append(module_path)
from mlp import MLP, CompoundNN
from activation_functions import ReLU

#### Função para testar método forward em MLP

In [3]:
def test_forward_function():
    mlp = MLP(6, 5)
    x = np.random.randn(1, 6)
    out = mlp.forward(x)

    mlp_torch = torch.nn.Linear(6, 5) # Linear pois nossa MLP ainda não possui função de ativação
    mlp_torch.weight.data = torch.from_numpy(mlp.W).type(torch.float) # Copia pesos
    mlp_torch.bias.data = torch.from_numpy(mlp.b).type(torch.float) # Copia bias
    out_torch = mlp(torch.from_numpy(x).type(torch.float)) # Calcula saída e transforma para numpy
    out_torch = out_torch.data.numpy()

    error = ((out - out_torch) ** 2).mean() # Calcula erro médio quadrático
    if error < 1e-8:
        return 0
    else:
        return -1

#### Função para testar Compound Neural Network (Sequencia de MLP)

In [4]:
def test_compound_nn():
    mlp1 = MLP(6,5)
    relu1 = ReLU()
    mlp2 = MLP(5,4)
    relu2 = ReLU()

    nn = CompoundNN([mlp1, relu1, mlp2, relu2])

    x = np.random.randn(1, 6)

    out1 = relu2(mlp2(relu1(mlp1(x))))
    out2 = nn(x)

    error = ((out1 - out2)**2).mean()

    if error < 1e-8:
        return 0
    else:
        return -1

#### Função para testar save and load

In [5]:
def test_save_and_load():
    mlp1 = MLP(6,5)
    relu1 = ReLU()
    mlp2 = MLP(5,4)
    relu2 = ReLU()
    nn1 = CompoundNN([mlp1, relu1, mlp2, relu2])

    mlp1 = MLP(6,5)
    relu1 = ReLU()
    mlp2 = MLP(5,4)
    relu2 = ReLU()
    nn2 = CompoundNN([mlp1, relu1, mlp2, relu2])

    x = np.random.randn(1, 6)

    nn1.save('nn1')
    nn2.load('nn1')

    out1 = nn1(x)
    out2 = nn2(x)

    error = ((out1 - out2)**2).mean()

    if error < 1e-8:
        return 0
    else:
        return -1

#### Assertions

In [6]:
assert test_compound_nn() == 0
assert test_forward_function() == 0
assert test_save_and_load() == 0