# Layers

> normalizing flow layers


In [1]:
#| default_exp layers

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
from fastai.imports import *

In [4]:
 #| export

import os
from importlib import import_module

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

In [5]:
 #| export
 
flow_layer_class_dict = {}

def regist_layer(layer_class):
    layer_name = layer_class.__name__.lower()
    assert not layer_name in flow_layer_class_dict, 'there is already registered layer: %s in flow_layer_class_dict.' % layer_name
    flow_layer_class_dict[layer_name] = layer_class
    return layer_class



In [6]:
def get_flow_layer(layer_name:str):
    layer_name = layer_name.lower()
    return flow_layer_class_dict[layer_name]


## Normalizing Flows


### Dequantization

#### Uniform Dequantization

In [7]:
#| export

@regist_layer
class UniformDequantization(nn.Module):
    def __init__(self, alpha=1e-5, num_bits=8, device='cpu', name='uniform_dequantization'):
        """
        Uniform dequantization layer for flows.
        
        Args:
            alpha (float): Small constant used to scale the input to avoid boundary values (default: 1e-5).
            num_bits (int): Number of bits used for quantization (default: 8).
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'uniform_dequantization').
        """
        super(UniformDequantization, self).__init__()
        self.alpha = alpha
        self.num_bits = num_bits
        self.quantization_bins = 2 ** num_bits
        # Precompute the log-determinant of the Jacobian per dimension
        self.register_buffer(
            'ldj_per_dim',
            - num_bits * torch.log(torch.tensor(2.0, device=device, dtype=torch.float))
        )
        self.name = name

    def _ldj(self, shape):
        """
        Computes the log-determinant of the Jacobian for a given shape.

        Args:
            shape (torch.Size): Shape of the input tensor.

        Returns:
            torch.Tensor: Log-determinant of the Jacobian repeated for the batch size.
        """
        batch_size = shape[0]
        num_dims = shape[1:].numel()
        ldj = self.ldj_per_dim * num_dims
        return ldj.repeat(batch_size)

    def _inverse(self, z, **kwargs):
        """
        Applies the inverse dequantization transformation to the input.

        Args:
            z (torch.Tensor): Input tensor to transform.

        Returns:
            torch.Tensor: Dequantized tensor.
        """
        # Apply the inverse sigmoid transformation to z
        z = self._sigmoid_inverse(z)
        # Quantize the values to integer bins
        z = (self.quantization_bins * z).floor().clamp(min=0, max=self.quantization_bins - 1)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Applies the forward dequantization transformation and computes the log-determinant of the Jacobian.

        Args:
            x (torch.Tensor): Input tensor to transform.

        Returns:
            tuple: Transformed tensor and the log-determinant of the Jacobian.
        """
        z, ldj = self._dequant(x.to(torch.float32))
        # Uncomment the next line if an additional sigmoid transformation is needed
        # z, ldj = self._sigmoid(z, ldj)
        return z, ldj
    
    def _sigmoid(self, z, ldj):
        """
        Applies an invertible sigmoid transformation to the input.

        Args:
            z (torch.Tensor): Input tensor to transform.
            ldj (torch.Tensor): Log-determinant of the Jacobian.

        Returns:
            tuple: Transformed tensor and updated log-determinant of the Jacobian.
        """
        # Update ldj with the sigmoid transformation's contribution
        ldj += (-z - 2 * F.softplus(-z)).sum(dim=[1, 2, 3])
        z = torch.sigmoid(z)
        # Adjust the log-determinant for the alpha scaling
        ldj -= torch.log(torch.tensor(1.0 - self.alpha, device=z.device, dtype=z.dtype)) * z.flatten(1).shape[1]
        # Scale z to avoid boundaries
        z = (z - 0.5 * self.alpha) / (1 - self.alpha)
        return z, ldj
    
    def _sigmoid_inverse(self, z):
        """
        Applies the inverse of the sigmoid transformation to the input.

        Args:
            z (torch.Tensor): Input tensor to transform.

        Returns:
            torch.Tensor: Transformed tensor.
        """
        # Scale z to avoid boundaries 0 and 1
        z = z * (1 - self.alpha) + 0.5 * self.alpha
        # Apply the logit function (inverse sigmoid)
        z = torch.log(z) - torch.log(1 - z)
        return z
    
    def _dequant(self, x):
        """
        Transforms discrete values to continuous volumes for dequantization.

        Args:
            x (torch.Tensor): Input tensor with discrete values.

        Returns:
            tuple: Dequantized tensor and the log-determinant of the Jacobian.
        """
        # Add uniform noise to dequantize
        u = torch.rand(x.shape, device=x.device, dtype=x.dtype)
        z = (x + u) / self.quantization_bins
        # Compute the log-determinant of the Jacobian
        ldj = self._ldj(z.shape)
        return z, ldj




In [8]:
a = torch.randint(256,[4, 4])
b, _ = UniformDequantization()._forward_and_log_det_jacobian(a)
print(a)
print(b)

tensor([[195,  37, 243, 255],
        [198, 212,  35, 240],
        [183, 145,  20, 120],
        [132,  48,  40, 101]])
tensor([[0.7638, 0.1468, 0.9514, 0.9980],
        [0.7743, 0.8284, 0.1368, 0.9406],
        [0.7160, 0.5668, 0.0807, 0.4709],
        [0.5177, 0.1895, 0.1573, 0.3954]])


#### Variational Dequantization (TO DO)

In [9]:
#| export

# class VariationalDequantization(UniformDequantization):

#     def __init__(self, var_flows, alpha=1e-5, num_bits=8, device='cpu', name='variational_dequantization'):
#         """
#         Variational dequantization layer inheriting from UniformDequantization.
        
#         Inputs:
#             var_flows - A list of flow transformations to use for modeling q(u|x).
#             alpha - Small constant used to scale the input to avoid very small and large values.
#             num_bits - Number of bits used for quantization.
#             device - Device to run computations on (default: 'cpu').
#             name - Name of the module (default: 'variational_dequantization').
#         """
#         super().__init__(alpha=alpha, num_bits=num_bits, device=device, name=name)
#         self.flows = nn.ModuleList(var_flows)
        
#     def _dequant(self, x):
#         # Transform discrete values to continuous volumes
#         u = torch.rand(x.shape, device=x.device, dtype=x.dtype)
#         img = (x / (self.quantization_bins - 1)) * 2 - 1 # We condition the flows on x, i.e. the original image
        
#         u = self._sigmoid_inverse(u)
#         for flow in self.flows:
#             u, ldj = flow(u, ldj, orig_img=img)
#         u, ldj = self._sigmoid(u, ldj)
        
#         z = (x + u) / self.quantization_bins
#         ldj = self._ldj(z.shape)
#         return z, ldj

#     # def dequant(self, z, ldj):
#     #     z = z.to(torch.float32)
#     #     img = (z / 255.0) * 2 - 1 # We condition the flows on x, i.e. the original image

#     #     # Prior of u is a uniform distribution as before
#     #     # As most flow transformations are defined on [-infinity,+infinity], we apply an inverse sigmoid first.
#     #     deq_noise = torch.rand_like(z).detach()
#     #     deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=True)
#     #     for flow in self.flows:
#     #         deq_noise, ldj = flow(deq_noise, ldj, reverse=False, orig_img=img)
#     #     deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=False)

#     #     # After the flows, apply u as in standard dequantization
#     #     z = (z + deq_noise) / 256.0
#     #     ldj -= np.log(256.0) * np.prod(z.shape[1:])
#     #     return z, ldj


### Conditional Linear

In [10]:
#| export

@regist_layer
class ConditionalLinear(nn.Module):
    """
    A conditional linear transformation module that applies different scales and biases
    based on the average pixel size and camera values provided in the input.

    The class performs both forward and inverse transformations, adjusting the scaling
    and bias parameters conditionally. The parameters are selected based on the mean
    values of the 'pixel' and 'cam' inputs, which are matched to predefined sets of
    pixel sizes and camera values.

    Attributes:
        name (str): Name of the transformation.
        pixel_size (torch.Tensor): Predefined set of pixel sizes.
        cam_vals (torch.Tensor): Predefined set of camera values.
        log_scale (torch.nn.Parameter): Learnable log-scale parameters for the transformation.
        bias (torch.nn.Parameter): Learnable bias parameters for the transformation.
    
    Methods:
        _inverse(z, **kwargs): Performs the inverse transformation based on the input 'z' and conditionals.
        _forward_and_log_det_jacobian(x, **kwargs): Performs the forward transformation and computes the log determinant of the Jacobian.
    """

    def __init__(self, device='cpu', name='linear_transformation'):
        # Initialize the parent class (nn.Module)
        super(ConditionalLinear, self).__init__()
        self.name = name

        # Define pixel sizes and camera values as tensors
        self.pixel_size = torch.tensor([60, 90, 100, 160, 320], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'

        # Define learnable parameters: log_scale and bias
        self.log_scale = nn.Parameter(torch.zeros(25), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(25), requires_grad=True)

    def _inverse(self, z, **kwargs):
        # Compute one-hot encoded pixel indices based on the mean of 'pixel' in kwargs
        gain_one_hot = self.pixel_size == torch.mean(kwargs['pixel'], dim=[1, 2, 3]).unsqueeze(1)
        pixel = gain_one_hot.nonzero()[:, 1]

        # Compute one-hot encoded camera indices based on the mean of 'cam' in kwargs
        cam_one_hot = self.cam_vals == torch.mean(kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
        cam = cam_one_hot.nonzero()[:, 1]

        # Combine pixel and camera indices to get iso_cam indices
        iso_cam = pixel * 5 + cam
        iso_cam = torch.arange(0, 25).to(z.device) == iso_cam.unsqueeze(1)

        # Select corresponding log_scale and bias values based on iso_cam indices
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]

        # Compute the inverse transformation
        x = (z - bias.reshape((-1, 1, 1, 1))) / torch.exp(log_scale.reshape((-1, 1, 1, 1)))
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        # Compute one-hot encoded pixel indices based on the mean of 'pixel' in kwargs
        gain_one_hot = self.pixel_size == torch.mean(kwargs['pixel'], dim=[1, 2, 3]).unsqueeze(1)
        pixel = gain_one_hot.nonzero()[:, 1]

        # Compute one-hot encoded camera indices based on the mean of 'cam' in kwargs
        cam_one_hot = self.cam_vals == torch.mean(kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
        cam = cam_one_hot.nonzero()[:, 1]

        # Combine pixel and camera indices to get iso_cam indices
        iso_cam = pixel * 5 + cam
        iso_cam = torch.arange(0, 25).to(x.device) == iso_cam.unsqueeze(1)

        # Select corresponding log_scale and bias values based on iso_cam indices
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]

        # Compute the forward transformation
        z = x * torch.exp(log_scale.reshape((-1, 1, 1, 1))) + bias.reshape((-1, 1, 1, 1))

        # Compute the log determinant of the Jacobian
        log_abs_det_J_inv = log_scale * np.prod(x.shape[1:])

        return z, log_abs_det_J_inv


In [11]:
batch_size = 1
channels = 1
height = 5
width = 5
device = 'cpu'

x = torch.randn(batch_size, channels, height, width).to(device)
pixel = torch.tensor([[[[60]] * width] * height] * batch_size, dtype=torch.float32).to(device)
cam = torch.tensor([[[[1]] * width] * height] * batch_size, dtype=torch.float32).to(device)

kwargs = {'pixel': pixel, 'cam': cam}

 # Forward transformation
z, log_det_jacobian = ConditionalLinear(device=device)._forward_and_log_det_jacobian(x, **kwargs)
test_eq(z.shape, x.shape)
test_eq(log_det_jacobian.shape, torch.Size([batch_size]))

# Inverse transformation
x_reconstructed = ConditionalLinear(device=device)._inverse(z, **kwargs)
test_eq(x_reconstructed.shape, x.shape)

# Check if the reconstructed input is close to the original input
assert torch.allclose(x, x_reconstructed, atol=1e-5)


In [12]:
z, log_det_jacobian, x, x_reconstructed

(tensor([[[[ 1.2494, -0.1547,  0.1914, -0.6321, -0.4821],
           [ 0.3390, -1.1744, -2.3254, -1.3291, -0.7013],
           [ 0.3492, -0.5303,  0.2715, -0.5903, -0.3316],
           [ 0.5376, -0.1348, -0.9580,  0.1716, -0.3195],
           [-0.3640, -2.9999, -1.8587,  0.0516, -2.2012]]]],
        grad_fn=<AddBackward0>),
 tensor([0.], grad_fn=<MulBackward0>),
 tensor([[[[ 1.2494, -0.1547,  0.1914, -0.6321, -0.4821],
           [ 0.3390, -1.1744, -2.3254, -1.3291, -0.7013],
           [ 0.3492, -0.5303,  0.2715, -0.5903, -0.3316],
           [ 0.5376, -0.1348, -0.9580,  0.1716, -0.3195],
           [-0.3640, -2.9999, -1.8587,  0.0516, -2.2012]]]]),
 tensor([[[[ 1.2494, -0.1547,  0.1914, -0.6321, -0.4821],
           [ 0.3390, -1.1744, -2.3254, -1.3291, -0.7013],
           [ 0.3492, -0.5303,  0.2715, -0.5903, -0.3316],
           [ 0.5376, -0.1348, -0.9580,  0.1716, -0.3195],
           [-0.3640, -2.9999, -1.8587,  0.0516, -2.2012]]]],
        grad_fn=<DivBackward0>))

### Conditional Linear $e^2$

In [13]:
#| export

@regist_layer
class ConditionalLinearExp2(nn.Module):
    """
    Conditional linear transformation layer for flows, conditioned on specific ISO levels and setup codes.
    
    This module applies a linear transformation to the input tensor, where the transformation parameters
    (log scale and bias) are conditioned based on the pixel size and setup code provided as input. 
    The module supports both forward and inverse transformations.

    Attributes:
        name (str): Name of the module.
        device (str): Device to run computations on.
        pixel_size (tensor): pixel size used for conditioning.
        cam_vals (tensor): Predefined setup codes used for conditioning.
        log_scale (nn.Parameter): Learnable log scale parameters for the transformation.
        bias (nn.Parameter): Learnable bias parameters for the transformation.

    Methods:
        _inverse(z, **kwargs):
            Applies the inverse transformation to the input tensor z.
        
        _forward_and_log_det_jacobian(x, **kwargs):
            Applies the forward transformation to the input tensor x and computes the log determinant
            of the Jacobian of the transformation.
    """
    def __init__(self, in_ch=1, device='cpu', name='linear_transformation_exp2'):
        """
        Initializes the ConditionalLinearExp2 module with specified input channels, device, and name.

        Args:
            in_ch (int): Number of input channels. Default is 3.
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'linear_transformation_exp2').
        """
        super(ConditionalLinearExp2, self).__init__()
        self.name = name
        self.device = device 

        # camera values used for conditioning the transformation
        self.pixel_size = torch.tensor([60, 90, 100, 160, 320], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'

        # Parameters for the linear transformation, conditioned on camera values
        self.log_scale = nn.Parameter(torch.zeros(25, in_ch), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(25, in_ch), requires_grad=True)

    def _inverse(self, z, **kwargs):
        """
        Applies the inverse transformation to the input tensor z.
        
        Args:
            z (tensor): Input tensor to be inversely transformed.
            kwargs: Additional keyword arguments containing 'ISO-level' and 'setup-code' for conditioning.
        
        Returns:
            tensor: The inversely transformed tensor.
        """
        b, _, _, _ = z.shape

        # Determine ISO index based on 'pixel-size' from kwargs
        iso = torch.zeros([b], device=self.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.pixel_size):
            iso += torch.where(kwargs['pixel-size'] == iso_val, iso_idx, 0.0)

        # Determine camera index based on 'ssetup-code' from kwargs
        cam = torch.zeros([b], device=self.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['setup-code'] == cam_val, cam_idx, 0.0)

        # Combine pixel and camera indices to get a unique index for each combination
        iso_cam = iso * self.pixel_size.shape[0] + cam
        iso_cam = torch.arange(0, self.pixel_size.shape[0] * self.cam_vals.shape[0]).to(device) == iso_cam.unsqueeze(1)

        # Select log scale and bias parameters based on the unique index
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]

        # Apply the inverse transformation
        x = (z - bias.reshape((-1, z.shape[1], 1, 1))) / torch.exp(log_scale.reshape((-1, z.shape[1], 1, 1)))
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Applies the forward transformation to the input tensor x and computes the log determinant of the Jacobian.
        
        Args:
            x (tensor): Input tensor to be transformed.
            kwargs: Additional keyword arguments containing 'ISO-level' and 'setup-code' for conditioning.
        
        Returns:
            tensor: The transformed tensor.
            tensor: The log determinant of the Jacobian of the transformation.
        """
        b, _, _, _ = x.shape

        # Determine ISO index based on 'ISO-level' from kwargs
        iso = torch.zeros([b], device=x.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.pixel_size):
            iso += torch.where(kwargs['pixel-size'] == iso_val, iso_idx, 0.0)

        # Determine camera index based on 'setup-code' from kwargs
        cam = torch.zeros([b], device=x.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['setup-code'] == cam_val, cam_idx, 0.0)

        # Combine pixel and camera indices to get a unique index for each combination
        iso_cam = iso * self.pixel_size.shape[0] + cam
        iso_cam = torch.arange(0, self.pixel_size.shape[0] * self.cam_vals.shape[0]).to(device) == iso_cam.unsqueeze(1)

        # Select log scale and bias parameters based on the unique index
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]

        # Apply the forward transformation
        z = x * torch.exp(log_scale.reshape((-1, x.shape[1], 1, 1))) + bias.reshape((-1, x.shape[1], 1, 1))
        log_abs_det_J_inv = torch.sum(log_scale * np.prod(x.shape[2:]), dim=1)

        return z, log_abs_det_J_inv


In [17]:
device = 'cpu'

x = torch.randn(1, 1, 5, 5).to(device)

kwargs = {
        'pixel-size': torch.tensor([60], dtype=torch.float32).to(device),
        'setup-code': torch.tensor([0], dtype=torch.float32).to(device)
    }

 # Forward transformation
z, log_det_jacobian = ConditionalLinearExp2(device=device)._forward_and_log_det_jacobian(x, **kwargs)
test_eq(z.shape, x.shape)
test_eq(log_det_jacobian.shape, torch.Size([batch_size]))

# Inverse transformation
x_reconstructed = ConditionalLinearExp2(device=device)._inverse(z, **kwargs)
test_eq(x_reconstructed.shape, x.shape)

# Check if the reconstructed input is close to the original input
assert torch.allclose(x, x_reconstructed, atol=1e-5)


In [18]:
z, x, log_det_jacobian

(tensor([[[[-0.9294, -0.6881,  3.2016,  0.9993,  3.2234],
           [-0.9341, -0.4158, -0.1198,  0.1328, -1.0001],
           [-1.0890, -0.9274,  0.4519, -1.0435,  0.8232],
           [-1.3048, -1.6218,  0.2501,  0.0532,  0.5669],
           [ 0.4533,  1.2550, -1.0358, -0.4922, -0.4962]]]],
        grad_fn=<AddBackward0>),
 tensor([[[[-0.9294, -0.6881,  3.2016,  0.9993,  3.2234],
           [-0.9341, -0.4158, -0.1198,  0.1328, -1.0001],
           [-1.0890, -0.9274,  0.4519, -1.0435,  0.8232],
           [-1.3048, -1.6218,  0.2501,  0.0532,  0.5669],
           [ 0.4533,  1.2550, -1.0358, -0.4922, -0.4962]]]]),
 tensor([0.], grad_fn=<SumBackward1>))

### Signal Dependent Conditional Linear

In [None]:
#| export


@regist_layer
class SignalDependentConditionalLinear(nn.Module):
    def __init__(self, meta_encoder, scale_and_bias, in_ch=3, device='cpu', name='signal_dependent_condition_linear'):
        super(SignalDependentConditionalLinear, self).__init__()
        self.name = name
        self.device = device 

        self.in_ch = in_ch
        self.iso_vals = torch.tensor([100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'
        self.encode_ch = 3
        self.meta_encoder = meta_encoder(10, self.encode_ch)
        self.scale_and_bias = scale_and_bias(self.encode_ch+in_ch, in_ch*2) # scale, bias per channels

    def _get_embeddings(self, x, **kwargs):
        b,_,_,_ = x.shape

        iso = torch.zeros([b], device=x.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.iso_vals):
            iso += torch.where(kwargs['ISO-level'] == iso_val, iso_idx, 0.0)

        cam = torch.zeros([b], device=x.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['smartphone-code'] == cam_val, cam_idx, 0.0)

        iso_one_hot = F.one_hot(iso.to(torch.int64), num_classes=self.iso_vals.shape[0]).to(torch.float32)
        cam_one_hot = F.one_hot(cam.to(torch.int64), num_classes=self.cam_vals.shape[0]).to(torch.float32)

        embedding = self.meta_encoder(torch.cat((iso_one_hot, cam_one_hot), dim=1)) # [b, 10] -> [b,encode_ch]
        embedding = embedding.reshape((-1, self.encode_ch, 1, 1))
        embedding = torch.repeat_interleave(embedding, x.shape[-2], dim=-2)# [b, encode_ch, 1, 1] -> [b, encode_ch, h, 1]
        embedding = torch.repeat_interleave(embedding, x.shape[-1], dim=-1)# [b, encode_ch, h, 1] -> [b, encode_ch, h, w]

        embedding = torch.cat((embedding, kwargs['clean']), dim=1) # [b, encode_ch, h, w], [b, c, h, w] -> [b, c+encode_ch, h, w]

        embedding = self.scale_and_bias(embedding)
        return embedding
    
    def _inverse(self, z, **kwargs):
        embedding = self._get_embeddings(z, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        z = (z - bias)/torch.exp(log_scale)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        embedding = self._get_embeddings(x, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        
        z = torch.exp(log_scale)*x + bias
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1,2,3])
        return z, log_abs_det_J_inv

### Structure-Aware Conditional Linear Layer

In [None]:
#| export


@regist_layer
class StructureAwareConditionalLinearLayer(nn.Module):
    def __init__(self, meta_encoder, structure_encoder, in_ch=3, device='cpu', name='signal_dependent_condition_linear'):
        super(StructureAwareConditionalLinearLayer, self).__init__()
        self.in_ch = in_ch
        self.iso_vals = torch.tensor([100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'

        self.meta_encoder = meta_encoder(10, in_ch*2)
        self.structure_encoder = structure_encoder(in_ch, in_ch*2)

    def _get_embeddings(self, x, **kwargs):
        b,_,_,_ = x.shape

        iso = torch.zeros([b], device=x.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.iso_vals):
            iso += torch.where(kwargs['ISO-level'] == iso_val, iso_idx, 0.0)

        cam = torch.zeros([b], device=x.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['smartphone-code'] == cam_val, cam_idx, 0.0)

        iso_one_hot = F.one_hot(iso.to(torch.int64), num_classes=self.iso_vals.shape[0]).to(torch.float32)
        cam_one_hot = F.one_hot(cam.to(torch.int64), num_classes=self.cam_vals.shape[0]).to(torch.float32)

        meta_embedding = self.meta_encoder(torch.cat((iso_one_hot, cam_one_hot), dim=1)) # [b, 10] -> [b,encode_ch]
        meta_embedding = meta_embedding.reshape((-1, self.in_ch*2, 1, 1))
        
        structure_embedding = self.structure_encoder(kwargs['clean'])
        embedding = structure_embedding * meta_embedding
        return embedding
    
    def _inverse(self, z, **kwargs):
        embedding = self._get_embeddings(z, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        z = (z - bias)/torch.exp(log_scale)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        embedding = self._get_embeddings(x, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        
        z = torch.exp(log_scale)*x + bias
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1,2,3])
        return z, log_abs_det_J_inv


### Pointwise Convs

In [None]:
#| export

@regist_layer
class PointwiseConvs(nn.Module):
    def __init__(self, in_features=3, out_features=3, feats=32, device='cpu', name='pointwise_convs'):
        super(PointwiseConvs, self).__init__()
        self.name = name
        self.device = device 
        self.body = nn.Sequential(
            nn.Conv2d(in_features, feats, kernel_size=1, stride=1, padding=0),
            self._get_basic_module(feats, feats*2, k_size=1, stride=1, padding=0),
            self._get_basic_module(feats*2, feats*2, k_size=1, stride=1, padding=0),
            self._get_basic_module(feats*2, feats, k_size=1, stride=1, padding=0),
            nn.Conv2d(feats, out_features, kernel_size=1, stride=1, padding=0),
            nn.Tanh()
        )

    def _get_basic_module(self, in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
            return nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding),
                    nn.InstanceNorm2d(out_ch, affine=True), #batch normalization?
                    nn.LeakyReLU(negative_slope, inplace=True)
            )
    
    def forward(self, x):
        return self.body(x)

### Spatial Convs

In [None]:
#| export

@regist_layer
class SpatialConvs(nn.Module):
    def __init__(self, in_features=3, out_features=3, feats=32, receptive_field=9, device='cpu', name='pointwise_convs'):
        super(SpatialConvs, self).__init__()
        self.name = name
        self.device = device 

        self.receptive_field = receptive_field

        self.body = list()
        self.body.append(nn.Conv2d(in_features, feats, kernel_size=1, stride=1, padding=0))
        self.body.append(nn.ReLU(inplace=True))

        for _ in range(self.receptive_field//2):
            self.body.append(nn.Conv2d(feats, feats, kernel_size=3, stride=1, padding=1))
            self.body.append(nn.ReLU(inplace=True))
        
        self.body.append(nn.Conv2d(feats, out_features, kernel_size=1, stride=1, padding=0))
        self.body.append(nn.Tanh())
        self.body = nn.Sequential(*self.body)

    def _get_basic_module(self, in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
            return nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding),
                    nn.InstanceNorm2d(out_ch, affine=True), #batch normalization?
                    nn.LeakyReLU(negative_slope, inplace=True)
            )
    
    def forward(self, x):
        return self.body(x)

## Noise Extraction

In [None]:
@regist_layer
class NoiseExtraction(nn.Module):
    def __init__(self, device='cpu', name='noise_extraction'):
        super(NoiseExtraction, self).__init__()
        self.name = name
        self.device = device

    def _inverse(self, z, **kwargs):
        x = z + kwargs['clean']
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        z = x - kwargs['clean']
        ldj = torch.zeros(x.shape[0], device=self.device)
        return z, ldj

## Residual Net (to be moved)

In [None]:
#| export

from torch.nn import functional as F, init

class ResidualBlock(nn.Module):
    """A general-purpose residual block. Works only with 1-dim inputs."""

    def __init__(self,
                 features,
                 context_features,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False,
                 zero_initialization=True):
        super().__init__()
        self.activation = activation

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList([
                #nn.BatchNorm1d(features, eps=1e-3, track_running_stats=False)
                nn.BatchNorm1d(features, eps=1e-3)
                for _ in range(2)
            ])
        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)
        self.linear_layers = nn.ModuleList([
            nn.Linear(features, features)
            for _ in range(2)
        ])
        if dropout_probability > 0.:
            self.dropout = nn.Dropout(p=dropout_probability)
        else:
            self.dropout = None
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        if self.dropout:
            temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(
                torch.cat(
                    (temps, self.context_layer(context)),
                    dim=1
                ),
                dim=1
            )
        return inputs + temps


In [None]:
#| export

@regist_layer
class ResidualNet(nn.Module):
    """A general-purpose residual network. Works only with 1-dim inputs."""

    def __init__(self,
                 in_features,
                 out_features,
                 hidden_features,
                 context_features=None,
                 num_blocks=2,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False):
        super().__init__()
        self.hidden_features = hidden_features
        self.context_features = context_features
        if context_features is not None:
            self.initial_layer = nn.Linear(in_features + context_features, hidden_features)
        else:
            self.initial_layer = nn.Linear(in_features, hidden_features)
        self.blocks = nn.ModuleList([
            ResidualBlock(
                features=hidden_features,
                context_features=context_features,
                activation=activation,
                dropout_probability=dropout_probability,
                use_batch_norm=use_batch_norm,
            ) for _ in range(num_blocks)
        ])
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, inputs, context=None):
        if context is None:
            temps = self.initial_layer(inputs)
        else:
            temps = self.initial_layer(
                torch.cat((inputs, context), dim=1)
            )
        for block in self.blocks:
            temps = block(temps, context=context)
        outputs = self.final_layer(temps)
        return outputs

# Noise Flow Layers


In [None]:
#| export

# class Gain(Flow):
#     """
#     Gain & Offset flow layer
#     """

#     def __init__(self, shape):
#         """Constructor

        
#         """
#         super().__init__()
#         self.shape = shape
#         self.gain = AffineConstFlow(self.shape)

#     def forward(self, z, **kwargs):
#         return self.gain(z)

#     def inverse(self, z, **kwargs):
#         return self.gain.inverse(z)

In [None]:
# channels = 1
# hidden_channels = 16

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# x = torch_randn(1, channels, 16, 16).to(device)
# print(x.device)

# # tst =  AffineSdn(x.shape[1:]).to(device)
# tst = Unconditional(channels=x.shape[1],hidden_channels = 16,split_mode='channel' if x.shape[1] != 1 else 'checkerboard').to(device)
# # tst = Gain(x.shape[1:]).to(device)  
# print(tst)
# kwargs = {}; kwargs['clean'] = x
# y, _ = tst(x,**kwargs)
# test_eq(y.shape, x.shape)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()