In [9]:
#imports
import jax
import jax.numpy as jnp

import jax_cfd.base as cfd
from jax_cfd.base import advection
from jax_cfd.ml import towers
import jax_cfd.ml.train_utils as train_utils
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

import haiku as hk
import numpy as np
import xarray
import random

import pickle
# from jax_cfd.ml.diego_model_utils import SaveObject, forward_pass_module
import jax_cfd.ml.diego_preprocessing as preprocessing
import jax_cfd.ml.diego_train_functions as training
from jax_cfd.ml import nonlinearities
import jax_cfd.ml.diego_cnn_bcs as bcs

import jax_cfd.ml.newSaveObject as saving
import jax_cfd.ml.diego_towers as mytowers

import matplotlib.pyplot as plt
import seaborn

import time

import tree_math

In [None]:
# this allows me to reload a module without having to interrupt the kernel
# import importlib
# importlib.reload()
# importlib.reload(from jax_cfd.ml.diego_train_functions import *)
# importlib.reload(from jax_cfd.ml import nonlinearities)
# importlib.reload(from jax_cfd.ml.diego_cnn_bcs import *)

In [2]:
# import data
file_name = '256x64_150_seconds_inner_1'
data = xarray.open_dataset(f'../../creating_dataset/datasets/'+ file_name +'.nc', chunks={'time': '100MB'})

In [5]:
# split by timestamps
x_shape = len(data.x)
y_shape = len(data.y)
high_def = []
for i in range(int(len(data.time))):
    this_time = np.dstack([
        jnp.array([data.u.isel(time = i)]).reshape(x_shape,y_shape),
        jnp.array([data.v.isel(time = i)]).reshape(x_shape,y_shape)
    ])
    high_def.append(this_time)

In [6]:
#warm up time (may want to discard initial stages of simulation since not really representative of turbulent flow?)
dt = float(data.time[0].values)

outer_steps = len(data.time.values)

inner_steps = (data.time[1].values-data.time[0].values)/dt

total_sim_time = outer_steps*inner_steps*dt
print("dt: \t\t" + str(dt))
print("outer_steps: \t" + str(outer_steps))
print("inner_steps: \t" + str(inner_steps))
print("total_sim_time: " + str(total_sim_time))

warm_up = 15 #seconds
warm_index = int(warm_up/total_sim_time * outer_steps // 1)
print("removed points: " + str(warm_index))
high_def = high_def[warm_index:]

print("\n")
step = 100
high_def = high_def[0::step]
print("step = " + str(step))
print("Training dataset shape: ") # (frames, x, y, input channels)
print("\t" + str(np.shape(high_def)))

dt: 		0.015625
outer_steps: 	9600
inner_steps: 	1.0
total_sim_time: 150.0
removed points: 960


step = 100
Training dataset shape: 
	(87, 256, 64, 2)


In [7]:
def convect(v):  # pylint: disable=function-redefined
      return tuple(
          advection.advect_van_leer(u, v, dt) for u in v)
    
convection = cfd.equations._wrap_term_as_vector(convect, name='convection')

In [61]:
def reshapeForAdvection(v):
    mygrid = (grids.GridVariable(array = grids.GridArray(data = v[:,:,0], 
                                            offset = (1.0, 0.5), 
                                            grid=grids.Grid(
                                                shape=(256, 64),  
                                                domain=((0.0, 8.0), (0.0, 2.0)) ) ),
#                        bc=cfd.boundaries.channel_flow_boundary_conditions(ndim=2)
                        bc = cfd.boundaries.periodic_boundary_conditions(2)
                   ),
             grids.GridVariable(array = grids.GridArray(data = v[:,:,1], 
                                            offset = (0.5, 1.0),
                                            grid=grids.Grid(
                                                shape=(256, 64),  
                                                domain=((0.0, 8.0), (0.0, 2.0)) ) ),
#                        bc=cfd.boundaries.channel_flow_boundary_conditions(ndim=2)
                        bc = cfd.boundaries.periodic_boundary_conditions(2)
                   )
           )
    return tree_math.Vector(mygrid)

In [99]:
all_advected = []
for field in high_def:
    advected = convection(reshapeForAdvection(field))
    vel = 0
    uadv = advected.tree_flatten()[0][0][vel].data
    vel = 1
    vadv = advected.tree_flatten()[0][0][vel].data
    this_time = jnp.dstack([
            jnp.array(uadv),
            jnp.array(vadv)
        ])
    all_advected.append(this_time)