### Imports

In [None]:
# Scinet modules
from scinet import *

# My custom modules
import scinet.ed_stokes as sto

#Other modules
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np

### Helper Functions (copied from oscillator example)

In [51]:
def gen_input_stokes(Fr, St, t_observ, tt_predicted):
    in1 = np.array([sto.stokes_eqn(Fr, St, t_observ) for _ in tt_predicted])
    in2 = np.reshape(tt_predicted, (-1, 1))
    out = in2 #dummy filler
    return [in1, in2, out]

def stokes_representation_plot(net, Fr_range, St_range, t_observ, step_num=100, eval_time=1.5):
    Fr_vec = np.linspace(*Fr_range, num=step_num)
    St_vec = np.linspace(*St_range, num=step_num)
    
    FR, ST = np.meshgrid(Fr_vec, St_vec)
    
    eval_time = np.array([eval_time])
    out = np.array([net.run(gen_input_stokes(Fr, St, t_observ, eval_time), net.mu)[0] for Fr, St in zip(np.ravel(FR), np.ravel(ST))])
    
    fig = plt.figure(figsize=(net.latent_size*3.9, 2.1))
    for i in range(net.latent_size):
        zs = out[:, i]
        ax = fig.add_subplot('1{}{}'.format(net.latent_size, i + 1), projection='3d')
        Z = np.reshape(zs, FR.shape)
        surf = ax.plot_surface(FR, ST, Z, rstride=1, cstride=1, cmap=cm.inferno, linewidth=0)
        ax.set_xlabel(r'$Fr$')
        ax.set_ylabel(r'$St$')
        ax.set_zlabel('Latent activation {}'.format(i + 1))
        if (i==2):
            ax.set_zlim(-1,1) #Fix the scale for the third plot, where the activation is close to zero
        ax.set_zticks([-1,-0.5,0,0.5,1])
    fig.tight_layout()
    return fig

def plot_stokes_prediction(net_, Fr_, St_, t_observ, t_predict):    
    x_correct = sto.stokes_eqn(Fr, St, t_predict)
    x_predict = net.run(gen_input_stokes(Fr_, St_, t_observ, t_predict), net.output).ravel()
    fig = plt.figure(figsize=(7, 4))
    ax = fig.add_subplot(111)
    ax.plot(t_predict, x_correct, color=orange_color, label='True time evolution')
    ax.plot(t_predict, x_predict, '--', color=blue_color, label='Predicted time evolution')
    ax.set_xlabel(r'$t$ [$s$]')
    ax.set_ylabel(r'$x$ [$m$]')
    handles, labels = ax.get_legend_handles_labels()
    lgd=ax.legend(handles, labels,loc='upper left', bbox_to_anchor=(0.6, 1.3), shadow=True, ncol=1)
    fig.tight_layout()
    return fig

In [52]:
blue_color='#000cff'
orange_color='#ff7700'

### Input Variables

In [53]:
observation_size = 50
latent_size = 3
question_size = 1
answer_size = 1
dev_percent = 10
num_examples = 50000

t_sim = np.linspace(0,1,observation_size)
t_q = 2.0

### Data creation and loading

In [54]:
sto.stokes_data(num_examples, t_sample=t_sim, fileName='stokes_example');


In [55]:
td, vd, ts, vs, proj = dl.load(dev_percent, 'stokes_example')

### Create and train neural network

In [61]:
# Create network object
net = nn.Network(observation_size, latent_size, question_size, answer_size) 

In [62]:
# Print initial reconstruction loss (depends on initialization)
print(net.run(vd, net.recon_loss)) #default
print(net.run(vd, net.kl_loss))

2.470951
11.610241


In [47]:
# Train
net.train(50, 256, 0.001, td, vd)

  0%|          | 0/50 [00:00<?, ?it/s]

In [65]:
# Check progress. It is recommended to use Tensorboard instead for this.
print(net.run(vd, net.recon_loss)) #default
print(net.run(vd, net.kl_loss))

2.470951
11.610241


In [66]:
# Plot prediction
Fr = 1.0
St = 3.0
t_predict = np.linspace(0, t_q, 250)
plot_stokes_prediction(net, Fr, St, t_sim, t_predict);

In [50]:
%matplotlib tk
stokes_representation_plot(net, [-5.0,5.0], [0.1,5.0], t_sim, eval_time=1.5);