In [1]:
import numpy as np

from util_classes import Dataset
from datasets import read_ML_cup, read_monks, train_valid_split
from loss import LossFunction
from optimizer import Optimizer

Test Data Modules

In [2]:
data = read_monks(1, "train")
ids = data.ids.copy()
data.shape

(124, [6, 1])

In [3]:
train, valid = train_valid_split(data)

In [4]:
train.shape, valid.shape

((86, [6, 1]), (38, [6, 1]))

In [5]:
# all elements are present
sorted(list(train.ids) + list(valid.ids)) == sorted(ids)

True

### Trainer Class

In [6]:
class Estimator:
    def __init__(self, net, *,
        loss=LossFunction(), optimizer=Optimizer(),
        batchsize=1, start_it=0, seed=None
    ):
        self.net = net
        self.t = start_it
        self.loss = loss
        self.optimizer = optimizer
        self.batchsize = batchsize
        self.rng = np.random.default_rng(seed)
        if seed != None:
            # re-randomize all layers with new rng
            self.net.rng = self.rng
    
    def update_params(net=None, loss=None, optimizer=None, batchsize=None, seed=None):
        self.t = 0
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        if net is not None:
            self.net = net
            self.net.rng = self.rng
        if loss is not None:
            self.loss = loss
        if optimizer is not None:
            self.optimizer = optimizer
        if batchsize is not None:
            self.batchsize = batchsize
    
    @staticmethod
    def get_minibatches(x, y, batchsize):
        size = x.shape[0]
        batchtotal, remainder = divmod(size, batchsize)
        for i in range(batchtotal):
            mini_x = x[i*batchsize:(i+1)*batchsize]
            mini_y = y[i*batchsize:(i+1)*batchsize]
            yield mini_x, mini_y
        if remainder > 0:
            yield (
                x[batchtotal*batchsize:],
                y[batchtotal*batchsize:]
            )
        
    def train(self, dataset, *, n_epochs=1, callback=print, mb_callback=None):
        for i in range(n_epochs):
            # permute dataset
            permutation = self.rng.permutation(dataset.shape[0])
            x = dataset.data[permutation]
            y = dataset.labels[permutation]
            # iterate minibatches
            avg_loss, batchcount = 0., np.ceil(x.shape[0] / self.batchsize)
            for b, (mini_x, mini_y) in enumerate(Estimator.get_minibatches(x, y, self.batchsize)):
                pred = self.net.foward(mini_x)
                loss = self.loss.foward(pred, mini_y)
                if mb_callback is not None:
                    record = {"epoch": self.t, "batch": b, "loss": loss}
                    mb_callback(self.t, b, loss)
                avg_loss += loss
                loss_grad = self.loss.backward()
                self.net.backward(loss_grad)
                self.net.optimize(self.optimizer)
            avg_loss /= batchcount
            self.t += 1
            record = {"epoch": self.t, "loss": avg_loss}
            callback(record)

In [7]:
from nn import NeuralNetwork, LinearLayer, ActivationFunction

In [8]:
trainer = Estimator(
    NeuralNetwork([
        LinearLayer((9, 8)),
        ActivationFunction(),
        LinearLayer((8, 8)),
        ActivationFunction(),
        LinearLayer((8, 2))
    ]),
    batchsize=64,
    seed=123,
    optimizer=Optimizer(eta=1e-3, l2_coeff=1e-2, alpha=0.2)
)

In [9]:
train, valid = train_valid_split(read_ML_cup("train"), splits=(0.7, 0.3))

In [10]:
trainer.train(train, n_epochs=100)

{'epoch': 1, 'loss': 477.2846025302484}
{'epoch': 2, 'loss': 355.36805692778324}
{'epoch': 3, 'loss': 340.72458828744124}
{'epoch': 4, 'loss': 351.21379834771705}
{'epoch': 5, 'loss': 358.029870328806}
{'epoch': 6, 'loss': 360.8883135887121}
{'epoch': 7, 'loss': 365.9556978451745}
{'epoch': 8, 'loss': 367.01020554761595}
{'epoch': 9, 'loss': 364.1841146261983}
{'epoch': 10, 'loss': 365.2386041262139}
{'epoch': 11, 'loss': 366.7439068194883}
{'epoch': 12, 'loss': 364.1846029862737}
{'epoch': 13, 'loss': 366.01150312002534}
{'epoch': 14, 'loss': 369.0559329064739}
{'epoch': 15, 'loss': 363.9748320545789}
{'epoch': 16, 'loss': 366.5443436439703}
{'epoch': 17, 'loss': 368.2874410174959}
{'epoch': 18, 'loss': 368.45876551320373}
{'epoch': 19, 'loss': 369.65054155020914}
{'epoch': 20, 'loss': 368.87109232262634}
{'epoch': 21, 'loss': 367.59148472713656}
{'epoch': 22, 'loss': 365.680379710991}
{'epoch': 23, 'loss': 368.1960237303889}
{'epoch': 24, 'loss': 366.4120518459492}
{'epoch': 25, 'los