# Import

In [None]:
import numpy as np
import src.datasets.load_TI46 as ti46
import src.echo_state_network as esn
import src.online_training as online
import src.networks as nets

# Load

In [None]:
x_train, y_train, x_test, y_test, N_LABELS = ti46.load_and_process_ti46()

y_train_reals = [int(np.argmax(x)) for x in y_train.tolist()]
y_test_reals = [int(np.argmax(x)) for x in y_test.tolist()]

# Transform using ESN

In [None]:
N_RESERVOIR = 200
reservoir = esn.ESN(nIn=x_train[0].shape[-1], nRes=N_RESERVOIR)  # can add kwargs

z_train = reservoir.transform(x_train)
z_test = reservoir.transform(x_test)

# Prepare Dataloaders

In [None]:
(
    z_train_flat,
    y_train_1h,
    z_valid_flat,
    y_valid_1h,
    z_test_flat,
    y_test_labels,
    z_split_nested_train,
    y_split_train,
) = online.Classifier.prepare_training_data(z_train, y_train, z_test, y_test)

# Train Online Classifier

In [None]:
model = online.Classifier(
    Nin=N_RESERVOIR, Nout=N_LABELS, model=nets.TwoLayerMlp(N_RESERVOIR, N_LABELS)
)

train_history = model.training_session(
    z_train_flat,
    y_train_1h,
    z_valid_flat,
    y_valid_1h,
    lr=0.001,
    l2_lambda=0,
)

# Evaluate Model

In [None]:
online.plot_results_array(train_history)