In [None]:
from ddnn.nn import *
from ddnn.validation import *
from ddnn.data import *

estimator = Estimator(
    net=NeuralNetwork([
        LinearLayer((17, 4)),
        ActivationFunction("ReLU"),
        LinearLayer((4, 1)),
        ActivationFunction("logistic")
    ]),
    loss=LossFunction("MSE"),
    optimizer=Optimizer("SGD", learning_rate=0.5, momentum_coefficient=0.5, l2_coefficient=0),
    # optimizer=Optimizer("Adam", learning_rate=0.01, l2_coefficient=2e-3),
    batchsize=16,
    initializer=Initializer("glorot_uniform"),
    seed=777
)
early_stopping = None
epochs = 50
dataset_type = ("monk", 2)
log_every = 5
losses = ["MSE", "binary_accuracy"]

In [None]:
if len(dataset_type) > 1:
    traindata = read_monks(dataset_type[1], "train")
    traindata = onehot_encoding(data=traindata)

    testdata = read_monks(dataset_type[1], "test")
    testdata = onehot_encoding(data=testdata)
else:
    # todo ml_cup
    raise ValueError()

In [None]:
traindata.shape, testdata.shape

In [None]:
testlogger = Logger(estimator, losses=losses, training_set=traindata, validation_set=testdata, every=log_every)
if early_stopping is not None:
    teststopper = TrainingThresholdStopping(estimator, early_stopping)
    def callback(record):
        testlogger(record)
        teststopper(record)
else:
    def callback(record):
        print(record)
        testlogger(record)

In [None]:
estimator.train(traindata, callback=callback, n_epochs=epochs)

In [None]:
res = estimator.evaluate(losses=losses, dataset=traindata)
res

In [None]:
res = estimator.evaluate(losses=losses, dataset=testdata)
res

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from ipywidgets import interact

In [None]:
%matplotlib ipympl

# 1 plot with train and valid

fig, ax = plt.subplots()

@interact(
    loss = testlogger._losses,
)
def plot_results(loss):
    fig.tight_layout()
    ax.clear()
    for where in ["train", "valid"]:
        y = testlogger._scores[0]["folds"][0][where][loss]
        if loss == "binary_accuracy":
            # todo fix to show last not best
            best = max(y)
            form = "{:.2}"
            logplot = False
        else:
            best = min(y)
            form = "{:.2E}"
            logplot = True
        # scale to resemble number of epochs instead of plot points
        ticks_x = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x*testlogger._every))
        ax.xaxis.set_major_formatter(ticks_x)
        if logplot:
            ax.set_yscale("log")
        else:
            ax.set_yscale("linear")
        
        ax.plot(y, label=f"{where}: {form.format(best)}")
        ax.legend()