<a href="https://colab.research.google.com/github/hucarlos08/Co-Register-HKP-RS/blob/main/Copia_de_Train_UNetSR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab configuration

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pwd
%cd '/content/drive/MyDrive/GitHub/SR3-JAX/'

/content
/content/drive/MyDrive/GitHub/SR3-JAX


## Train step

In [3]:
import jax
import jax.numpy as jnp
import jax.nn as nn

from flax.training import train_state
import optax

from typing import Any

class TrainState(train_state.TrainState):
  batch_stats: Any


def train_step(state: TrainState, batch, rng):

  lr_images, hr_images = batch

  """Train for a single step."""
  def loss_fn(params):

    outputs, updates = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, inputs=lr_images, train=True,
                                      mutable=['batch_stats'], rngs={'dropout': rng})

    # Compute the pixel-wise mean squared error (MSE) loss
    loss = jnp.mean(jnp.abs(outputs - hr_images))

    return loss, (outputs, updates)

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

  (loss, (outputs, updates)), grads = grad_fn(state.params)

  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])

  mse = jnp.mean((outputs - hr_images) ** 2)

   # PSNR
  psnr = 20.0 * jnp.log10(1 / jnp.sqrt(mse))

  metrics = {
    'L1-loss': loss,
    'MSE': mse,
    'PSNR': psnr
  }

  rng, _ = jax.random.split(rng)

  return state, metrics, rng

## Evaluation step

In [4]:
import jax.numpy as jnp

def eval_step(state: TrainState, batch):
  low_res_images, high_res_images = batch

  """Evaluate for a single step."""
  def loss_fn(params):
    outputs = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, inputs=low_res_images, train=False)

    loss = jnp.mean(jnp.abs(outputs - high_res_images))

    return loss, outputs

    loss, outputs = loss_fn(state.params)

    mse = jnp.mean((outputs - high_res_images) ** 2)

    # PSNR
    psnr = 20.0 * jnp.log10(1 / jnp.sqrt(mse))

    metrics = {
    'L1-loss': loss,
    'MSE': mse,
    'PSNR': psnr
    }

    return state, metrics



## Training loop

In [5]:
import numpy as np

from cloudsr.utils import ProgressBar

def train(state, dataloader, epochs, bath_size, losses, avg_losses, eval_losses, eval_accuracies):

  p = ProgressBar(int(6000/bath_size))

  rng = jax.random.PRNGKey(0)

  mse_list = []
  psnr_list = []

  for epoch in range(epochs):

    # this is where the learning rate schedule state is stored in the optimizer state
    #optimizer_step = state.optimizer_state[1].count

    # run an epoch of training
    for step, batch in enumerate(dataloader()):

      p.step(reset=(step==0))

      state, metrics_train, rng = train_step(state, batch, rng)

      l1_loss = metrics_train['L1-loss']
      mse     = metrics_train['MSE']
      psnr    = metrics_train['PSNR']

      ## Add to a list
      losses.append(l1_loss)
      mse_list.append(mse)
      psnr_list.append(psnr)

    avg_loss = np.mean(losses[-step:])
    avg_mse  = np.mean(mse_list[-step:])
    avg_psnr = np.mean(psnr_list[-step:])

    # All the estimations
    avg_losses.append(avg_loss)

    # run one epoch of evals test images in a single batch)

    print("Epoch", epoch, "train loss:", avg_loss, "MSE", avg_mse, "PSNR:", avg_psnr)


  return state

### Training configuration

In [6]:
import flax
import optax

# Training hyperparams
EPOCHS = 15
BATCH_SIZE = 32
FILTERS = 16
LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6
EVAL_INTERVAL = 3
HDF5_FILE = '/content/drive/MyDrive/Data/Durlar/Medium/Durlar_lr_hi_resolution_dataset_M.h5'

## Dataset

In [7]:
from cloudsr.lidar_data_io import load_lidar_dataset_from_hdf5

# Create the JAX dataloader
dataloader = load_lidar_dataset_from_hdf5(HDF5_FILE, batch_size=BATCH_SIZE, shuffle=True)

# Define the model

In [8]:
from cloudsr.models.UnetSR import UNetSRJAX

# Create the JAX dataloader
dataloader = load_lidar_dataset_from_hdf5(HDF5_FILE, batch_size=BATCH_SIZE, shuffle=True)


# Initialize the model

# Generate a random PRNG key
rng = jax.random.PRNGKey(0)

# Create the U-Net model instance
model = UNetSRJAX(filters=FILTERS, dropout_rate=0.25, act_func=nn.relu, kernel_init=nn.initializers.he_normal())

# Initialize the model
rng, init_rng = jax.random.split(rng)
dummy_input   = jnp.ones((1, 64, 1024, 1), dtype=jnp.float32)
variables     = model.init({'params': init_rng, 'batch_stats': init_rng}, dummy_input, train=False)


# View the model parameters
params = variables['params']
batch_stats = variables['batch_stats']

# Create the optimizer

# Learning Rate schedule for JAX
jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=BATCH_SIZE, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    batch_stats=batch_stats,
    tx=optax.adam(learning_rate=jlr_decay),
)

In [9]:
losses=[]
avg_losses=[]
eval_losses=[]
eval_accuracies=[]

state = train(state, dataloader, EPOCHS, BATCH_SIZE, losses, avg_losses, eval_losses, eval_accuracies);


0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                                              100%

0%                                                                               

KeyboardInterrupt: ignored