### Imports

In [85]:
# Scinet modules
from scinet import *

# My custom modules
import scinet.ed_stokes as sto

#Other modules
import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
import tqdm

%load_ext tensorboard
import tensorflow as tf
import datetime, os

import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


### Helper Functions (based on oscillator example)

In [111]:
def gen_input_stokes(Fr, St, t_observ, tt_predicted):
    in1 = np.array([sto.stokes_eqn(Fr, St, t_observ) for _ in tt_predicted])
    in2 = np.reshape(tt_predicted, (-1, 1))
    out = in2 #dummy filler
    return [in1, in2, out]

def stokes_representation_plot(net, Fr_range, St_range, t_observ, step_num=100, eval_time=1.5):
    Fr_vec = np.linspace(*Fr_range, num=step_num)
    St_vec = np.linspace(*St_range, num=step_num)
    
    FR, ST = np.meshgrid(Fr_vec, St_vec)
    
    eval_time = np.array([eval_time])
    out = np.array([net.run(gen_input_stokes(Fr, St, t_observ, eval_time), net.mu)[0] for Fr, St in zip(np.ravel(FR), np.ravel(ST))])
    
    fig = plt.figure(figsize=(net.latent_size*3.9, 2.1))
    for i in range(net.latent_size):
        zs = out[:, i]
        ax = fig.add_subplot('1{}{}'.format(net.latent_size, i + 1), projection='3d')
        Z = np.reshape(zs, FR.shape)
        surf = ax.plot_surface(FR, ST, Z, rstride=1, cstride=1, cmap=cm.inferno, linewidth=0)
        ax.set_xlabel(r'$Fr$')
        ax.set_ylabel(r'$St$')
        ax.set_zlabel('Latent activation {}'.format(i + 1))
        if (i==0):
            ax.set_zlim(-1,1) #Fix the scale for the third plot, where the activation is close to zero
        if (i==3):
            ax.set_zlim(-1,1) #Fix the scale for the third plot, where the activation is close to zero        
        ax.set_zticks([-1,-.5,0,.5,1])
    fig.tight_layout()
    return fig

def plot_stokes_prediction(net_, Fr_, St_, t_observ, t_predict):    
    x_correct = sto.stokes_eqn(Fr, St, t_predict)
    x_predict = net.run(gen_input_stokes(Fr_, St_, t_observ, t_predict), net.output).ravel()
    fig = plt.figure(figsize=(7, 4))
    ax = fig.add_subplot(111)
    ax.plot(t_predict, x_correct, color=orange_color, label='True time evolution')
    ax.plot(t_predict, x_predict, '--', color=blue_color, label='Predicted time evolution')
    ax.set_xlabel(r'$t$ [$s$]')
    ax.set_ylabel(r'$x$ [$m$]')
    handles, labels = ax.get_legend_handles_labels()
    lgd=ax.legend(handles, labels,loc='upper left', bbox_to_anchor=(0.6, 1.3), shadow=True, ncol=1)
    fig.tight_layout()
    return fig

In [87]:
blue_color='#000cff'
orange_color='#ff7700'

### Input Variables

In [88]:
netName = 'StokesNet'
observation_size = 50
latent_size = 4
question_size = 1
answer_size = 1
dev_percent   = 5
num_examples  = 200000
test_examples = num_examples * dev_percent /100

encoder_layout = [500,100]
decoder_layout = [100,100]
myBeta = 1e-3
batch_size = 512
learning_rate = 1e-3

t_sim = np.linspace(0,1,observation_size)
t_q   = 2.0

### Data creation and loading

In [89]:
sto.stokes_data(num_examples, t_sample=t_sim, fileName='stokes_example');


In [90]:
td, vd, ts, vs, proj = dl.load(dev_percent, 'stokes_example')

### Create and train neural network

In [99]:
# Create network object
net = nn.Network(observation_size, latent_size, question_size, answer_size, 
                 encoder_num_units=encoder_layout, decoder_num_units=decoder_layout,
                 name=netName) 

In [100]:
# Print initial reconstruction loss (depends on  initialization)
print(net.run(vd, net.recon_loss)) #default
print(net.run(vd, net.kl_loss))

3.7854733
8.892088


In [101]:
train_losses = []
dev_losses = []

#Training program
#procedure summary: 1000 epochs with alpha 1e-3, batch 512; 500 epochs with alpha 1e-4 batch 1024, 500 epochs with alpha 1e-5 batch 1024
num_phases = 3
all_epochs         = [1000, 500,  500 ]
all_batches        = [512,  1024, 1024]
all_learning_rates = [1e-3, 1e-4, 1e-5]

In [102]:
# Train
print_frequency = 0.1

for j in tqdm.tqdm_notebook(range(num_phases)):
    num_epochs = all_epochs[j]
    batch_size = all_batches[j]
    learning_rate = all_learning_rates[j]
    check_epochs = int(print_frequency * num_epochs)
    
    for i in tqdm.tqdm_notebook(range(num_epochs)):
        net.train(1, batch_size, learning_rate, td, vd, beta_fun=(lambda x: myBeta), test_step=10 )

        # Check progress. It is recommended to use Tensorboard instead for this.
        train_recon_error = net.run(td, net.recon_loss)
        train_kl_loss     = net.run(td, net.kl_loss)
        train_loss        = train_recon_error + myBeta*train_kl_loss

        dev_recon_error   = net.run(vd, net.recon_loss)
        dev_kl_loss       = net.run(vd, net.kl_loss)
        dev_loss          = dev_recon_error + myBeta*dev_kl_loss

        train_losses.append(train_loss)
        dev_losses.append(dev_loss)

        if i%check_epochs == 0:
            print("Training: (loss, reconstruction error, kl loss): ({:.2e}, {:.2e}, {:.2e})".format(
            train_loss, train_recon_error, train_kl_loss))
            print("Dev:      (loss, reconstruction error, kl loss): ({:.2e}, {:.2e}, {:.2e})".format(
            dev_loss, dev_recon_error, dev_kl_loss))
            print("=======================================")

    print("{} epochs trained so far".format(net.tot_epochs) )        

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


  0%|          | 0/3 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  # Remove the CWD from sys.path while we load stuff.


  0%|          | 0/1000 [00:00<?, ?it/s]

Training: (loss, reconstruction error, kl loss): (5.76e-02, 3.75e-02, 2.01e+01)
Dev:      (loss, reconstruction error, kl loss): (5.78e-02, 3.77e-02, 2.02e+01)
Training: (loss, reconstruction error, kl loss): (4.79e-03, 7.42e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.82e-03, 7.67e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.93e-03, 8.66e-04, 4.06e+00)
Dev:      (loss, reconstruction error, kl loss): (4.95e-03, 8.75e-04, 4.07e+00)
Training: (loss, reconstruction error, kl loss): (4.58e-03, 4.49e-04, 4.14e+00)
Dev:      (loss, reconstruction error, kl loss): (4.61e-03, 4.67e-04, 4.15e+00)
Training: (loss, reconstruction error, kl loss): (4.27e-03, 1.68e-04, 4.10e+00)
Dev:      (loss, reconstruction error, kl loss): (4.28e-03, 1.73e-04, 4.11e+00)
Training: (loss, reconstruction error, kl loss): (4.19e-03, 1.26e-04, 4.07e+00)
Dev:      (loss, reconstruction error, kl loss): (4.21e-03, 1.34e-04, 4.07e+00)
Training: (loss, reconstruction error, k

  0%|          | 0/500 [00:00<?, ?it/s]

Training: (loss, reconstruction error, kl loss): (4.15e-03, 9.71e-05, 4.06e+00)
Dev:      (loss, reconstruction error, kl loss): (4.17e-03, 1.03e-04, 4.06e+00)
Training: (loss, reconstruction error, kl loss): (4.16e-03, 1.06e-04, 4.05e+00)
Dev:      (loss, reconstruction error, kl loss): (4.17e-03, 1.11e-04, 4.06e+00)
Training: (loss, reconstruction error, kl loss): (4.16e-03, 1.12e-04, 4.05e+00)
Dev:      (loss, reconstruction error, kl loss): (4.18e-03, 1.19e-04, 4.06e+00)
Training: (loss, reconstruction error, kl loss): (4.15e-03, 1.16e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.17e-03, 1.24e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.16e-03, 1.11e-04, 4.05e+00)
Dev:      (loss, reconstruction error, kl loss): (4.18e-03, 1.17e-04, 4.06e+00)
Training: (loss, reconstruction error, kl loss): (4.15e-03, 1.05e-04, 4.05e+00)
Dev:      (loss, reconstruction error, kl loss): (4.17e-03, 1.11e-04, 4.06e+00)
Training: (loss, reconstruction error, k

  0%|          | 0/500 [00:00<?, ?it/s]

Training: (loss, reconstruction error, kl loss): (4.14e-03, 1.02e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.16e-03, 1.10e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.15e-03, 1.03e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.16e-03, 1.09e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.15e-03, 1.01e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.16e-03, 1.08e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.15e-03, 1.05e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.16e-03, 1.11e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.14e-03, 1.00e-04, 4.04e+00)
Dev:      (loss, reconstruction error, kl loss): (4.16e-03, 1.08e-04, 4.05e+00)
Training: (loss, reconstruction error, kl loss): (4.15e-03, 1.03e-04, 4.05e+00)
Dev:      (loss, reconstruction error, kl loss): (4.17e-03, 1.10e-04, 4.06e+00)
Training: (loss, reconstruction error, k

In [103]:
# Plot losses
%matplotlib tk
plt.plot(np.array(train_losses), 'b-')
plt.plot(np.array(dev_losses), 'r-')
plt.xlabel('# Epochs')
plt.legend(['training loss','dev loss'])
plt.yscale('log')

In [106]:
# Plot prediction
%matplotlib tk
Fr = 1.0
St = 1e-2
t_predict = np.linspace(0, t_q, 250)
plot_stokes_prediction(net, Fr, St, t_sim, t_predict);

In [112]:
stokes_representation_plot(net, [0.0,5.0], [0.1,5.0], t_sim, eval_time=1.5);

In [109]:
#Choose Network Name for saving (if desired)
date_str  = str(datetime.datetime.now().date())
name_str  = 'stokesNet_4latent_'
filename  = name_str + date_str
full_path = io.tf_save_path + filename

In [110]:
#Save Network and train_history
if os.path.isfile(full_path+'.pkl'):
    print("Filename already exists. Please choose another name.")
else:    
    net.save(filename)
    
    with open(full_path+'.npy', 'wb') as f:
        np.save(f, np.array(train_losses))
        np.save(f, np.array(dev_losses))



Saved network to file stokesNet_4latent_2021-03-14
