## Designing boundary conditions for CNNs

In [1]:
from jax_cfd.ml.diego_cnn_bcs import *

#imports
import jax
import jax.numpy as jnp

import jax_cfd.base as cfd
from jax_cfd.ml import towers
import jax_cfd.ml.train_utils as train_utils
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

import haiku as hk
import numpy as np
import xarray
import random

import pickle
from diego_model_utils import SaveObject, forward_pass_module
from diego_preprocessing import *
from diego_train_functions import *

import matplotlib.pyplot as plt
import seaborn

import time

In [2]:
# import data (fine grid)
# create X_data via mean pooling
# create Y_data by calculating everything for each frame and stacking them along the channel dimension

In [3]:
# import data
file_name = '256x64_150_seconds_inner_1'
data = xarray.open_dataset(f'../../creating_dataset/datasets/'+ file_name +'.nc', chunks={'time': '100MB'})

In [62]:
# split by timestamps
x_shape = len(data.x)
y_shape = len(data.y)
high_def = []
for i in range(int(len(data.time))):
    this_time = np.dstack([
        jnp.array([data.u.isel(time = i)]).reshape(x_shape, y_shape),
        jnp.array([data.v.isel(time = i)]).reshape(x_shape, y_shape)
    ])
    high_def.append(this_time)

In [63]:
#warm up time (may want to discard initial stages of simulation since not really representative of turbulent flow?)
dt = float(data.time[0].values)

outer_steps = len(data.time.values)

inner_steps = (data.time[1].values-data.time[0].values)/dt

total_sim_time = outer_steps*inner_steps*dt
print("dt: \t\t" + str(dt))
print("outer_steps: \t" + str(outer_steps))
print("inner_steps: \t" + str(inner_steps))
print("total_sim_time: " + str(total_sim_time))

warm_up = 15 #seconds
warm_index = int(warm_up/total_sim_time * outer_steps // 1)
print("removed points: " + str(warm_index))
high_def = high_def[warm_index:]

print("\n")
step = 50
high_def = high_def[0::step]
print("step = " + str(step))
print("Training dataset shape: ") # (frames, x, y, input channels)
print("\t" + str(np.shape(high_def)))

dt: 		0.015625
outer_steps: 	9600
inner_steps: 	1.0
total_sim_time: 150.0
removed points: 960


step = 50
Training dataset shape: 
	(173, 256, 64, 2)


In [64]:
high_def_norm,ogMean,ogStdDev = normalisingDataset(high_def)

In [65]:
jnp.shape(high_def[0])

(256, 64, 2)

In [15]:
test = high_def[0]
size = (256,64)
domain = ((0, 8), (0, 2))
# mygrid = cfd.grids.Grid(size,domain=domain)







    

In [35]:
field = np.linspace(1,25,25)
field = field.reshape(5,5)

field = jnp.array(field)
print(field)

#axis=0 gives dy derivtative, axis=1 gives dx derivative
#axis=None gives list([ddy,ddx])
print(np.shape(np.gradient(np.gradient(field,axis=None))))



npLaplacian(field)

    

[[ 1.  2.  3.  4.  5.]
 [ 6.  7.  8.  9. 10.]
 [11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20.]
 [21. 22. 23. 24. 25.]]
(3, 2, 5, 5)


array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

In [88]:
np.shape(high_def)

(173, 256, 64, 2)

In [82]:


%time np.shape(calculateALLDerivativesNUMPY(diego_test[0],sampling,4))

%time new_dataset = createDatasetDerivatives(diego_test,sampling,4)
print(np.shape(new_dataset))

CPU times: user 6.9 ms, sys: 575 µs, total: 7.47 ms
Wall time: 6.7 ms
CPU times: user 208 ms, sys: 8.94 ms, total: 217 ms
Wall time: 195 ms
(173, 256, 64, 8)


In [85]:
#split into train and test

split = 0.8
split = int(len(high_def)*split//1)
random.shuffle(high_def)

factor = 4


%time X_dataset = creatingDataset(high_def_norm,mean_pooling,factor)
%time Y_dataset = createDatasetDerivatives(high_def_norm,sampling,factor)

# %time Y_dataset = calculateResiduals(X_dataset,Y_dataset)


X_train = X_dataset[:split]
Y_train = Y_dataset[:split]

X_test = X_dataset[split:]
Y_test = Y_dataset[split:]

CPU times: user 3.73 s, sys: 5.78 ms, total: 3.74 s
Wall time: 3.74 s
CPU times: user 871 ms, sys: 796 µs, total: 872 ms
Wall time: 875 ms


In [86]:
np.shape(X_train)

(138, 64, 16, 2)

In [87]:
np.shape(Y_train)

(138, 64, 16, 8)