## Training to velocities, 1st derivatives and laplacians

In [1]:
from jax_cfd.ml.diego_cnn_bcs import *

#imports
import jax
import jax.numpy as jnp

import jax_cfd.base as cfd
from jax_cfd.ml import towers
import jax_cfd.ml.train_utils as train_utils
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

import haiku as hk
import numpy as np
import xarray
import random

import pickle
# from jax_cfd.ml.diego_model_utils import SaveObject, forward_pass_module
from jax_cfd.ml.diego_preprocessing import *
from jax_cfd.ml.diego_train_functions import *
from jax_cfd.ml import nonlinearities

from jax_cfd.ml.newSaveObject import *
from jax_cfd.ml.diego_towers import *

import matplotlib.pyplot as plt
import seaborn

import time

In [2]:
# import data (fine grid)
# create X_data via mean pooling
# create Y_data by calculating everything for each frame and stacking them along the channel dimension

In [3]:
# import data
file_name = '256x64_150_seconds_inner_1'
data = xarray.open_dataset(f'../../creating_dataset/datasets/'+ file_name +'.nc', chunks={'time': '100MB'})

In [4]:
# split by timestamps
x_shape = len(data.x)
y_shape = len(data.y)
high_def = []
for i in range(int(len(data.time))):
    this_time = np.dstack([
        jnp.array([data.u.isel(time = i)]).reshape(x_shape,y_shape).T,
        jnp.array([data.v.isel(time = i)]).reshape(x_shape,y_shape).T
    ])
    high_def.append(this_time)



In [5]:
#warm up time (may want to discard initial stages of simulation since not really representative of turbulent flow?)
dt = float(data.time[0].values)

outer_steps = len(data.time.values)

inner_steps = (data.time[1].values-data.time[0].values)/dt

total_sim_time = outer_steps*inner_steps*dt
print("dt: \t\t" + str(dt))
print("outer_steps: \t" + str(outer_steps))
print("inner_steps: \t" + str(inner_steps))
print("total_sim_time: " + str(total_sim_time))

warm_up = 15 #seconds
warm_index = int(warm_up/total_sim_time * outer_steps // 1)
print("removed points: " + str(warm_index))
high_def = high_def[warm_index:]

print("\n")
step = 100
high_def = high_def[0::step]
print("step = " + str(step))
print("Training dataset shape: ") # (frames, x, y, input channels)
print("\t" + str(np.shape(high_def)))

dt: 		0.015625
outer_steps: 	9600
inner_steps: 	1.0
total_sim_time: 150.0
removed points: 960


step = 100
Training dataset shape: 
	(87, 64, 256, 2)


In [6]:
%time high_def_norm,ogMean,ogStdDev = normalisingDataset(high_def)

CPU times: user 8.63 ms, sys: 1.04 ms, total: 9.66 ms
Wall time: 9.62 ms


## Defining what we are training to

In [13]:
which_outputs = {
        "u" : False,
        "du" : False,
        "lapu" : True,
        
        "v" : False,
        "dv" : False,
        "lapv" : True
    }

In [14]:
#split into train and test

split = 0.8
split = int(len(high_def)*split//1)
random.shuffle(high_def)

factor = 4

print("Create X dataset: ")
%time X_dataset = creatingDataset(high_def_norm,mean_pooling,factor)

print("\nCreate Y dataset: ")
%time Y_dataset = createDatasetDerivatives(high_def_norm,sampling,factor,which_outputs)

# %time Y_dataset = calculateResiduals(X_dataset,Y_dataset)


X_train = X_dataset[:split]
Y_train = Y_dataset[:split]

X_test = X_dataset[split:]
Y_test = Y_dataset[split:]

# NOTE: padding conditions can be specified via the "conditions" input to the padYDataset function below
# the padXDataset only pads u and v for channel flow conditions, so hard coded for impermeability and no-slip conditions (0,0)
print("\nPadding all datasets: ")
padding = [1,1] #for a 3 by 3 kernel, find a better way to define this (so not redifined when creating CNN)
%time X_train = padXDataset(X_train,padding)
%time Y_train = padYDataset(Y_train,padding)

%time X_test = padXDataset(X_test,padding)
%time Y_test = padYDataset(Y_test,padding)

print("\nShapes of all datasets")
printAllShapes(X_train,Y_train, X_test,Y_test)

Create X dataset: 
CPU times: user 1.93 s, sys: 5.59 ms, total: 1.94 s
Wall time: 1.94 s

Create Y dataset: 
CPU times: user 362 ms, sys: 988 µs, total: 362 ms
Wall time: 364 ms

Padding all datasets: 
CPU times: user 1.39 s, sys: 971 µs, total: 1.39 s
Wall time: 1.39 s
CPU times: user 4.16 s, sys: 3.78 ms, total: 4.17 s
Wall time: 4.19 s
CPU times: user 366 ms, sys: 976 µs, total: 367 ms
Wall time: 369 ms
CPU times: user 1.09 s, sys: 794 µs, total: 1.09 s
Wall time: 1.1 s

Shapes of all datasets
(69, 18, 66, 2)
(69, 18, 66, 6)
(18, 18, 66, 2)
(18, 18, 66, 6)


## Defining model shape

In [9]:
num_output_channels = 8
# spatial_size = 17
ndim = 2
input_channels = 2

rng_key = jax.random.PRNGKey(42)

In [15]:
def ConvNet(x):
    cnn = CNN(CNN_specs)
    return cnn(x)

CNN_specs = {
    "hidden_channels" : 4,
    "hidden_layers" : 2,
    "nonlinearity" : "relu",
    "num_output_channels" : 6
}

# CNN_specs = None

In [16]:
forward_pass = hk.without_apply_rng(hk.transform(ConvNet))

In [17]:
epochs = 3
learning_rates = staggeredLearningRate((120,0.01),(70,0.001))
printEvery=1
%time losses,val_losses,params = train(X_train,Y_train,X_test,Y_test,rng_key,input_channels,epochs,printEvery=printEvery,learning_rates=learning_rates,params=None,forward_pass=forward_pass,tol = 1e-10)

Shapes of all datasets
(69, 18, 66, 2)
(69, 18, 66, 6)
(18, 18, 66, 2)
(18, 18, 66, 6)


Epoch 1/3
	mse : 0.121802		val mse : 0.136308	Estimated end time: 23:15:09


Epoch 2/3
	mse : 0.108319		val mse : 0.121205	Estimated end time: 23:15:09


Epoch 3/3
	mse : 0.097199		val mse : 0.108745	Estimated end time: 23:15:09



Finished training at max epochs

CPU times: user 6.03 s, sys: 82 ms, total: 6.12 s
Wall time: 5.57 s


In [None]:
step = 1

plt.plot(losses[::step], label="training")
plt.plot(val_losses[::step],label="validation")
plt.ylabel("mse")
plt.xlabel("epochs")
# plt.yscale("log")
plt.legend()

In [None]:
toSave = newSaveObject(params,CNN_specs)

In [2]:
save_path = "./pred_all_eight.pickle"

In [38]:
with open(save_path,"wb") as f:
    pickle.dump(toSave,f)

## Loading model

In [39]:
with open(save_path,'rb',) as pickle_file:
    loaded = pickle.load(pickle_file)
    CNN_specs = loaded.CNN_specs
    loaded.forward_pass = hk.without_apply_rng(hk.transform(ConvNet))

In [41]:
np.shape(loaded.forward_pass.apply(loaded.params,X_test[0]))

(18, 66, 8)