## Training to advection terms

In [1]:
#imports
import jax
import jax.numpy as jnp

import jax_cfd.base as cfd
from jax_cfd.base import advection
from jax_cfd.ml import towers
import jax_cfd.ml.train_utils as train_utils
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

import haiku as hk
import numpy as np
import xarray
import random

import pickle
# from jax_cfd.ml.diego_model_utils import SaveObject, forward_pass_module
import jax_cfd.ml.diego_preprocessing as preprocessing
import jax_cfd.ml.diego_train_functions as training
from jax_cfd.ml import nonlinearities
import jax_cfd.ml.diego_cnn_bcs as bcs

import jax_cfd.ml.newSaveObject as saving
import jax_cfd.ml.diego_towers as mytowers

import matplotlib.pyplot as plt
import seaborn

import time

import tree_math

In [23]:
# this allows me to reload a module without having to interrupt the kernel
import importlib
importlib.reload(mytowers)
# importlib.reload(from jax_cfd.ml.diego_train_functions import *)
# importlib.reload(from jax_cfd.ml import nonlinearities)
# importlib.reload(from jax_cfd.ml.diego_cnn_bcs import *)

<module 'jax_cfd.ml.diego_towers' from '/rds/general/user/dd519/home/FYP/forked_jax/jax-cfd/jax_cfd/ml/diego_towers.py'>

In [3]:
# import data
file_name = '256x64_150_seconds_inner_1'
data = xarray.open_dataset(f'../../creating_dataset/datasets/'+ file_name +'.nc', chunks={'time': '100MB'})

In [4]:
# split by timestamps
x_shape = len(data.x)
y_shape = len(data.y)
high_def = []
for i in range(int(len(data.time))):
    this_time = np.dstack([
        jnp.array([data.u.isel(time = i)]).reshape(x_shape,y_shape),
        jnp.array([data.v.isel(time = i)]).reshape(x_shape,y_shape)
    ])
    high_def.append(this_time)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
#warm up time (may want to discard initial stages of simulation since not really representative of turbulent flow?)
dt = float(data.time[0].values)

outer_steps = len(data.time.values)

inner_steps = (data.time[1].values-data.time[0].values)/dt

total_sim_time = outer_steps*inner_steps*dt
print("dt: \t\t" + str(dt))
print("outer_steps: \t" + str(outer_steps))
print("inner_steps: \t" + str(inner_steps))
print("total_sim_time: " + str(total_sim_time))

warm_up = 15 #seconds
warm_index = int(warm_up/total_sim_time * outer_steps // 1)
print("removed points: " + str(warm_index))
high_def = high_def[warm_index:]

print("\n")
step = 10
high_def = high_def[0::step]
print("step = " + str(step))
print("Training dataset shape: ") # (frames, x, y, input channels)
print("\t" + str(np.shape(high_def)))

dt: 		0.015625
outer_steps: 	9600
inner_steps: 	1.0
total_sim_time: 150.0
removed points: 960


step = 10
Training dataset shape: 
	(864, 256, 64, 2)


In [6]:
def convect(v):  # pylint: disable=function-redefined
      return tuple(
          advection.advect_van_leer(u, v, dt) for u in v)
    
convection = cfd.equations._wrap_term_as_vector(convect, name='convection')

In [7]:
def reshapeForAdvection(v):
    mygrid = (grids.GridVariable(array = grids.GridArray(data = v[:,:,0], 
                                            offset = (1.0, 0.5), 
                                            grid=grids.Grid(
                                                shape=(256, 64),  
                                                domain=((0.0, 8.0), (0.0, 2.0)) ) ),
#                        bc=cfd.boundaries.channel_flow_boundary_conditions(ndim=2)
                        bc = cfd.boundaries.periodic_boundary_conditions(2)
                   ),
             grids.GridVariable(array = grids.GridArray(data = v[:,:,1], 
                                            offset = (0.5, 1.0),
                                            grid=grids.Grid(
                                                shape=(256, 64),  
                                                domain=((0.0, 8.0), (0.0, 2.0)) ) ),
#                        bc=cfd.boundaries.channel_flow_boundary_conditions(ndim=2)
                        bc = cfd.boundaries.periodic_boundary_conditions(2)
                   )
           )
    return tree_math.Vector(mygrid)

In [8]:
all_advected = []
for field in high_def:
    advected = convection(reshapeForAdvection(field))
    vel = 0
    uadv = advected.tree_flatten()[0][0][vel].data
    vel = 1
    vadv = advected.tree_flatten()[0][0][vel].data
    this_time = jnp.dstack([
            jnp.array(uadv),
            jnp.array(vadv)
        ])
    all_advected.append(this_time)

In [9]:
#split into train and test

split = 0.8
split = int(len(high_def)*split//1)
random.shuffle(high_def)

factor = 4

print("Create X dataset: ")
%time X_dataset = preprocessing.creatingDataset(high_def,preprocessing.sampling,factor)

print("\nCreate Y dataset: ")
padding = [1,1]
%time Y_dataset = preprocessing.creatingDataset(all_advected,preprocessing.sampling,factor)

# %time Y_dataset = calculateResiduals(X_dataset,Y_dataset)


X_train = X_dataset[:split]
Y_train = Y_dataset[:split]

X_test = X_dataset[split:]
Y_test = Y_dataset[split:]



# print("\nPadding all datasets: ")
# padding = [1,1] #this is for a 3 by 3 kernel, find a better way to define this (so not redifined when creating CNN)
# %time X_train = padXDataset(X_train,padding)
# %time Y_train = padYDatasetNew(Y_train,padding,conditions)

# %time X_test = padXDataset(X_test,padding)
# %time Y_test = padYDataset(Y_test,padding,conditions)

print("\nShapes of all datasets")
training.printAllShapes(X_train,Y_train, X_test,Y_test)

Create X dataset: 
CPU times: user 175 ms, sys: 8.94 ms, total: 184 ms
Wall time: 183 ms

Create Y dataset: 
CPU times: user 874 ms, sys: 10.9 ms, total: 885 ms
Wall time: 889 ms

Shapes of all datasets
(691, 64, 16, 2)
(691, 64, 16, 2)
(173, 64, 16, 2)
(173, 64, 16, 2)


In [27]:
def ConvNet(x):
    cnn = mytowers.CNN(CNN_specs)
    return cnn(x)

CNN_specs = {
    "hidden_channels" : 5,
    "hidden_layers" : 10,
    "nonlinearity" : "relu",
    "num_output_channels" : 2
}
input_channels = 2

# CNN_specs = None

In [28]:
forward_pass = hk.without_apply_rng(hk.transform(ConvNet))

In [29]:
instance = training.MyTraining(X_train,Y_train,X_test,Y_test,
                      jax.random.PRNGKey(40), #rng_key
                      input_channels=2,
                      epochs = 50,
                      printEvery=1,#epochs
                      learning_rates=training.staggeredLearningRate((10,0.1),(20,0.01),(30,0.001)), #iterated over batches
                      batch_size=len(X_train)//2+1, # number or len(X_train)
                      validateEvery=1,
                      params=None,
                      forward_pass=forward_pass,
                      tol = 1e-10)

%time instance.train()

Shapes of all datasets
(691, 64, 16, 2)
(691, 64, 16, 2)
(173, 64, 16, 2)
(173, 64, 16, 2)



Start time: 1:16:08
Epoch 1/50
	mse : 0.040015		val mse : 0.070931	Estimated end time: 1:58:37


Epoch 2/50
	mse : 0.039864		val mse : 0.070769	Estimated end time: 2:12:12


Epoch 3/50
	mse : 0.039695		val mse : 0.070586	Estimated end time: 2:22:03


Epoch 4/50
	mse : 0.039505		val mse : 0.070385	Estimated end time: 2:27:12


Epoch 5/50
	mse : 0.039299		val mse : 0.070168	Estimated end time: 2:30:32


Epoch 6/50
	mse : 0.039079		val mse : 0.069939	Estimated end time: 2:31:48


Epoch 7/50
	mse : 0.038849		val mse : 0.069702	Estimated end time: 2:32:42


Epoch 8/50
	mse : 0.038613		val mse : 0.069460	Estimated end time: 2:33:25


Epoch 9/50
	mse : 0.038376		val mse : 0.069219	Estimated end time: 2:33:52


Epoch 10/50
	mse : 0.038141		val mse : 0.068979	Estimated end time: 2:34:38


Epoch 11/50
	mse : 0.038013		val mse : 0.068956	Estimated end time: 2:35:03


Epoch 12/50
	mse : 0.037990		val mse 

KeyboardInterrupt: 

In [13]:

batches = np.arange(len(instance.losses))
val_step = len(instance.losses)//len(instance.val_losses)

batches_val = batches[::val_step]+val_step
plt.plot(batches+1,instance.losses, label="training")
plt.plot(batches_val,instance.val_losses,label="validation")
plt.ylabel("mse")
plt.xlabel("batches")
# plt.yscale("log")
plt.legend()

ZeroDivisionError: integer division or modulo by zero

In [None]:
description = "left for lunch, please save me"

In [None]:
toSave = saving.newSaveObject(instance.params,instance.losses,instance.val_losses,description,CNN_specs)

In [None]:
save_path = "./../models/final_models/advection_1.pickle"

In [None]:
with open(save_path,"wb") as f:
    pickle.dump(toSave,f)

del save_path

## Loading model

In [None]:
with open(save_path,'rb',) as pickle_file:
    loaded = pickle.load(pickle_file)
    CNN_specs = loaded.CNN_specs
    loaded.forward_pass = hk.without_apply_rng(hk.transform(ConvNet))

In [None]:
loaded.description