# Train by using roll, pitch, yaw and efforts only (kadj module activated)

#### 0. IPython magic and imports

In [1]:
%load_ext autoreload
%autoreload

import sys
sys.path.append('/home/gurbain/hyq_ml')

import matplotlib
matplotlib.use('Qt5Agg')

#### 1. Declare the network structure and the training parameters

In [2]:
import network
import utils

# Readout layer and backprop algo
nn_layers       = [[('osc',)],
                   [('relu', 40)],
                   [('relu', 40)],
                   [('relu', 40)]
                  ]
batch_size      = 2048
max_epochs      = 200
stop_patience   = 150
regularization  = 0.001
metric          = "mae"
optimizer       = "adam"

# ESN
n_res           = 60
n_read          = 30
damping         = 0.1
sparsity        = 0.2
spectral_radius = 0.95
noise           = 0.001
use_real_fb     = False
in_mask         = [False, True, True, True, True, True] # No output is injected in the ESN
out_mask        = [False] * 24  # All readouts ouputs are fed back in the ESN

# Other
data_file       = "/home/gurbain/hyq_ml/data/sims/tc_kadj.pkl"
save_folder     = "/home/gurbain/hyq_ml/data/nn_learning/" + utils.timestamp()
verbose         = -3
utils.mkdir(save_folder)

Using TensorFlow backend.


#### 2. Instanciate the network

In [3]:
nn = network.NN(nn_layers=nn_layers,
                optim=optimizer,
                metric=metric,
                batch_size=batch_size,
                max_epochs=max_epochs,
                regularization=regularization,
                esn_n_res=n_res,
                esn_n_read=n_read,
                esn_in_mask=in_mask,
                esn_out_mask=out_mask,
                esn_real_fb=use_real_fb,
                esn_spec_rad=spectral_radius,
                esn_damping=damping,
                esn_sparsity=sparsity,
                esn_noise=noise,
                data_file=data_file,
                save_folder=save_folder,
                checkpoint=False,
                verbose=verbose,
                random_state=12)

#### 3. Train it!

In [4]:
loss, acc = nn.train(evaluate=False, plot_train_states=True, plot_train=True, plot_hist=True, plot_test_states=True, plot_test=True, win=5000)




  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 1/200 [00:02<09:09,  2.76s/it]

  1%|          | 2/200 [00:03<07:17,  2.21s/it]

  2%|▏         | 4/200 [00:03<05:09,  1.58s/it]

  2%|▎         | 5/200 [00:03<03:41,  1.13s/it]

  3%|▎         | 6/200 [00:04<02:39,  1.21it/s]

  4%|▎         | 7/200 [00:04<02:21,  1.36it/s]

  4%|▍         | 9/200 [00:04<01:43,  1.84it/s]

  5%|▌         | 10/200 [00:04<01:18,  2.43it/s]

  6%|▌         | 11/200 [00:05<01:00,  3.14it/s]

  6%|▌         | 12/200 [00:05<01:18,  2.40it/s]

  6%|▋         | 13/200 [00:05<01:00,  3.09it/s]

  7%|▋         | 14/200 [00:05<00:47,  3.88it/s]

  8%|▊         | 16/200 [00:06<00:38,  4.75it/s]

  8%|▊         | 17/200 [00:06<00:55,  3.28it/s]

  9%|▉         | 18/200 [00:06<00:44,  4.10it/s]

 10%|▉         | 19/200 [00:06<00:36,  4.97it/s]

 10%|█         | 20/200 [00:06<00:30,  5.85it/s]

 10%|█         | 21/200 [00:07<00:26,  6.67it/s]

 11%|█         | 22/200 [00:07<00:46,  3.83it/s]

 12%|█▏        | 23/200 [00:07<00:37,  4.69it/s]

 12%|█▏        | 24/200 [00:07<00:31,  5.56it/s]

 12%|█▎        | 25/200 [00:07<00:27,  6.38it/s]

 13%|█▎        | 26/200 [00:07<00:24,  7.13it/s]

 14%|█▎        | 27/200 [00:08<00:44,  3.91it/s]

 14%|█▍        | 29/200 [00:08<00:35,  4.78it/s]

 15%|█▌        | 30/200 [00:08<00:30,  5.62it/s]

 16%|█▌        | 31/200 [00:08<00:26,  6.42it/s]

 16%|█▌        | 32/200 [00:09<00:45,  3.72it/s]

 16%|█▋        | 33/200 [00:09<00:36,  4.57it/s]

 17%|█▋        | 34/200 [00:09<00:30,  5.44it/s]

 18%|█▊        | 35/200 [00:09<00:26,  6.27it/s]

 18%|█▊        | 37/200 [00:10<00:33,  4.89it/s]

 19%|█▉        | 38/200 [00:10<00:28,  5.75it/s]

 20%|█▉        | 39/200 [00:10<00:24,  6.56it/s]

 20%|██        | 41/200 [00:10<00:21,  7.28it/s]

 21%|██        | 42/200 [00:11<00:40,  3.92it/s]

 22%|██▏       | 44/200 [00:11<00:32,  4.77it/s]

 22%|██▎       | 45/200 [00:11<00:27,  5.62it/s]

 23%|██▎       | 46/200 [00:11<00:23,  6.44it/s]

 24%|██▎       | 47/200 [00:12<00:41,  3.70it/s]

 24%|██▍       | 48/200 [00:12<00:33,  4.55it/s]

 24%|██▍       | 49/200 [00:12<00:27,  5.41it/s]

 25%|██▌       | 50/200 [00:12<00:24,  6.24it/s]

 26%|██▌       | 51/200 [00:12<00:21,  7.01it/s]

 26%|██▌       | 52/200 [00:13<00:38,  3.87it/s]

 26%|██▋       | 53/200 [00:13<00:31,  4.73it/s]

 27%|██▋       | 54/200 [00:13<00:26,  5.59it/s]

 28%|██▊       | 55/200 [00:13<00:22,  6.44it/s]

 28%|██▊       | 56/200 [00:13<00:19,  7.21it/s]

 28%|██▊       | 57/200 [00:14<00:36,  3.88it/s]

 29%|██▉       | 58/200 [00:14<00:30,  4.67it/s]

 30%|██▉       | 59/200 [00:14<00:25,  5.52it/s]

 30%|███       | 60/200 [00:14<00:21,  6.37it/s]

 30%|███       | 61/200 [00:14<00:19,  7.13it/s]

 31%|███       | 62/200 [00:15<00:34,  3.95it/s]

 32%|███▏      | 63/200 [00:15<00:28,  4.79it/s]

 32%|███▏      | 64/200 [00:15<00:24,  5.66it/s]

 32%|███▎      | 65/200 [00:15<00:20,  6.49it/s]

 33%|███▎      | 66/200 [00:15<00:18,  7.22it/s]

 34%|███▎      | 67/200 [00:15<00:33,  3.99it/s]

 34%|███▍      | 68/200 [00:16<00:27,  4.84it/s]

 34%|███▍      | 69/200 [00:16<00:22,  5.72it/s]

 36%|███▌      | 71/200 [00:16<00:19,  6.58it/s]

 36%|███▌      | 72/200 [00:16<00:34,  3.76it/s]

 36%|███▋      | 73/200 [00:16<00:27,  4.62it/s]

 37%|███▋      | 74/200 [00:17<00:22,  5.49it/s]

 38%|███▊      | 75/200 [00:17<00:19,  6.33it/s]

 38%|███▊      | 76/200 [00:17<00:17,  7.10it/s]

 38%|███▊      | 77/200 [00:17<00:31,  3.94it/s]

 40%|███▉      | 79/200 [00:18<00:25,  4.82it/s]

 40%|████      | 80/200 [00:18<00:21,  5.61it/s]

 40%|████      | 81/200 [00:18<00:18,  6.45it/s]

 41%|████      | 82/200 [00:18<00:31,  3.77it/s]

 42%|████▏     | 84/200 [00:18<00:25,  4.63it/s]

 42%|████▎     | 85/200 [00:19<00:20,  5.52it/s]

 43%|████▎     | 86/200 [00:19<00:17,  6.36it/s]

 44%|████▎     | 87/200 [00:19<00:30,  3.71it/s]

 44%|████▍     | 88/200 [00:19<00:24,  4.55it/s]

 44%|████▍     | 89/200 [00:19<00:20,  5.43it/s]

 45%|████▌     | 90/200 [00:19<00:17,  6.28it/s]

 46%|████▌     | 91/200 [00:20<00:15,  6.98it/s]

 46%|████▌     | 92/200 [00:20<00:27,  3.94it/s]

 47%|████▋     | 94/200 [00:20<00:21,  4.84it/s]

 48%|████▊     | 95/200 [00:20<00:18,  5.67it/s]

 48%|████▊     | 97/200 [00:21<00:21,  4.68it/s]

 50%|████▉     | 99/200 [00:21<00:18,  5.57it/s]

 50%|█████     | 101/200 [00:21<00:15,  6.46it/s]

 51%|█████     | 102/200 [00:22<00:25,  3.78it/s]

 52%|█████▏    | 104/200 [00:22<00:20,  4.65it/s]

 52%|█████▎    | 105/200 [00:22<00:17,  5.51it/s]

 53%|█████▎    | 106/200 [00:22<00:14,  6.34it/s]

 54%|█████▎    | 107/200 [00:23<00:24,  3.74it/s]

 54%|█████▍    | 108/200 [00:23<00:20,  4.59it/s]

 55%|█████▍    | 109/200 [00:23<00:16,  5.47it/s]

 55%|█████▌    | 110/200 [00:23<00:14,  6.27it/s]

 56%|█████▌    | 112/200 [00:24<00:18,  4.87it/s]

 56%|█████▋    | 113/200 [00:24<00:15,  5.69it/s]

 57%|█████▊    | 115/200 [00:24<00:12,  6.59it/s]

 58%|█████▊    | 116/200 [00:24<00:11,  7.30it/s]

 58%|█████▊    | 117/200 [00:25<00:21,  3.93it/s]

 59%|█████▉    | 118/200 [00:25<00:17,  4.79it/s]

 60%|█████▉    | 119/200 [00:25<00:14,  5.67it/s]

 60%|██████    | 120/200 [00:25<00:12,  6.49it/s]

 60%|██████    | 121/200 [00:25<00:10,  7.20it/s]

 61%|██████    | 122/200 [00:26<00:20,  3.88it/s]

 62%|██████▏   | 123/200 [00:26<00:16,  4.73it/s]

 62%|██████▏   | 124/200 [00:26<00:13,  5.58it/s]

 62%|██████▎   | 125/200 [00:26<00:11,  6.40it/s]

 63%|██████▎   | 126/200 [00:26<00:10,  7.13it/s]

 64%|██████▎   | 127/200 [00:27<00:18,  3.87it/s]

 64%|██████▍   | 128/200 [00:27<00:15,  4.68it/s]

 65%|██████▌   | 130/200 [00:27<00:12,  5.59it/s]

 66%|██████▌   | 131/200 [00:27<00:10,  6.43it/s]

 66%|██████▌   | 132/200 [00:28<00:18,  3.76it/s]

 66%|██████▋   | 133/200 [00:28<00:14,  4.60it/s]

 67%|██████▋   | 134/200 [00:28<00:12,  5.42it/s]

 68%|██████▊   | 135/200 [00:28<00:10,  6.28it/s]

 68%|██████▊   | 136/200 [00:28<00:09,  7.06it/s]

 68%|██████▊   | 137/200 [00:28<00:16,  3.85it/s]

 69%|██████▉   | 138/200 [00:29<00:13,  4.69it/s]

 70%|██████▉   | 139/200 [00:29<00:10,  5.56it/s]

 70%|███████   | 141/200 [00:29<00:09,  6.42it/s]

 71%|███████   | 142/200 [00:29<00:15,  3.76it/s]

 72%|███████▏  | 144/200 [00:30<00:12,  4.62it/s]

 73%|███████▎  | 146/200 [00:30<00:09,  5.54it/s]

 74%|███████▎  | 147/200 [00:30<00:15,  3.48it/s]

 74%|███████▍  | 149/200 [00:31<00:11,  4.32it/s]

 75%|███████▌  | 150/200 [00:31<00:09,  5.17it/s]

 76%|███████▌  | 151/200 [00:31<00:08,  5.99it/s]

 76%|███████▌  | 152/200 [00:31<00:13,  3.65it/s]

 76%|███████▋  | 153/200 [00:31<00:10,  4.47it/s]

 77%|███████▋  | 154/200 [00:31<00:08,  5.35it/s]

 78%|███████▊  | 156/200 [00:32<00:07,  6.20it/s]

 78%|███████▊  | 157/200 [00:32<00:11,  3.70it/s]

 79%|███████▉  | 158/200 [00:32<00:09,  4.56it/s]

 80%|███████▉  | 159/200 [00:32<00:07,  5.45it/s]

 80%|████████  | 160/200 [00:33<00:06,  6.31it/s]

 80%|████████  | 161/200 [00:33<00:05,  6.98it/s]

 81%|████████  | 162/200 [00:33<00:09,  3.92it/s]

 82%|████████▏ | 163/200 [00:33<00:07,  4.77it/s]

 82%|████████▎ | 165/200 [00:33<00:06,  5.67it/s]

 83%|████████▎ | 166/200 [00:34<00:05,  6.48it/s]

 84%|████████▎ | 167/200 [00:34<00:08,  3.74it/s]

 84%|████████▍ | 169/200 [00:34<00:06,  4.60it/s]

 85%|████████▌ | 170/200 [00:34<00:05,  5.49it/s]

 86%|████████▌ | 171/200 [00:34<00:04,  6.28it/s]

 86%|████████▌ | 172/200 [00:35<00:07,  3.74it/s]

 86%|████████▋ | 173/200 [00:35<00:05,  4.59it/s]

 87%|████████▋ | 174/200 [00:35<00:04,  5.48it/s]

 88%|████████▊ | 175/200 [00:35<00:03,  6.29it/s]

 88%|████████▊ | 176/200 [00:35<00:03,  7.02it/s]

 88%|████████▊ | 177/200 [00:36<00:05,  3.87it/s]

 89%|████████▉ | 178/200 [00:36<00:04,  4.73it/s]

 90%|████████▉ | 179/200 [00:36<00:03,  5.54it/s]

 90%|█████████ | 180/200 [00:36<00:03,  6.38it/s]

 90%|█████████ | 181/200 [00:36<00:02,  7.04it/s]

 91%|█████████ | 182/200 [00:37<00:04,  3.87it/s]

 92%|█████████▏| 184/200 [00:37<00:03,  4.76it/s]

 92%|█████████▎| 185/200 [00:37<00:02,  5.64it/s]

 93%|█████████▎| 186/200 [00:37<00:02,  6.46it/s]

 94%|█████████▎| 187/200 [00:38<00:03,  3.79it/s]

 94%|█████████▍| 188/200 [00:38<00:02,  4.65it/s]

 94%|█████████▍| 189/200 [00:38<00:01,  5.53it/s]

 95%|█████████▌| 190/200 [00:38<00:01,  6.33it/s]

 96%|█████████▌| 191/200 [00:38<00:01,  7.04it/s]

 96%|█████████▌| 192/200 [00:39<00:02,  3.94it/s]

 96%|█████████▋| 193/200 [00:39<00:01,  4.80it/s]

 97%|█████████▋| 194/200 [00:39<00:01,  5.66it/s]

 98%|█████████▊| 195/200 [00:39<00:00,  6.50it/s]

 98%|█████████▊| 196/200 [00:39<00:00,  7.25it/s]

 98%|█████████▊| 197/200 [00:40<00:00,  3.91it/s]

100%|█████████▉| 199/200 [00:40<00:00,  4.79it/s]

In [5]:
y_truth, y_pred, score = nn.evaluate(plot_states=True, plot_test=True, win=5000)