In [1]:
import numpy as np
import jax
import torch
import jax.numpy as jnp

from dln.data import get_Low_light_training_set

In [2]:
train_set = get_Low_light_training_set(
        upscale_factor=1, patch_size=128, data_augmentation=True
    )

In [3]:
def numpy_data_loader(dataset, batch_size, shuffle=True):
    indices = np.arange(len(dataset))
    if shuffle:
        np.random.shuffle(indices)
    for start_idx in range(0, len(dataset) - batch_size + 1, batch_size):
        excerpt = indices[start_idx : start_idx + batch_size]
        batch = [dataset[i] for i in excerpt]
        low_light, normal_light = zip(*batch)
        jaxed_low_light = jnp.array(low_light)
        jaxed_normal_light = jnp.array(normal_light)
        yield jnp.transpose(jaxed_low_light, (0, 2, 3, 1)), jnp.transpose(
            jaxed_normal_light, (0, 2, 3, 1)
        )

In [4]:
from flax import linen as nn

class ConvBlock(nn.Module):
    output_size: int
    kernel_size: tuple
    stride: tuple
    padding: str  # 'SAME' or 'VALID'
    use_bias: bool = True
    use_bn: bool = False

    @nn.compact
    def __call__(self, x, training: bool = True):
        x = nn.Conv(features=self.output_size, 
                    kernel_size=self.kernel_size, 
                    strides=self.stride, 
                    padding=self.padding, 
                    use_bias=self.use_bias)(x)
        if self.use_bn:
            x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.PReLU()(x)
        return x


In [5]:
class PreDLN(nn.Module):
    dim: int

    @nn.compact
    def __call__(self, inputs, training: bool = True):
        x = (inputs - 0.5) * 2
        x_bright = jnp.max(x, axis=3, keepdims=True)
        x_in = jnp.concatenate((x, x_bright), axis=3)
        feat1 = ConvBlock(output_size=2 * self.dim, kernel_size=3, stride=1, padding=1)(
            x_in, training=training
        )
        feat2 = ConvBlock(output_size=self.dim, kernel_size=3, stride=1, padding=1)(
            feat1, training=training
        )
        return feat2

In [10]:
two_blocks = PreDLN(dim=64)
sample_input = jnp.ones((32, 128, 128, 3))
variables = two_blocks.init(jax.random.PRNGKey(0), sample_input)

In [11]:
for iteration, batch in enumerate(numpy_data_loader(train_set, 32), 1):
    print("iteration", iteration)
    LL_t, NL_t = batch
    print("Image Shapes")
    print(LL_t.shape, NL_t.shape)
    print("ConvBlock")
    output = two_blocks.apply(variables, LL_t, training=True)
    print(output.shape)
    break


iteration 1
Image Shapes
(32, 128, 128, 3) (32, 128, 128, 3)
ConvBlock
(32, 128, 128, 64)
