In [1]:
import os
import sys

repo_root_path = os.path.abspath(os.path.pardir)
if repo_root_path not in sys.path:
    sys.path.append(repo_root_path)

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_classification

In [2]:
from hebbnets.networks import MultilayerQAGRELNetwork
from hebbnets.utils import softmax

In [3]:
num_categories = 5

data_X, data_Y = make_classification(
    n_samples=num_categories * 100,
    n_features=num_categories + 5,
    n_informative=num_categories,
    n_redundant=2,
    n_classes=num_categories,
    n_clusters_per_class=3,
    class_sep=10.0,
    scale=0.1
)

In [4]:
input_layer_size = data_X.shape[1]
nodes_per_layer = [20, 10, num_categories]

qagrel_network = MultilayerQAGRELNetwork(
    input_layer_size,
    nodes_per_layer,
    act_type='relu',
)

In [5]:
%pdb on
qagrel_network.train(
    list(zip(data_X, data_Y)),
    num_epochs=100
)

Automatic pdb calling has been turned ON


In [6]:
import pandas as pd
results = []
for x_in, y_targ in zip(data_X, data_Y):
    qagrel_network.propogate_input(x_in)
    score = softmax(qagrel_network.layers[-1].activation, temp=0.1).ravel()
    results.append(
        {
            'true_categ': y_targ,
            'pred_categ': np.argmax(score),
            'is_corr': np.argmax(score) == y_targ,
            'score': score[y_targ]
        }
    )

df_results = pd.DataFrame(results)

In [7]:
qagrel_network.layers[-1].activation

array([[-0.03355557],
       [-0.01015246],
       [-0.03339144],
       [ 0.07978038],
       [ 0.19300272]])

In [8]:
df_results.is_corr.mean()

0.856

In [9]:
df_results.groupby('true_categ')['is_corr'].sum()

true_categ
0    100.0
1     49.0
2     84.0
3    100.0
4     95.0
Name: is_corr, dtype: float64

In [10]:
df_results.groupby('true_categ')['score'].mean()

true_categ
0    0.375861
1    0.276465
2    0.311114
3    0.397800
4    0.380774
Name: score, dtype: float64

In [11]:
df_results.head(10)

Unnamed: 0,is_corr,pred_categ,score,true_categ
0,True,2,0.286398,2
1,True,2,0.334758,2
2,True,3,0.379054,3
3,False,4,0.270017,1
4,True,4,0.223024,4
5,True,1,0.338156,1
6,True,2,0.354491,2
7,True,3,0.42423,3
8,True,2,0.277167,2
9,True,2,0.249129,2


In [12]:
for ff_lay, fb_lay in zip(qagrel_network.layers, qagrel_network.fb_layers[::-1]):
    print(np.allclose(
        ff_lay.input_weights.T,
        fb_lay.input_weights
    ))

False
False
False


In [13]:
qagrel_network.layers[-1].input_weights

array([[-3.37306474e-01, -7.36077091e-02, -3.71071572e-02,
        -6.41407044e-02, -1.23745432e-01],
       [ 1.12547570e-02,  4.30995770e-01, -4.80011181e-02,
        -1.41053095e-01, -2.13090786e-01],
       [ 2.77345746e-01, -1.83843439e-01, -2.40125999e-01,
        -4.25848826e-01,  1.57592866e-01],
       [-1.15511323e-01,  1.24262560e-01,  5.48002527e-01,
        -2.40706472e-01,  6.06197327e-03],
       [-1.36534863e-01,  1.11928864e-01,  3.68969750e-01,
         6.56637538e-05,  1.85196813e-02],
       [ 3.64977463e-01, -3.56112408e-01,  3.70349998e-02,
         2.63525996e-01, -9.77041261e-02],
       [-3.57762144e-02,  1.98405666e-01,  1.36010818e-01,
        -1.79270802e-01,  2.06705343e-01],
       [ 2.02205896e-01, -2.23799149e-01,  1.22228147e-01,
         2.02091831e-01, -1.93776743e-01],
       [-8.83909505e-02, -2.45503141e-01, -2.41082331e-01,
         2.93165206e-01, -1.16693195e-01],
       [-1.21960053e-01, -2.17394177e-01,  7.03981329e-02,
        -4.22151674e-01