In [1]:
%pylab inline

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [None]:
import os
if "RANK_TABLE_FILE" in os.environ:
    del os.environ["RANK_TABLE_FILE"]
import matplotlib.pyplot as plt
import numpy as np
import mindspore as ms
from mindspore import nn, ops, Tensor
import nalu_ms,mlp_ms

def create_data(min_val, max_val, n_elts, fun_op, single_dim=False):
    if single_dim:
        x = np.random.randint(low=min_val, high=max_val + 1, size=(n_elts, 1))
        y = fun_op(x).reshape(-1)
    else:
        x = np.random.randint(low=min_val, high=max_val + 1, size=(n_elts, 2))
        y = fun_op(x[:, 0], x[:, 1])

    x = Tensor(x, dtype=ms.float32)
    y = Tensor(y, dtype=ms.float32)
    return x, y

In [None]:
def train(model, data, n_epochs, optimizer_class, lr, verbose=False):
    x, y = data
    loss_fn = nn.MSELoss()
    opt = optimizer_class(model.trainable_params(), learning_rate=lr)
    abs_fn = ops.Abs()
    mean_fn = ops.ReduceMean()

    class TrainStep(nn.Cell):
        def __init__(self, network, optimizer):
            super(TrainStep, self).__init__(auto_prefix=False)
            self.network = network
            self.network.set_train()
            self.optimizer = optimizer
            self.loss_fn = loss_fn
            self.abs_fn = abs_fn
            self.mean_fn = mean_fn
            self.grad = ops.value_and_grad(self.forward_fn, grad_position=None, weights=network.trainable_params())

        def forward_fn(self, x, y):
            logits = self.network(x)
            loss = self.loss_fn(logits.reshape(-1), y)
            return loss

        def construct(self, x, y):
            loss, grads = self.grad(x, y)
            self.optimizer(grads)
            return loss

    train_step = TrainStep(model, opt)

    for epoch in range(n_epochs):
        loss = train_step(x, y)
        if verbose and epoch % 10000 == 0:
            pred = model(x).reshape(-1)
            mae = mean_fn(abs_fn(pred - y))
            print(f'Epoch: {epoch}: mse={round(loss.asnumpy().item(), 2)}; mae={round(mae.asnumpy().item(), 2)}')

def test(model, data):
    x, y = data
    loss_fn = nn.MSELoss()
    abs_fn = ops.Abs()
    mean_fn = ops.ReduceMean()

    pred = model(x).reshape(-1)
    mse = loss_fn(pred, y)
    mae = mean_fn(abs_fn(pred - y))

    return round(mse.asnumpy().item(), 2), round(mae.asnumpy().item(), 2)


fun_dict = {
    'add': lambda x, y: x + y,
    'sub': lambda x, y: x - y,
    'mul': lambda x, y: x * y,
    'div': lambda x, y: x / y,
    'sqr': lambda x: np.power(x, 2),
    'sqrt': lambda x: np.sqrt(x)
}

models = {
    'tanh': nn.Tanh,
    'sigmoid': nn.Sigmoid,
    'relu6': nn.ReLU6,
    'softsign': nn.Softsign,
    'selu': nn.SeLU,
    'elu': nn.ELU,
    'relu': nn.ReLU,
    'none': None,
    'NAC': None,
    'NALU': None
}


N_LAYERS = 2
OUT_DIM = 1
HIDDEN_DIM = 2
N_EPOCHS = int(1e5)
OPTIMIZER = nn.RMSProp
LR = 0.01
RANGE_INTER = (1, 100)
RANGE_EXTRA = (101, 200)
N_ELTS = 500



interpolation_logs = {}
extrapolation_logs = {}

for fun_name, fun_op in fun_dict.items():
    if fun_name in ['sqr', 'sqrt']:
        single_dim = True
        in_dim = 1
    else:
        single_dim = False
        in_dim = 2

    train_data = create_data(*RANGE_INTER, N_ELTS, fun_op, single_dim)
    test_data_interpolation = create_data(*RANGE_INTER, N_ELTS, fun_op, single_dim)
    test_data_extrapolation = create_data(*RANGE_EXTRA, N_ELTS, fun_op, single_dim)

    interpolation_logs[fun_name] = {}
    extrapolation_logs[fun_name] = {}

    for model_name, act in models.items():
        if model_name == 'NAC':
            model = nalu_ms.StackedNAC(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM)
        elif model_name == 'NALU':
            model = nalu_ms.StackedNALU(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM)
        else:
            model = mlp_ms.MLP(N_LAYERS, in_dim, OUT_DIM, HIDDEN_DIM, act)

        train(model, train_data, N_EPOCHS, OPTIMIZER, LR)
        _, mae_inter = test(model, test_data_interpolation)
        _, mae_extra = test(model, test_data_extrapolation)

        interpolation_logs[fun_name][model_name] = mae_inter
        extrapolation_logs[fun_name][model_name] = mae_extra

        print(f'{fun_name.ljust(10)}: {model_name.ljust(10)}: mae inter: {mae_inter}, mae extra: {mae_extra}')

        del model
    del train_data
    del test_data_interpolation
    del test_data_extrapolation



add       : tanh      : mae inter: 15.33, mae extra: 159.59
add       : sigmoid   : mae inter: 6.11, mae extra: 132.23
add       : relu6     : mae inter: 8.97, mae extra: 139.96
add       : softsign  : mae inter: 7.23, mae extra: 131.38
add       : selu      : mae inter: 1.07, mae extra: 3.15
add       : elu       : mae inter: 3.47, mae extra: 119.99
add       : relu      : mae inter: 1.04, mae extra: 3.06
add       : none      : mae inter: 1.07, mae extra: 3.07
add       : NAC       : mae inter: 101.44, mae extra: 294.36
add       : NALU      : mae inter: 101.85, mae extra: 297.42
sub       : tanh      : mae inter: 13.47, mae extra: 22.6
sub       : sigmoid   : mae inter: 6.64, mae extra: 15.28
sub       : relu6     : mae inter: 34.37, mae extra: 34.19
sub       : softsign  : mae inter: 7.61, mae extra: 13.08
sub       : selu      : mae inter: 1.4, mae extra: 4.11
sub       : elu       : mae inter: 1.61, mae extra: 4.85
sub       : relu      : mae inter: 1.5, mae extra: 3.75
sub      

In [None]:
def autolabel(rects, ax):
    for rect in rects:
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width() / 2., 0.9 * height,
                str(height), ha='center', va='bottom')

idx = 1
n_rows = len(interpolation_logs.keys())
figure = plt.figure(figsize=(20, 40))

for fun_name in interpolation_logs.keys():
    ax1 = figure.add_subplot(n_rows, 2, idx)

    items = list(interpolation_logs[fun_name].keys())
    y_pos = np.arange(len(items))
    mae = list(interpolation_logs[fun_name].values())

    rect1 = ax1.bar(y_pos, mae, align='center', alpha=0.5)
    ax1.set_xticks(np.arange(len(items)))
    ax1.set_xticklabels(items)
    ax1.set_ylabel('mae')
    ax1.set_title(f'{fun_name} (interpolation [{RANGE_INTER[0]}, {RANGE_INTER[1]}])')
    autolabel(rect1, ax1)

    ax2 = figure.add_subplot(n_rows, 2, idx + 1)

    items = list(extrapolation_logs[fun_name].keys())
    y_pos = np.arange(len(items))
    mae = list(extrapolation_logs[fun_name].values())

    rect2 = ax2.bar(y_pos, mae, align='center', alpha=0.5)
    ax2.set_xticks(np.arange(len(items)))
    ax2.set_xticklabels(items)
    ax2.set_ylabel('mae')
    ax2.set_title(f'{fun_name} (extrapolation [{RANGE_EXTRA[0]}, {RANGE_EXTRA[1]}])')
    autolabel(rect2, ax2)

    idx += 2

plt.show()