In [None]:
import os
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

import matplotlib.pyplot as plt
import seaborn as sns

from models import MLP, NAC, NALU

In [None]:
arithmetic_functions = {
    'add': lambda x, y: x + y,
    'sub': lambda x, y: x - y,
    'mul': lambda x, y: x * y,
    'div': lambda x, y: x / y,
    'squared': lambda x: torch.pow(x, 2),
    'sqrt': lambda x: torch.sqrt(x)
}

In [None]:
models = {
    'None': None,
    'NAC': None,
    'NALU': None,
    'ReLU6': nn.ReLU6(),
    'Tanh': nn.Tanh(),
    'Sigmoid': nn.Sigmoid(),
    'Softsign': nn.Softsign(),
    'SELU': nn.SELU(),
    'ELU': nn.ELU(),
    'ReLU': nn.ReLU()
}

In [None]:
def generate_data(dim, fn, support):
    X = torch.FloatTensor(*dim).uniform_(*support)
    y = fn(*[X[:, i] for i in range(dim[1])]).unsqueeze(1)
    return X, y

In [None]:
def train(model, optimizer, criterion, data, target, n_epochs):

    for epoch in range(n_epochs):

        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, target)
        m = torch.mean(torch.abs(target - output))

        loss.backward()
        optimizer.step()

        if epoch % 1000 == 0:
            print('Epoch {:05}:\t'
                  'Loss = {:.5f}\t'
                  'MEA = {:.5f}'.format(epoch, loss, m))

In [None]:
def test(model, data, target):

    with torch.no_grad():
        output = model(data)
        m = torch.mean(torch.abs(target - output))
        return m

In [None]:
hidden_dim = 2
n_layers = 2

interp_support = [1, 100]
extrap_support = [101, 200]

n_epochs = 10_000
lr = 0.01

In [None]:
results = []
for fn_type, fn in arithmetic_functions.items():

    if fn_type in ['squared', 'sqrt']:
        in_dim = 1
    else:
        in_dim = 2

    print('-> Testing function: {}'.format(fn_type))

    Xtrain, ytrain = generate_data(
        dim=(500, in_dim), fn=fn, support=interp_support
    )

    Xtest_interp, ytest_interp = generate_data(
        dim=(50, in_dim), fn=fn, support=interp_support
    )

    Xtest_extrap, ytest_extrap = generate_data(
        dim=(50, in_dim), fn=fn, support=extrap_support
    )

    print('-> Training random.')
    net = MLP(in_dim=in_dim, hidden_dim=hidden_dim, out_dim=1, n_layers=n_layers, act=None)
    
    random_mse_interp = torch.mean(torch.stack([test(net, Xtest_interp, ytest_interp) for i in range(100)])).item()
    random_mse_extrap = torch.mean(torch.stack([test(net, Xtest_extrap, ytest_extrap) for i in range(100)])).item()

    for name, model in models.items():

        if name == 'NAC':
            net = NAC(in_dim=in_dim, hidden_dim=hidden_dim, out_dim=1, n_layers=n_layers)
        elif name == 'NALU':
            net = NALU(in_dim=in_dim, hidden_dim=hidden_dim, out_dim=1, n_layers=n_layers)
        else:
            net = MLP(in_dim=in_dim, hidden_dim=hidden_dim, out_dim=1, n_layers=n_layers, act=model)

        print('-> Running: {}'.format(name))
        optimizer = torch.optim.RMSprop(net.parameters(), lr=lr)
        criterion = nn.MSELoss()
        train(net, optimizer, criterion, Xtrain, ytrain, n_epochs)

        interp_mse = test(net, Xtest_interp, ytest_interp).item()
        extrap_mse = test(net, Xtest_extrap, ytest_extrap).item()

        _tmp_interp = {
            'type': 'interp',
            'fn_type': fn_type,
            'activation': name,
            'mse': interp_mse,
            'random_mse': random_mse_interp
        }

        _tmp_extrap = {
            'type': 'extrap',
            'fn_type': fn_type,
            'activation': name,
            'mse': extrap_mse,
            'random_mse': random_mse_extrap
        }

        results.append(_tmp_interp)
        results.append(_tmp_extrap)

In [None]:
df_results = pd.DataFrame(results)
df_results['normalised_mse'] = df_results.apply(lambda row: 100.0 * row['mse'] / row['random_mse'], axis=1)
df_results.to_csv('results.csv')

df_results

In [None]:
df_interp = df_results[df_results['type'] == 'interp']
df_extrap = df_results[df_results['type'] == 'extrap']

In [None]:
fig, axs = plt.subplots(6, 2, figsize=(20, 20))
axs = axs.flatten()

for idx, fn in enumerate(df_interp.fn_type.unique()):
    sns.barplot(x='activation', y='normalised_mse', data=df_interp[df_interp['fn_type'] == fn], palette='YlOrRd', ax=axs[2 * idx])
    sns.barplot(x='activation', y='normalised_mse', data=df_extrap[df_extrap['fn_type'] == fn], palette='YlOrRd', ax=axs[2 * idx + 1])
    axs[2 * idx].set_title(f'interp function = {fn}')
    axs[2 * idx + 1].set_title(f'extrap function = {fn}')
    
plt.tight_layout()
plt.savefig('normalised_mse.png', bbox_inches='tight')