### Imports

In [1]:
# Scinet modules
from scinet import *

# My custom modules
import scinet.ed_stokes as sto

#Other modules
import datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np

%load_ext tensorboard
import tensorflow as tf
import datetime, os


### Helper Functions (based on oscillator example)

In [2]:
def gen_input_stokes(Fr, St, t_observ, tt_predicted):
    in1 = np.array([sto.stokes_eqn(Fr, St, t_observ) for _ in tt_predicted])
    in2 = np.reshape(tt_predicted, (-1, 1))
    out = in2 #dummy filler
    return [in1, in2, out]

def stokes_representation_plot(net, Fr_range, St_range, t_observ, step_num=100, eval_time=1.5):
    Fr_vec = np.linspace(*Fr_range, num=step_num)
    St_vec = np.linspace(*St_range, num=step_num)
    
    FR, ST = np.meshgrid(Fr_vec, St_vec)
    
    FR_plot = FR
    ST_plot = ST
    
    eval_time = np.array([eval_time])
    out = np.array([net.run(gen_input_stokes(Fr, St, t_observ, eval_time), net.mu)[0] for Fr, St in zip(np.ravel(FR), np.ravel(ST))])
        
    fig = plt.figure(figsize=(net.latent_size*3.9, 2.1))
    for i in range(net.latent_size):
        zs = out[:, i]
        ax = fig.add_subplot('1{}{}'.format(net.latent_size, i + 1), projection='3d')
        Z = np.reshape(zs, FR.shape)
        surf = ax.plot_surface(FR_plot, ST_plot, Z, rstride=1, cstride=1, cmap=cm.inferno, linewidth=0)
        ax.set_xlabel(r'$Fr$')
        ax.set_ylabel(r'$St$')
        ax.set_zlabel('Latent activation {}'.format(i + 1))
        if (i==2):
            ax.set_zlim(-2,2) #Fix the scale for the third plot, where the activation is close to zero
        ax.set_zticks([-2,-1,0,1,2])
    fig.tight_layout()
    return fig

def plot_stokes_prediction(net_, Fr_, St_, t_observ, t_predict):    
    x_correct = sto.stokes_eqn(Fr, St, t_predict)
    x_predict = net.run(gen_input_stokes(Fr_, St_, t_observ, t_predict), net.output).ravel()
    fig = plt.figure(figsize=(7, 4))
    ax = fig.add_subplot(111)
    ax.plot(t_predict, x_correct, color=orange_color, label='True time evolution')
    ax.plot(t_predict, x_predict, '--', color=blue_color, label='Predicted time evolution')
    ax.set_xlabel(r'$t$ [$s$]')
    ax.set_ylabel(r'$x$ [$m$]')
    handles, labels = ax.get_legend_handles_labels()
    lgd=ax.legend(handles, labels,loc='upper left', bbox_to_anchor=(0.6, 1.3), shadow=True, ncol=1)
    fig.tight_layout()
    return fig

In [3]:
blue_color='#000cff'
orange_color='#ff7700'

### Load Network

In [4]:
filename2Load = 'stokesNet_2021-02-19'
filename2Load = 'stokesNetDeeper_2021-03-02'
filename2Load = 'stokesNet_baseline2021-03-13'
net = nn.Network.from_saved(filename2Load)

{'decoder_num_units': [100, 100], 'input2_size': 1, 'tot_epochs': 2000, 'latent_size': 3, 'output_size': 1, 'encoder_num_units': [500, 100], 'input_size': 50, 'load_file': 'stokesNet_baseline2021-03-13', 'name': 'StokesNet'}







Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where








INFO:tensorflow:Restoring parameters from /home/jrwest/Courses/Winter_2021/CS230/project/nn_physical_concepts_cs230/tf_save/stokesNet_baseline2021-03-13.ckpt
Loaded network from file stokesNet_baseline2021-03-13


### Plot Stuff

In [32]:
# Plot prediction
%matplotlib tk
Fr = 1.0
St = 3
t_sim = np.linspace(0,1,net.input_size)
t_q = 2.0
t_predict = np.linspace(0, t_q, 250)
plot_stokes_prediction(net, Fr, St, t_sim, t_predict);

In [13]:
stokes_representation_plot(net, [0.0,5.0], [0.1,5.0], t_sim, eval_time=1.5);

In [15]:
num_examples = 200000
dev_percent = 5
sto.stokes_data(num_examples, t_sample=t_sim, fileName='stokes_example');
td, vd, ts, vs, proj = dl.load(dev_percent, 'stokes_example')

In [16]:
# Print reconstruction loss
print(net.run(vd, net.recon_loss)) #default
print(net.run(vd, net.kl_loss))

0.0001799856
4.233454
