## Building a superresolution network

In [176]:
#imports
import jax
import jax.numpy as jnp
from jax import value_and_grad
from jax_cfd.ml import towers
import haiku as hk
import gin
import numpy as np
import jax_cfd.ml.train_utils as train_utils
import xarray 
import random

In [8]:
# import data
file_name = '256x64_inner_50_outer_1000'
data = xarray.open_dataset(f'../creating_dataset/datasets/'+ file_name +'.nc', chunks={'time': '100MB'})


In [216]:
# split by timestamps
x_shape = len(data.x)
y_shape = len(data.y)
high_def = []
for i in range(len(data.time)):
    this_time_u = np.array([data.u.isel(time = i)]).reshape(x_shape, y_shape)
    this_time_v = np.array([data.v.isel(time = i)]).reshape(x_shape, y_shape)
    this_time = [this_time_u, this_time_v]
    high_def.append(this_time)

In [217]:
np.shape(high_def)

(1000, 2, 256, 64)

In [218]:
#warm up time (may want to discard initial stages of simulation since not really representative of turbulent flow?)
dt = float(data.time[0].values)

outer_steps = len(data.time.values)

inner_steps = (data.time[1].values-data.time[0].values)/dt

total_sim_time = outer_steps*inner_steps*dt
print("dt: \t\t" + str(dt))
print("outer_steps: \t" + str(outer_steps))
print("inner_steps: \t" + str(inner_steps))
print("total_sim_time: " + str(total_sim_time))

warm_up = 30 #seconds
warm_index = int(warm_up/total_sim_time * outer_steps // 1)
print("removed points: " + str(warm_index))
high_def = high_def[warm_index:]

dt: 		0.78125
outer_steps: 	1000
inner_steps: 	1.0
total_sim_time: 781.25
removed points: 38


In [None]:
# normalise velocities to beetween zero and one

In [223]:
# a = np.array([[1,2,3], [3,4,3], [5,6,3]])
# a = np.random.rand((9*9)).reshape(9,9) * 10 //1
# print(a)



def increaseSize(input, factor):
    w,h = np.shape(input)
    output = np.zeros((w*factor,h*factor))
    
    for width in range(w*factor):
        for height in range(h*factor):
            output[width][height] = input[width//factor][height//factor]
    return output


def decreaseSize(input,factor):
    w,h = np.shape(input)
    if w%factor != 0 or h%factor != 0:
        raise(AssertionError("Non-compatible input shape and downsample factor"))
    
    output = np.zeros((int(w/factor),int(h/factor)))
    
    for width in range(w):
        for height in range(h):
            output[width//factor][height//factor] += input[width][height]
    output /= factor**len(np.shape(output))
    return output

def downsampleHighDefVels(high_def,factor):
    low_def = []
    for vels in high_def:
        both_vels = []
        for vel in vels:
            vel = decreaseSize(vel,factor)

            vel = increaseSize(vel,factor)
            both_vels.append(vel)
        low_def.append(both_vels)
    return low_def

In [224]:
#split into train and test

split = 0.8
split = int(len(high_def)*split//1)
random.shuffle(high_def)

factor = 2
low_def = downsampleHighDefVels(high_def,factor)

X_train = low_def[:split]
Y_train = high_def[:split]

X_test = low_def[split:]
Y_test = high_def[split:]



In [235]:
time = 200
vel = 1
mse(X_train[time][vel],Y_train[time+1][vel])

0.007942801341414452

In [151]:
#reference:
# https://goodboychan.github.io/python/deep_learning/vision/tensorflow-keras/2020/10/13/01-Super-Resolution-CNN.html#Build-SR-CNN-Model

def mse(target, ref):
    target_data = target.astype(np.float32)
    ref_data = ref.astype(np.float32)
    err = np.sum((target_data - ref_data) ** 2)
    
    err /= float(target_data.shape[0] * target_data.shape[1])
    return err

In [21]:
def forward_pass_module(
    num_output_channels,
    ndim,
    tower_module=gin.REQUIRED
):
  """Constructs a function that initializes tower and applies it to inputs."""
  def forward_pass(inputs):
    return tower_module(num_output_channels, ndim)(inputs)

  return forward_pass

In [22]:
num_output_channels = 2
spatial_size = 17
ndim = 2
input_channels = 2

rng = jax.random.PRNGKey(42)
inputs = jax.random.uniform(rng, (spatial_size,) * ndim + (input_channels,))

forward_pass = hk.without_apply_rng(
                    hk.transform(
                        forward_pass_module(num_output_channels = num_output_channels, 
                                            ndim = ndim,
                                           tower_module = towers.forward_tower_factory)))

In [56]:
# Reference:
# https://coderzcolumn.com/tutorials/artificial-intelligence/haiku-guide-to-create-multi-layer-perceptrons-using-jax

# define X_train and Y_train

def MeanSquaredErrorLoss(weights, input_data, actual):
    preds = model.apply(weights, rng, input_data)
    preds = preds.squeeze()
    return jnp.power(actual - preds, 2).mean()

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients


batch_size = 1
params = forward_pass.init(rng, X_train[:batch_size])
epochs = 1000
learning_rate = jnp.array(0.001)


def train_step(params, X_train, Y_train):
    loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_train, Y_train)
    return jax.tree_map(UpdateWeights, params, param_grads), loss

train_step = jax.jit(train_step)



for i in range(1, epochs+1):
    params,loss = train_step(params, X_train, Y_train)

    if i%100 == 0: #every hundred epochs
        print("MSE : {:.2f}".format(loss))

NameError: name 'X_train' is not defined

In [None]:
# # train_utils.loss_and_gradient
# train_utils.train_step(
#     loss_and_grad_fn= train_utils.loss_and_gradient,
#     update_fn =  Callable[[int, ModelGradients, OptimizerState], OptimizerState],
#     get_params_fn = Callable[[OptimizerState], ModelParams]
# )

In [None]:
# mse(inputs,output)

In [None]:

# params = forward_pass.init(rng, inputs)
# output = forward_pass.apply(params, inputs)
# expected_output_shape = inputs.shape[:-1] + (num_output_channels,)
# actual_output_shape = output.shape

# print(expected_output_shape,actual_output_shape)
# mse(inputs,output)

In [None]:
# def loss_fn(trainable_params, non_trainable_params, images, labels):
#   # NOTE: We need to combine trainable and non trainable before calling apply.
#   params = hk.data_structures.merge(trainable_params, non_trainable_params)

#   # NOTE: From here on this is a standard softmax cross entropy loss.
#   logits = f.apply(params, None, images)
#   labels = jax.nn.one_hot(labels, logits.shape[-1])
#   return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]

# def sgd_step(params, grads, *, lr):
#   return jax.tree_util.tree_map(lambda p, g: p - g * lr, params, grads)

# def train_step(trainable_params, non_trainable_params, x, y):
#   # NOTE: We will only compute gradients wrt `trainable_params`.
#   trainable_params_grads = jax.grad(loss_fn)(trainable_params,
#                                              non_trainable_params, x, y)

#   # NOTE: We are only updating `trainable_params`.
#   trainable_params = sgd_step(trainable_params, trainable_params_grads, lr=0.1)
#   return trainable_params

# train_step = jax.jit(train_step)

# for x, y in dataset(batch_size=num_classes, num_records=10000):
#   # NOTE: In our training loop only our trainable parameters are updated.
#   trainable_params = train_step(trainable_params, non_trainable_params, x, y)