# Training with Open Loop Oscillator

#### 0. IPython magic and imports

In [1]:
%load_ext autoreload

import sys    
sys.path.append('/home/gurbain/hyq_ml')

import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

#### 1. Declare the network structure and the training parameters

In [2]:
%autoreload
import network
import utils

# Readout layer and backprop algo
nn_layers       = [[('relu', 40)],  # [('osc',)],                  
                   [('relu', 40)],
                   [('relu', 40)]
                  ]

batch_size      = 2048
max_epochs      = 200
stop_patience   = 150
regularization  = 0.001
metric          = "mae"
optimizer       = "adam"

# ESN
n_res           = 40
n_read          = 30
damping         = 0.1
sparsity        = 0.2
spectral_radius = 0.95
noise           = 0.001
use_real_fb     = False
in_mask         = [True]  # None  # No output is injected in the ESN
out_mask        = [True] * 24  # None  # All readouts ouputs are fed back in the ESN

# Other
data_file       = "/home/gurbain/hyq_ml/data/sims/tc.pkl"
save_folder     = "/home/gurbain/hyq_ml/data/nn_learning/" + utils.timestamp()
verbose         = -3
utils.mkdir(save_folder)

Using TensorFlow backend.


#### 2. Instanciate the network

In [3]:
nn = network.NN(nn_layers=nn_layers,
                optim=optimizer,
                metric=metric,
                batch_size=batch_size,
                max_epochs=max_epochs,
                regularization=regularization,
                esn_n_res=n_res,
                esn_n_read=n_read,
                esn_in_mask=in_mask,
                esn_out_mask=out_mask,
                esn_real_fb=use_real_fb,
                esn_spec_rad=spectral_radius,
                esn_damping=damping,
                esn_sparsity=sparsity,
                esn_noise=noise,
                data_file=data_file,
                save_folder=save_folder,
                checkpoint=False,
                verbose=verbose,
                random_state=12)

#### 3. Train it!

In [4]:
loss, acc = nn.train(evaluate=False, plot_train_states=True, plot_train=True, plot_hist=True, plot_test_states=True, plot_test=True, win=5000)




  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 1/200 [00:02<08:21,  2.52s/it]

  1%|          | 2/200 [00:03<06:45,  2.05s/it]

  2%|▏         | 4/200 [00:03<04:45,  1.46s/it]

  3%|▎         | 6/200 [00:03<03:22,  1.05s/it]

  4%|▎         | 7/200 [00:04<02:50,  1.13it/s]

  4%|▍         | 9/200 [00:04<02:02,  1.56it/s]

  6%|▌         | 11/200 [00:04<01:29,  2.12it/s]

  6%|▌         | 12/200 [00:05<01:38,  1.91it/s]

  7%|▋         | 14/200 [00:05<01:12,  2.58it/s]

  8%|▊         | 16/200 [00:05<00:54,  3.40it/s]

  9%|▉         | 18/200 [00:06<00:53,  3.41it/s]

 10%|█         | 20/200 [00:06<00:40,  4.39it/s]

 11%|█         | 22/200 [00:06<00:44,  4.02it/s]

 12%|█▏        | 24/200 [00:07<00:34,  5.07it/s]

 13%|█▎        | 26/200 [00:07<00:28,  6.19it/s]

 14%|█▍        | 28/200 [00:07<00:34,  4.92it/s]

 15%|█▌        | 30/200 [00:07<00:28,  6.06it/s]

 16%|█▌        | 32/200 [00:08<00:34,  4.90it/s]

 17%|█▋        | 34/200 [00:08<00:27,  6.05it/s]

 18%|█▊        | 36/200 [00:08<00:22,  7.24it/s]

 19%|█▉        | 38/200 [00:09<00:30,  5.36it/s]

 20%|██        | 40/200 [00:09<00:24,  6.53it/s]

 21%|██        | 42/200 [00:10<00:30,  5.11it/s]

 22%|██▏       | 44/200 [00:10<00:25,  6.24it/s]

 23%|██▎       | 46/200 [00:10<00:20,  7.38it/s]

 24%|██▍       | 48/200 [00:11<00:28,  5.30it/s]

 25%|██▌       | 50/200 [00:11<00:23,  6.43it/s]

 26%|██▌       | 52/200 [00:11<00:29,  5.09it/s]

 27%|██▋       | 54/200 [00:11<00:23,  6.18it/s]

 28%|██▊       | 56/200 [00:12<00:19,  7.34it/s]

 29%|██▉       | 58/200 [00:12<00:26,  5.39it/s]

 30%|███       | 60/200 [00:12<00:21,  6.57it/s]

 31%|███       | 62/200 [00:13<00:26,  5.14it/s]

 32%|███▏      | 64/200 [00:13<00:21,  6.23it/s]

 33%|███▎      | 66/200 [00:13<00:18,  7.41it/s]

 34%|███▍      | 68/200 [00:14<00:23,  5.53it/s]

 35%|███▌      | 70/200 [00:14<00:19,  6.71it/s]

 36%|███▌      | 72/200 [00:15<00:24,  5.20it/s]

 37%|███▋      | 74/200 [00:15<00:19,  6.35it/s]

 38%|███▊      | 76/200 [00:15<00:16,  7.58it/s]

 39%|███▉      | 78/200 [00:15<00:21,  5.55it/s]

 40%|████      | 80/200 [00:16<00:17,  6.75it/s]

 41%|████      | 82/200 [00:16<00:22,  5.29it/s]

 42%|████▏     | 84/200 [00:16<00:17,  6.45it/s]

 43%|████▎     | 86/200 [00:17<00:15,  7.59it/s]

 44%|████▍     | 88/200 [00:17<00:20,  5.58it/s]

 45%|████▌     | 90/200 [00:17<00:16,  6.72it/s]

 46%|████▌     | 92/200 [00:18<00:20,  5.24it/s]

 47%|████▋     | 94/200 [00:18<00:16,  6.41it/s]

 48%|████▊     | 96/200 [00:18<00:13,  7.61it/s]

 49%|████▉     | 98/200 [00:19<00:18,  5.61it/s]

 50%|█████     | 100/200 [00:19<00:14,  6.82it/s]

 51%|█████     | 102/200 [00:19<00:18,  5.31it/s]

 52%|█████▏    | 104/200 [00:20<00:14,  6.45it/s]

 53%|█████▎    | 106/200 [00:20<00:12,  7.63it/s]

 54%|█████▍    | 108/200 [00:20<00:16,  5.63it/s]

 55%|█████▌    | 110/200 [00:20<00:13,  6.82it/s]

 56%|█████▌    | 112/200 [00:21<00:16,  5.26it/s]

 57%|█████▋    | 114/200 [00:21<00:13,  6.39it/s]

 58%|█████▊    | 116/200 [00:21<00:11,  7.60it/s]

 59%|█████▉    | 118/200 [00:22<00:14,  5.57it/s]

 60%|██████    | 120/200 [00:22<00:11,  6.74it/s]

 61%|██████    | 122/200 [00:23<00:14,  5.20it/s]

 62%|██████▏   | 124/200 [00:23<00:12,  6.32it/s]

 63%|██████▎   | 126/200 [00:23<00:09,  7.49it/s]

 64%|██████▍   | 128/200 [00:24<00:13,  5.43it/s]

 65%|██████▌   | 130/200 [00:24<00:10,  6.59it/s]

 66%|██████▌   | 132/200 [00:24<00:13,  5.07it/s]

 67%|██████▋   | 134/200 [00:24<00:10,  6.23it/s]

 68%|██████▊   | 136/200 [00:25<00:08,  7.33it/s]

 69%|██████▉   | 138/200 [00:25<00:11,  5.53it/s]

 70%|███████   | 140/200 [00:25<00:08,  6.70it/s]

 71%|███████   | 142/200 [00:26<00:11,  5.26it/s]

 72%|███████▏  | 144/200 [00:26<00:08,  6.40it/s]

 73%|███████▎  | 146/200 [00:26<00:07,  7.57it/s]

 74%|███████▍  | 148/200 [00:27<00:09,  5.61it/s]

 75%|███████▌  | 150/200 [00:27<00:07,  6.78it/s]

 76%|███████▌  | 152/200 [00:28<00:09,  5.29it/s]

 77%|███████▋  | 154/200 [00:28<00:07,  6.46it/s]

 78%|███████▊  | 156/200 [00:28<00:05,  7.66it/s]

 79%|███████▉  | 158/200 [00:28<00:07,  5.65it/s]

 80%|████████  | 160/200 [00:29<00:05,  6.85it/s]

 81%|████████  | 162/200 [00:29<00:07,  5.33it/s]

 82%|████████▏ | 164/200 [00:29<00:05,  6.47it/s]

 83%|████████▎ | 166/200 [00:29<00:04,  7.63it/s]

 84%|████████▍ | 168/200 [00:30<00:05,  5.59it/s]

 85%|████████▌ | 170/200 [00:30<00:04,  6.78it/s]

 86%|████████▌ | 172/200 [00:31<00:05,  5.24it/s]

 87%|████████▋ | 174/200 [00:31<00:04,  6.43it/s]

 88%|████████▊ | 176/200 [00:31<00:03,  7.61it/s]

 89%|████████▉ | 178/200 [00:32<00:03,  5.59it/s]

 90%|█████████ | 180/200 [00:32<00:02,  6.80it/s]

 91%|█████████ | 182/200 [00:32<00:03,  5.19it/s]

 92%|█████████▏| 184/200 [00:33<00:02,  6.34it/s]

 93%|█████████▎| 186/200 [00:33<00:01,  7.43it/s]

 94%|█████████▍| 188/200 [00:33<00:02,  5.54it/s]

 95%|█████████▌| 190/200 [00:33<00:01,  6.72it/s]

 96%|█████████▌| 192/200 [00:34<00:01,  5.23it/s]

 97%|█████████▋| 194/200 [00:34<00:00,  6.42it/s]

 98%|█████████▊| 196/200 [00:34<00:00,  7.58it/s]

 99%|█████████▉| 198/200 [00:35<00:00,  5.61it/s]

100%|██████████| 200/200 [00:35<00:00,  6.79it/s]

In [5]:
y_truth, y_pred, score = nn.evaluate(plot_states=True, plot_test=True)

In [7]:
plt.close('all')