# Training with Open Loop Oscillator

#### 0. IPython magic and imports

In [4]:
%load_ext autoreload

import sys    
sys.path.append('/home/gurbain/hyq_ml')

import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

#### 1. Declare the network structure and the training parameters

In [11]:
%autoreload
import network
import utils

# Readout layer and backprop algo
nn_layers       = [[('relu', 40)],  # [('osc',)],                  
                   [('relu', 40)],
                   [('relu', 40)]
                  ]

batch_size      = 2048
max_epochs      = 200
stop_patience   = 150

regularization  = 0.001
metric          = "mae"
optimizer       = "adam"

# ESN
n_res           = 40
n_read          = 30
damping         = 0.1
sparsity        = 0.2
spectral_radius = 0.95
noise           = 0.001
in_mask         = [True]  # None  # No output is injected in the ESN
out_mask        = [True] * 24  # None  # All readouts ouputs are fed back in the ESN

# Other
data_file       = "/home/gurbain/hyq_ml/data/sims/tc.pkl"
save_folder     = "/home/gurbain/hyq_ml/data/nn_learning/" + utils.timestamp()
verbose         = -3
utils.mkdir(save_folder)

#### 2. Instanciate the network

In [12]:
nn = network.NN(nn_layers=nn_layers,
                optim=optimizer,
                metric=metric,
                batch_size=batch_size,
                max_epochs=max_epochs,
                regularization=regularization,
                esn_n_res=n_res,
                esn_n_read=n_read,
                esn_in_mask=in_mask,
                esn_out_mask=out_mask,
                esn_spec_rad=spectral_radius,
                esn_damping=damping,
                esn_sparsity=sparsity,
                esn_noise=noise,
                data_file=data_file,
                save_folder=save_folder,
                checkpoint=False,
                verbose=verbose,
                random_state=12)

#### 3. Train it!

In [13]:
loss, acc = nn.train(evaluate=False, plot_train_states=True, plot_train=True, plot_hist=True, plot_test_states=True, plot_test=True, win=5000)




  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 1/200 [00:02<09:36,  2.90s/it]

  1%|          | 2/200 [00:04<07:59,  2.42s/it]

  2%|▏         | 4/200 [00:04<05:37,  1.72s/it]

  3%|▎         | 6/200 [00:04<03:58,  1.23s/it]

  4%|▎         | 7/200 [00:05<03:18,  1.03s/it]

  4%|▍         | 8/200 [00:05<02:24,  1.33it/s]

  4%|▍         | 9/200 [00:05<01:47,  1.78it/s]

  6%|▌         | 11/200 [00:05<01:18,  2.39it/s]

  6%|▌         | 12/200 [00:06<01:24,  2.23it/s]

  7%|▋         | 14/200 [00:06<01:02,  2.96it/s]

  8%|▊         | 16/200 [00:06<00:47,  3.85it/s]

  8%|▊         | 17/200 [00:06<01:02,  2.94it/s]

 10%|▉         | 19/200 [00:07<00:47,  3.82it/s]

 10%|█         | 21/200 [00:07<00:36,  4.85it/s]

 12%|█▏        | 23/200 [00:07<00:41,  4.27it/s]

 12%|█▎        | 25/200 [00:07<00:32,  5.35it/s]

 14%|█▎        | 27/200 [00:08<00:37,  4.57it/s]

 14%|█▍        | 29/200 [00:08<00:30,  5.68it/s]

 16%|█▌        | 31/200 [00:08<00:24,  6.86it/s]

 16%|█▋        | 33/200 [00:09<00:31,  5.30it/s]

 18%|█▊        | 35/200 [00:09<00:25,  6.47it/s]

 18%|█▊        | 37/200 [00:10<00:31,  5.14it/s]

 20%|█▉        | 39/200 [00:10<00:25,  6.30it/s]

 20%|██        | 41/200 [00:10<00:21,  7.45it/s]

 22%|██▏       | 43/200 [00:11<00:28,  5.48it/s]

 22%|██▎       | 45/200 [00:11<00:23,  6.69it/s]

 24%|██▎       | 47/200 [00:11<00:29,  5.27it/s]

 24%|██▍       | 49/200 [00:11<00:23,  6.45it/s]

 26%|██▌       | 51/200 [00:12<00:19,  7.63it/s]

 26%|██▋       | 53/200 [00:12<00:26,  5.60it/s]

 28%|██▊       | 55/200 [00:12<00:21,  6.72it/s]

 28%|██▊       | 57/200 [00:13<00:27,  5.24it/s]

 30%|██▉       | 59/200 [00:13<00:22,  6.41it/s]

 30%|███       | 61/200 [00:13<00:18,  7.58it/s]

 32%|███▏      | 63/200 [00:14<00:24,  5.65it/s]

 32%|███▎      | 65/200 [00:14<00:19,  6.79it/s]

 34%|███▎      | 67/200 [00:14<00:25,  5.24it/s]

 34%|███▍      | 69/200 [00:15<00:20,  6.39it/s]

 36%|███▌      | 71/200 [00:15<00:17,  7.59it/s]

 36%|███▋      | 73/200 [00:15<00:22,  5.53it/s]

 38%|███▊      | 75/200 [00:15<00:18,  6.76it/s]

 38%|███▊      | 77/200 [00:16<00:23,  5.23it/s]

 40%|███▉      | 79/200 [00:16<00:18,  6.41it/s]

 40%|████      | 81/200 [00:16<00:15,  7.55it/s]

 42%|████▏     | 83/200 [00:17<00:21,  5.55it/s]

 42%|████▎     | 85/200 [00:17<00:17,  6.74it/s]

 44%|████▎     | 87/200 [00:18<00:21,  5.15it/s]

 44%|████▍     | 89/200 [00:18<00:17,  6.31it/s]

 46%|████▌     | 91/200 [00:18<00:14,  7.48it/s]

 46%|████▋     | 93/200 [00:19<00:19,  5.55it/s]

 48%|████▊     | 95/200 [00:19<00:15,  6.75it/s]

 48%|████▊     | 97/200 [00:19<00:19,  5.28it/s]

 50%|████▉     | 99/200 [00:19<00:15,  6.48it/s]

 50%|█████     | 101/200 [00:20<00:12,  7.65it/s]

 52%|█████▏    | 103/200 [00:20<00:17,  5.58it/s]

 52%|█████▎    | 105/200 [00:20<00:14,  6.76it/s]

 54%|█████▎    | 107/200 [00:21<00:18,  5.13it/s]

 55%|█████▍    | 109/200 [00:21<00:14,  6.29it/s]

 56%|█████▌    | 111/200 [00:21<00:12,  7.41it/s]

 56%|█████▋    | 113/200 [00:22<00:15,  5.50it/s]

 57%|█████▊    | 115/200 [00:22<00:12,  6.65it/s]

 58%|█████▊    | 117/200 [00:23<00:16,  5.14it/s]

 60%|█████▉    | 119/200 [00:23<00:12,  6.32it/s]

 60%|██████    | 121/200 [00:23<00:10,  7.52it/s]

 62%|██████▏   | 123/200 [00:23<00:13,  5.57it/s]

 62%|██████▎   | 125/200 [00:24<00:11,  6.70it/s]

 64%|██████▎   | 127/200 [00:24<00:14,  5.17it/s]

 64%|██████▍   | 129/200 [00:24<00:11,  6.32it/s]

 66%|██████▌   | 131/200 [00:24<00:09,  7.52it/s]

 66%|██████▋   | 133/200 [00:25<00:12,  5.56it/s]

 68%|██████▊   | 135/200 [00:25<00:09,  6.73it/s]

 68%|██████▊   | 137/200 [00:26<00:12,  5.20it/s]

 70%|██████▉   | 139/200 [00:26<00:09,  6.39it/s]

 70%|███████   | 141/200 [00:26<00:07,  7.55it/s]

 72%|███████▏  | 143/200 [00:27<00:10,  5.52it/s]

 72%|███████▎  | 145/200 [00:27<00:08,  6.73it/s]

 74%|███████▎  | 147/200 [00:27<00:10,  5.25it/s]

 74%|███████▍  | 149/200 [00:28<00:07,  6.42it/s]

 76%|███████▌  | 151/200 [00:28<00:06,  7.62it/s]

 76%|███████▋  | 153/200 [00:28<00:08,  5.53it/s]

 78%|███████▊  | 155/200 [00:28<00:06,  6.70it/s]

 78%|███████▊  | 157/200 [00:29<00:08,  5.23it/s]

 80%|███████▉  | 159/200 [00:29<00:06,  6.39it/s]

 80%|████████  | 161/200 [00:29<00:05,  7.59it/s]

 82%|████████▏ | 163/200 [00:30<00:06,  5.59it/s]

 82%|████████▎ | 165/200 [00:30<00:05,  6.78it/s]

 84%|████████▎ | 167/200 [00:31<00:06,  4.75it/s]

 84%|████████▍ | 169/200 [00:31<00:05,  5.87it/s]

 86%|████████▌ | 171/200 [00:31<00:04,  7.05it/s]

 86%|████████▋ | 173/200 [00:32<00:05,  5.34it/s]

 88%|████████▊ | 175/200 [00:32<00:03,  6.54it/s]

 88%|████████▊ | 177/200 [00:32<00:04,  5.14it/s]

 90%|████████▉ | 179/200 [00:33<00:03,  6.28it/s]

 90%|█████████ | 181/200 [00:33<00:02,  7.48it/s]

 92%|█████████▏| 183/200 [00:33<00:03,  5.48it/s]

 92%|█████████▎| 185/200 [00:33<00:02,  6.68it/s]

 94%|█████████▎| 187/200 [00:34<00:02,  5.25it/s]

 94%|█████████▍| 189/200 [00:34<00:01,  6.43it/s]

 96%|█████████▌| 191/200 [00:34<00:01,  7.58it/s]

 96%|█████████▋| 193/200 [00:35<00:01,  5.59it/s]

 98%|█████████▊| 195/200 [00:35<00:00,  6.78it/s]

 98%|█████████▊| 197/200 [00:36<00:00,  5.27it/s]

100%|█████████▉| 199/200 [00:36<00:00,  6.43it/s]

In [14]:
y_truth, y_pred, score = nn.evaluate(plot_states=True, plot_test=True)

In [7]:
plt.close('all')