## Building a superresolution network

In [1]:
#imports
import jax
import jax.numpy as jnp
from jax import value_and_grad
from jax_cfd.ml import towers
import haiku as hk
import gin
import numpy as np
import jax_cfd.ml.train_utils as train_utils

In [2]:
#reference:
# https://goodboychan.github.io/python/deep_learning/vision/tensorflow-keras/2020/10/13/01-Super-Resolution-CNN.html#Build-SR-CNN-Model

def mse(target, ref):
    target_data = target.astype(np.float32)
    ref_data = ref.astype(np.float32)
    err = np.sum((target_data - ref_data) ** 2)
    
    err /= float(target_data.shape[0] * target_data.shape[1])
    return err

In [3]:

def forward_pass_module(
    num_output_channels,
    ndim,
    tower_module=gin.REQUIRED
):
  """Constructs a function that initializes tower and applies it to inputs."""
  def forward_pass(inputs):
    return tower_module(num_output_channels, ndim)(inputs)

  return forward_pass

f = forward_pass_module(num_output_channels = 2,
                        ndim = 2,
                        tower_module = towers.forward_tower_factory)



In [4]:
# x = np.random.rand(8,28,28)
# # x = jnp.ones([8,28,28]) # simulates 8 28x28 images
# print(jnp.shape(x))
# print(hk.experimental.tabulate(f)(x))

In [13]:
num_output_channels = 2
spatial_size = 17
ndim = 2
input_channels = 2

rng = jax.random.PRNGKey(42)
inputs = jax.random.uniform(rng, (spatial_size,) * ndim + (input_channels,))

forward_pass = hk.without_apply_rng(
                    hk.transform(
                        forward_pass_module(num_output_channels = num_output_channels, 
                                            ndim = ndim,
                                           tower_module = towers.forward_tower_factory)))
params = forward_pass.init(rng, inputs)
output = forward_pass.apply(params, inputs)
expected_output_shape = inputs.shape[:-1] + (num_output_channels,)
actual_output_shape = output.shape

In [14]:
print(expected_output_shape,actual_output_shape)

(17, 17, 2) (17, 17, 2)


In [6]:
mse(inputs,output)

DeviceArray(2.2741005, dtype=float32)

In [7]:
def loss_fn(trainable_params, non_trainable_params, images, labels):
  # NOTE: We need to combine trainable and non trainable before calling apply.
  params = hk.data_structures.merge(trainable_params, non_trainable_params)

  # NOTE: From here on this is a standard softmax cross entropy loss.
  logits = f.apply(params, None, images)
  labels = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]

def sgd_step(params, grads, *, lr):
  return jax.tree_util.tree_map(lambda p, g: p - g * lr, params, grads)

def train_step(trainable_params, non_trainable_params, x, y):
  # NOTE: We will only compute gradients wrt `trainable_params`.
  trainable_params_grads = jax.grad(loss_fn)(trainable_params,
                                             non_trainable_params, x, y)

  # NOTE: We are only updating `trainable_params`.
  trainable_params = sgd_step(trainable_params, trainable_params_grads, lr=0.1)
  return trainable_params

train_step = jax.jit(train_step)

for x, y in dataset(batch_size=num_classes, num_records=10000):
  # NOTE: In our training loop only our trainable parameters are updated.
  trainable_params = train_step(trainable_params, non_trainable_params, x, y)

NameError: name 'dataset' is not defined

In [15]:
# Reference:
# https://coderzcolumn.com/tutorials/artificial-intelligence/haiku-guide-to-create-multi-layer-perceptrons-using-jax

# define X_train and Y_train

def MeanSquaredErrorLoss(weights, input_data, actual):
    preds = model.apply(weights, rng, input_data)
    preds = preds.squeeze()
    return jnp.power(actual - preds, 2).mean()

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients


batch_size = 8
params = forward_pass.init(rng, X_train[:batch_size])
epochs = 1000
learning_rate = jnp.array(0.001)


def train_step(params, X_train, Y_train):
    loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_train, Y_train)
    return jax.tree_map(UpdateWeights, params, param_grads), loss

train_step = jax.jit(train_step)



for i in range(1, epochs+1):
    params,loss = train_step(params, X_train, Y_train)

    if i%100 == 0: #every hundred epochs
        print("MSE : {:.2f}".format(loss))

NameError: name 'X_train' is not defined

In [9]:
# train_utils.loss_and_gradient
train_utils.train_step(
    loss_and_grad_fn= train_utils.loss_and_gradient,
    update_fn =  Callable[[int, ModelGradients, OptimizerState], OptimizerState],
    get_params_fn = Callable[[OptimizerState], ModelParams]
)

NameError: name 'Callable' is not defined