In [1]:
import numpy as np

from ml.models import MLPClassifier

from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder

# Testing

In [2]:
def print_metadata(model):
    params = 0
    
    message = f'''
    ============================================================================================
    Layers size: {tuple(model.layers_size)}
    ============================================================================================
    Activations: {tuple(model.activations)}'''
    
    for key, value in model.parameters.items():
        params += value.shape[0] * value.shape[1]
        
    message = f'''{message}
    ============================================================================================
    Trainable parameters: {params}
    ============================================================================================
    '''
    
    print(message)

In [3]:
HYPERPARAMS = {
    'architecture': {
        'n_features': 64, 
        'n_classes': 10,
        'hidden_layers_size': (32,16),
        'activations': ('tanh','tanh')
    },
    'train': {
        'learning_rate': 0.01,
        'epochs': 300
    }
}

In [4]:
dataset = load_digits()
scaler = StandardScaler()
encoder = OneHotEncoder(sparse=False).fit(np.array([i for i in range(10)]).reshape(-1, 1))

X, y = dataset.data, dataset.target

X_std = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_std, y, test_size=0.2)

X_train, X_test, y_train = X_train.T, X_test.T, encoder.transform(y_train.reshape(-1, 1)).T

In [5]:
model = MLPClassifier(**HYPERPARAMS['architecture'])
print_metadata(model)


    Layers size: (64, 32, 16, 10)
    Activations: ('tanh', 'tanh', 'softmax')
    Trainable parameters: 2778
    


In [6]:
model.fit(X_train, y_train, **HYPERPARAMS['train'])

In [7]:
y_pred = model.predict(X_test)

In [8]:
print('Accuracy:', accuracy_score(y_test, y_pred))

Accuracy: 0.7916666666666666
