## Building a superresolution network

In [1]:
#imports
import jax
import jax.numpy as jnp
from jax import value_and_grad
from jax_cfd.ml import towers
import haiku as hk
import gin
import numpy as np
import jax_cfd.ml.train_utils as train_utils
import xarray 
import random

In [2]:
# import data
file_name = '256x64_inner_50_outer_1000'
data = xarray.open_dataset(f'../creating_dataset/datasets/'+ file_name +'.nc', chunks={'time': '100MB'})


In [3]:
# split by timestamps
x_shape = len(data.x)
y_shape = len(data.y)
high_def = []
for i in range(len(data.time)):
    this_time_u = np.array([data.u.isel(time = i)]).reshape(x_shape, y_shape)
    this_time_v = np.array([data.v.isel(time = i)]).reshape(x_shape, y_shape)
    this_time = [this_time_u, this_time_v]
    high_def.append(this_time)

In [4]:
#warm up time (may want to discard initial stages of simulation since not really representative of turbulent flow?)
dt = float(data.time[0].values)

outer_steps = len(data.time.values)

inner_steps = (data.time[1].values-data.time[0].values)/dt

total_sim_time = outer_steps*inner_steps*dt
print("dt: \t\t" + str(dt))
print("outer_steps: \t" + str(outer_steps))
print("inner_steps: \t" + str(inner_steps))
print("total_sim_time: " + str(total_sim_time))

warm_up = 30 #seconds
warm_index = int(warm_up/total_sim_time * outer_steps // 1)
print("removed points: " + str(warm_index))
high_def = high_def[warm_index:]

dt: 		0.78125
outer_steps: 	1000
inner_steps: 	1.0
total_sim_time: 781.25
removed points: 38


In [5]:
np.max(high_def)

1.74617

In [6]:
originalMax = np.max(high_def)
originalMin = (np.min(high_def))

In [7]:
print(originalMin,originalMax)

-0.78758466 1.74617


In [8]:
# scale velocities to beetween zero and one
def scale(input, min=0, max=1):
    
    if min>max:
        raise(ValueError("Min and max may be the wrong way around"))
    
    if min<0:
        og_min = min
        max += -min
        min = 0

    
    input += (min-np.min(input))
    input /= (np.max(input)/max)
    
    try:
        input += og_min
    except:
        pass
    return input


# this function is useless, it does the same as scale()
# def scaleBack(input,originalMin,originalMax):
#     inputMin = np.min(input)
#     inputMax = np.max(input)
#     inputRange = inputMax-inputMin
    
#     originalRange = originalMax-originalMin
    
#     input += inputMin #sets min to zero
#     input *= originalRange/inputRange
#     input -= originalMin
    
#     return input


def scaleAllVelocities(high_def,allVels=False):
    
    scaled = []
    for vels in high_def:
        both_vels = []
        for vel in vels:
            vel = scale(vel,min=0,max=1)
            both_vels.append(vel)
        scaled.append(both_vels)
    return scaled
        

# scaled_high_def = scaleAllVelocities(high_def,allVels=True) #scales each frame from 0 to 1
scaled_high_def = scale(high_def) #scales all timesteps to same max and min

In [9]:
#normalisation function

In [10]:
def increaseSize(input, factor):
    w,h = np.shape(input)
    output = np.zeros((w*factor,h*factor))
    
    for width in range(w*factor):
        for height in range(h*factor):
            output[width][height] = input[width//factor][height//factor]
    return output


def decreaseSize(input,factor):
    w,h = np.shape(input)
    if w%factor != 0 or h%factor != 0:
        raise(AssertionError("Non-compatible input shape and downsample factor"))
    
    output = np.zeros((int(w/factor),int(h/factor)))
    
    for width in range(w):
        for height in range(h):
            output[width//factor][height//factor] += input[width][height]
    output /= factor**len(np.shape(output))
    return output

def downsampleHighDefVels(high_def,factor):
    low_def = []
    for vels in high_def:
        both_vels = []
        for vel in vels:
            vel = decreaseSize(vel,factor)

            vel = increaseSize(vel,factor)
            both_vels.append(vel)
        low_def.append(both_vels)
    return low_def

In [11]:
#converting lists of images to DeviceArray with correct shape

def convertListToDeviceArray(list):
    new_list = []
    for i in range(len(list)):

        new_list.append(
            jnp.dstack([
                jnp.array(list[i][0]),
                jnp.array(list[i][1])]
            )
        )
    return jnp.array(new_list)



In [12]:
#split into train and test

split = 0.8
split = int(len(scaled_high_def)*split//1)
random.shuffle(high_def)

factor = 2



%time scaled_low_def = scale(downsampleHighDefVels(scaled_high_def,factor))
scaled_high_def = convertListToDeviceArray(scaled_high_def)
scaled_low_def = convertListToDeviceArray(scaled_low_def)

X_train = scaled_low_def[:split]
Y_train = scaled_high_def[:split]

X_test = scaled_low_def[split:]
Y_test = scaled_high_def[split:]



CPU times: user 33.7 s, sys: 166 ms, total: 33.9 s
Wall time: 34 s


In [13]:
print(np.shape(scaled_high_def))
print(np.shape(scaled_low_def))

(962, 256, 64, 2)
(962, 256, 64, 2)


In [14]:
#reference:
# https://goodboychan.github.io/python/deep_learning/vision/tensorflow-keras/2020/10/13/01-Super-Resolution-CNN.html#Build-SR-CNN-Model

def mse(target, ref):
    target_data = target.astype(np.float32)
    ref_data = ref.astype(np.float32)
    err = np.sum((target_data - ref_data) ** 2)
    
    err /= float(target_data.shape[0] * target_data.shape[1])
    return err

In [15]:
def forward_pass_module(
    num_output_channels,
    ndim,
    tower_module=gin.REQUIRED
):
  """Constructs a function that initializes tower and applies it to inputs."""
  def forward_pass(inputs):
    return tower_module(num_output_channels, ndim)(inputs)

  return forward_pass

In [16]:
num_output_channels = 2
# spatial_size = 17
ndim = 2
input_channels = 2

rng_key = jax.random.PRNGKey(42)
# inputs = jax.random.uniform(rng, (spatial_size,) * ndim + (input_channels,))

forward_pass = hk.without_apply_rng(hk.transform(forward_pass_module(num_output_channels = num_output_channels, 
                                            ndim = ndim,
                                           tower_module = towers.forward_tower_factory)))

In [17]:
diego = [1,1,2,3]
print(type(diego))
diego = np.array(diego)
print(type(diego))

<class 'list'>
<class 'numpy.ndarray'>


In [None]:
# Reference:
# https://coderzcolumn.com/tutorials/artificial-intelligence/haiku-guide-to-create-multi-layer-perceptrons-using-jax

# define X_train and Y_train

def MeanSquaredErrorLoss(params, input_data, actual):
#     print(np.shape(input_data[0]))
#     params = forward_pass.init(rng, inputs)
    preds = []
    truth = []
    for i in range(len(input_data)):
        preds.append(forward_pass.apply(params, input_data[i]))
        truth.append(actual[i])
    
    
#     preds = forward_pass.apply(params, x=input_data, rng = None)
#     preds = preds.squeeze()
    return jnp.power(jnp.array(truth) - jnp.array(preds), 2).mean()

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients

# X_train[:batch_size]

sample_x = jax.random.uniform(rng_key, (256,64,input_channels))
batch_size = 1
params = forward_pass.init(rng_key, sample_x)
epochs = 100
learning_rate = jnp.array(0.1)


def train_step(params, X_train, Y_train):
    loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_train, Y_train)
    return jax.tree_map(UpdateWeights, params, param_grads), loss

# train_step = jax.jit(train_step)

print("got this far")

losses = []
for i in range(1, epochs+1):
    params,loss = train_step(params, X_train, Y_train)

    if i%1 == 0: #every 5 epochs
        print("Epoch {:.0f}/{:.0f}".format(i,epochs))
        print("\tMSE : {:.6f}".format(loss))
        losses.append(loss)

got this far
Epoch 1/100
	MSE : 0.161420
Epoch 2/100
	MSE : 0.084578
Epoch 3/100
	MSE : 0.024321
Epoch 4/100
	MSE : 0.006604
Epoch 5/100
	MSE : 0.006146
Epoch 6/100
	MSE : 0.005805
Epoch 7/100
	MSE : 0.005492
Epoch 8/100
	MSE : 0.005200


In [None]:
num = 100

input_data = X_train[:num]
actual = Y_train[:num]

preds = []
truth = []
for i in range(len(input_data)):
    preds.append(forward_pass.apply(params, input_data[i]))
    truth.append(actual[i])

In [None]:
jnp.power(np.array(truth) - np.array(preds), 2).mean()

In [30]:
preds = forward_pass.apply(params,scaled_low_def[0])

In [24]:
print(np.shape(preds))
np.shape(scaled_high_def[0])

(256, 64, 2)


(256, 64, 2)

In [None]:
sample_x = jax.random.uniform(rng, (256,64,input_channels))
np.shape(sample_x)

In [170]:
np.shape(preds.squeeze())

(256, 64, 2)

In [32]:
print(np.shape(preds))
print(np.shape(Y_train[0]))

(256, 64, 2)
(256, 64, 2)


In [33]:
jnp.power(Y_train[0] - preds, 2).mean()

DeviceArray(0.09285066, dtype=float32)

In [None]:
# # train_utils.loss_and_gradient
# train_utils.train_step(
#     loss_and_grad_fn= train_utils.loss_and_gradient,
#     update_fn =  Callable[[int, ModelGradients, OptimizerState], OptimizerState],
#     get_params_fn = Callable[[OptimizerState], ModelParams]
# )

In [None]:
# mse(inputs,output)

In [None]:

# params = forward_pass.init(rng, inputs)
# output = forward_pass.apply(params, inputs)
# expected_output_shape = inputs.shape[:-1] + (num_output_channels,)
# actual_output_shape = output.shape

# print(expected_output_shape,actual_output_shape)
# mse(inputs,output)

In [None]:
# def loss_fn(trainable_params, non_trainable_params, images, labels):
#   # NOTE: We need to combine trainable and non trainable before calling apply.
#   params = hk.data_structures.merge(trainable_params, non_trainable_params)

#   # NOTE: From here on this is a standard softmax cross entropy loss.
#   logits = f.apply(params, None, images)
#   labels = jax.nn.one_hot(labels, logits.shape[-1])
#   return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]

# def sgd_step(params, grads, *, lr):
#   return jax.tree_util.tree_map(lambda p, g: p - g * lr, params, grads)

# def train_step(trainable_params, non_trainable_params, x, y):
#   # NOTE: We will only compute gradients wrt `trainable_params`.
#   trainable_params_grads = jax.grad(loss_fn)(trainable_params,
#                                              non_trainable_params, x, y)

#   # NOTE: We are only updating `trainable_params`.
#   trainable_params = sgd_step(trainable_params, trainable_params_grads, lr=0.1)
#   return trainable_params

# train_step = jax.jit(train_step)

# for x, y in dataset(batch_size=num_classes, num_records=10000):
#   # NOTE: In our training loop only our trainable parameters are updated.
#   trainable_params = train_step(trainable_params, non_trainable_params, x, y)