In [None]:
import jax
import jax.numpy as jnp
import optax

from dm_pix import ssim
from dm_pix import psnr
from flax.training import train_state

from dln.data import get_Low_light_training_set
from dln.jax_dln import DLN
from dln.jax_data_loader import jnp_data_loader
from dln.jax_tv import total_variation

In [None]:
@jax.jit
def apply_model(state, ll, nl):

    def loss_fn(params):
        nl_pred = state.apply_fn({"params": params}, ll)
        ssim_loss = 1 - ssim(nl, nl_pred)
        tv_loss = total_variation(nl_pred)
        loss = ssim_loss + 0.001 * tv_loss
        return jnp.mean(loss), nl_pred

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, nl_pred), grads = grad_fn(state.params)
    res_psnr = psnr(nl, nl_pred)
    return grads, loss, jnp.mean(res_psnr)


@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)


def train_epoch(state, data_loader):
    epoch_loss = []
    epoch_psnr = []
    for ll, nl in data_loader:
        grads, loss, res_psnr = apply_model(state, ll, nl)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_psnr.append(res_psnr)
    return state, jnp.mean(epoch_loss), jnp.mean(epoch_psnr)


def create_train_state(rng, model, lr):
    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=model.init(rng, jnp.ones((1, 256, 256, 3)))["params"],
        tx=optax.adam(learning_rate=lr),
    )
    return state


rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
model = DLN(64)
state = create_train_state(init_rng, model, 1e-3)
train_set = get_Low_light_training_set(
    upscale_factor=1, patch_size=128, data_augmentation=True
)
data_loader = jnp_data_loader(train_set, batch_size=4)
for epoch in range(100):
    state, loss, res_psnr = train_epoch(state, data_loader)
    print(f"Epoch: {epoch}, Loss: {loss}, PSNR: {res_psnr}")
    break