In [9]:
from sklearn.datasets import load_iris
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import numpy as np
from Gnarl import Gnarl

data = load_iris()

X = data.data
y = data.target

enc = OneHotEncoder()
y = enc.fit_transform(y.reshape(-1,1)).toarray()

y_dim = 3

X = (X - np.mean(X, axis=0))/np.std(X, axis=0)

# Train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

nn_options = {
    'activation': 'leaky_relu',
    'learning_rate': 1e-4,
    'regularization': 0.,
    'random_state': 50,
    'verbose': True,
    'loss': 'cross_entropy',
    'batch_size': 5
}

gnarl = Gnarl(X, y, **nn_options)

# Add first hidden layer
gnarl.hidden_layer(10, activation='leaky_relu')

# Add second hidden layer
gnarl.hidden_layer(5, activation='leaky_relu')

# Add output layer
gnarl.hidden_layer(y_dim, activation='none')

# Connect the layers
gnarl.connect_layers()

# Train the model by sampling from all training data
gnarl.fit(X_train, y_train, solver='gd', epochs=1)

# Predict outputs
y_pred = gnarl.predict(X_test, truncate_labels=True)

Training model...
Total number of samples: 120
Steps per epoch: 24
Epoch: 1, Loss: 3.406
Epoch: 501, Loss: 0.205
Epoch: 1000, Loss: 0.061

In [5]:
y_pred
#y_pred = np.exp(y_pred)/np.sum(np.exp(y_pred))
#y_pred_ohe = enc.transform(y_pred.reshape(-1,1)).toarray().astype(int)

array([1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0,
       1, 1, 0, 0, 1, 1, 1])

In [6]:
for i,j in zip(y_pred, y_test):
    print(i, j)

1 [ 0.  0.  1.]
1 [ 0.  0.  1.]
0 [ 1.  0.  0.]
1 [ 0.  0.  1.]
1 [ 0.  1.  0.]
0 [ 1.  0.  0.]
1 [ 0.  0.  1.]
0 [ 1.  0.  0.]
1 [ 0.  0.  1.]
1 [ 0.  1.  0.]
1 [ 0.  1.  0.]
1 [ 0.  0.  1.]
1 [ 0.  0.  1.]
1 [ 0.  1.  0.]
1 [ 0.  0.  1.]
1 [ 0.  1.  0.]
1 [ 0.  0.  1.]
1 [ 0.  0.  1.]
0 [ 1.  0.  0.]
1 [ 0.  0.  1.]
1 [ 0.  1.  0.]
0 [ 1.  0.  0.]
0 [ 1.  0.  0.]
1 [ 0.  1.  0.]
1 [ 0.  1.  0.]
0 [ 1.  0.  0.]
0 [ 1.  0.  0.]
1 [ 0.  1.  0.]
1 [ 0.  0.  1.]
1 [ 0.  0.  1.]
