In [1]:
import jax
import jax.numpy as jnp

from dm_pix import ssim
from dm_pix import psnr
from flax.training import train_state

from dln.jax_dln import DLN
from dln.jax_data_loader import jnp_data_loader
from dln.jax_tv import total_variation

In [2]:
@jax.jit
def apply_model(state, ll, nl):

    def loss_fn(params):
        nl_pred = state.apply_fn({"params": params}, ll)
        ssim_loss = 1 - ssim(nl, nl_pred)
        tv_loss = total_variation(nl_pred)
        loss = ssim_loss + 0.001 * tv_loss
        return loss, nl_pred

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, nl_pred), grads = grad_fn(state.params)
    res_psnr = psnr(nl, nl_pred)
    return grads, loss, res_psnr


@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [3]:
def train_epoch(state, data_loader):
    for ll, nl in data_loader:
        grads, loss, res_psnr = apply_model(state, ll, nl)
        state = update_model(state, grads)
    return state, loss, res_psnr

In [None]:
train_epoch