In [1]:
import sys
sys.path.append('/media/data/studia/sieci_neuronowe_7_sem')
from nn.metrics import accuracy
%load_ext autoreload
%autoreload 2

from functools import partial
import json
from pathlib import Path

import matplotlib.pyplot as plt
from torchvision.datasets import MNIST

from nn.init import init_weights
from nn.layers.activations import ReLU, Tanh, Sigmoid
from nn.layers.linear import Linear
from nn.layers.model import Model
from nn.loss import cross_entropy
from nn.optimization import sgd
from nn.training import Trainer, DataGenerator

In [2]:
dataset = MNIST('./data', train=True, download=True)
x_train = dataset.data.numpy()
x_train = x_train
y_train = dataset.targets.numpy()

dataset = MNIST('./data', train=False, download=True)
x_test = dataset.data.numpy()
y_test = dataset.targets.numpy()

def preprocess_x(x):
    return x.reshape(-1, 784) / 255.

train_set = list(zip(preprocess_x(x_train), y_train))
test_set = list(zip(preprocess_x(x_test), y_test))


In [6]:
from time import time

def time_it(func):
    def inner(*args, **kwargs):
        start = time()
        result = func(*args, **kwargs)
        stop = time()
        return result, stop - start
    return inner

@time_it
def experiment(hidden_size=256, learning_rate=0.001, batch_size=128, std=0.05, activation_func=ReLU):
    model = Model(
        Linear(784, hidden_size, init_func=partial(init_weights, std=std)),
        activation_func(),
        Linear(hidden_size, 10),
    )

    trainer = Trainer(
        optimizer_func=partial(sgd, learning_rate=1e-3),
        loss_func=cross_entropy,
        epochs=30,
        metrics={'accuracy': accuracy},
        monitor='accuracy',
        mode='max',
        delta=0.00,
        patience=30,

    )

    train_generator = DataGenerator(train_set, batch_size=batch_size)
    test_generator = DataGenerator(test_set, batch_size=128)
    return list(trainer.train(model, train_generator, test_generator))

Loss: 0.7536, Patience: 30: : 469it [00:18, 26.01it/s]
Loss: 0.8102, Patience: 30:  17%|█▋        | 78/468 [00:03&lt;00:18, 20.90it/s]


KeyboardInterrupt: 

In [13]:
experiment_specs = {
    'hidden_size': [16, 64, 128, 256, 1024],
    'learning_rate': [1e-5, 1e-4, 1e-3, 1e-1],
    'batch_size': [1, 64, 128, 256, 512],
    'std': [0.0001, 0.1, 0.3, 0.8],
    'activation_func': [ReLU, Tanh, Sigmoid]
}

result_path = Path('./result.json')
result = {}
for arg, values in experiment_specs.items():
    result[arg] = {}
    for value in values:
        print(f'{arg}: {value}\n')
        result[arg][str(value)] = {}
        metrics, experiment_time = experiment(**{arg: value})
        result[arg][str(value)]['value'] = metrics
        result[arg][str(value)]['time'] = experiment_time

        print(f'Best validation accuracy: {max(a["accuracy"] for _, a in metrics):.4}\n')

        with result_path.open('w') as file:
            json.dump(result, file, indent=4)


Loss: 2.393, Patience: 30:   1%|          | 4/468 [00:00&lt;00:12, 37.00it/s]hidden_size: 16

Loss: 2.082, Patience: 30: : 469it [00:07, 65.24it/s]                       
Loss: 1.951, Patience: 30: : 469it [00:07, 61.42it/s]                       
Loss: 1.824, Patience: 30: : 469it [00:06, 74.36it/s]                       
Loss: 1.679, Patience: 30: : 469it [00:09, 51.93it/s]                       
Loss: 1.5, Patience: 30: : 469it [00:06, 70.72it/s]
Loss: 1.337, Patience: 30: : 469it [00:04, 100.62it/s]                       
Loss: 1.274, Patience: 30: : 469it [00:09, 51.93it/s]                       
Loss: 1.07, Patience: 30: : 469it [00:08, 56.02it/s]                       
Loss: 1.014, Patience: 30: : 469it [00:09, 51.88it/s]                       
Loss: 0.9552, Patience: 30: : 469it [00:10, 46.77it/s]                       
Loss: 0.7831, Patience: 30: : 469it [00:08, 52.77it/s]
Loss: 0.7528, Patience: 30: : 469it [00:09, 51.14it/s]                       
Loss: 0.9143, Patience: 30:

KeyboardInterrupt: 

array([[-0.00903136, -0.07781877,  0.03398565, ...,  0.05009631,
         0.07205296, -0.02185572],
       [ 0.04742982, -0.01818698, -0.02515133, ..., -0.01136244,
        -0.01199874, -0.00294767],
       [-0.03188821,  0.05495292, -0.00668778, ..., -0.02476683,
         0.00990076, -0.07795491],
       ...,
       [ 0.11051129, -0.05569008,  0.04533202, ...,  0.04103275,
         0.04337337, -0.0471332 ],
       [-0.01428359,  0.03714409, -0.03014214, ..., -0.01259957,
        -0.03742836, -0.01441736],
       [-0.07751212, -0.05631129, -0.07758386, ..., -0.00286561,
        -0.04752837,  0.01186537]])