# Layers

> normalizing flow layers


In [1]:
#| default_exp layers

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
from fastai.imports import *

In [4]:
 #| export

import os
from importlib import import_module

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

from Noise2Model.utils import attributesFromDict, compute_index, compute_one_hot

In [5]:
 #| export
 
flow_layer_class_dict = {}

def regist_layer(layer_class):
    layer_name = layer_class.__name__.lower()
    assert not layer_name in flow_layer_class_dict, 'there is already registered layer: %s in flow_layer_class_dict.' % layer_name
    flow_layer_class_dict[layer_name] = layer_class
    return layer_class



In [6]:
 #| export
 
def get_flow_layer(layer_name:str):
    layer_name = layer_name.lower()
    return flow_layer_class_dict[layer_name]


## Normalizing Flows


### Dequantization

#### Uniform Dequantization

In [7]:
#| export

@regist_layer
class UniformDequantization(nn.Module):
    def __init__(self, alpha=1e-5, num_bits=8, device='cpu', name='uniform_dequantization'):
        """
        Uniform dequantization layer for flows.
        
        Args:
            alpha (float): Small constant used to scale the input to avoid boundary values (default: 1e-5).
            num_bits (int): Number of bits used for quantization (default: 8).
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'uniform_dequantization').
        """
        super(UniformDequantization, self).__init__()
        self.alpha = alpha
        self.num_bits = num_bits
        self.quantization_bins = 2 ** num_bits
        # Precompute the log-determinant of the Jacobian per dimension
        self.register_buffer(
            'ldj_per_dim',
            - num_bits * torch.log(torch.tensor(2.0, device=device, dtype=torch.float))
        )
        self.name = name

    def _ldj(self, shape):
        """
        Computes the log-determinant of the Jacobian for a given shape.

        Args:
            shape (torch.Size): Shape of the input tensor.

        Returns:
            torch.Tensor: Log-determinant of the Jacobian repeated for the batch size.
        """
        batch_size = shape[0]
        num_dims = shape[1:].numel()
        ldj = self.ldj_per_dim * num_dims
        return ldj.repeat(batch_size)

    def _inverse(self, z, **kwargs):
        """
        Applies the inverse dequantization transformation to the input.

        Args:
            z (torch.Tensor): Input tensor to transform.

        Returns:
            torch.Tensor: Dequantized tensor.
        """
        # Apply the inverse sigmoid transformation to z
        z = self._sigmoid_inverse(z)
        # Quantize the values to integer bins
        z = (self.quantization_bins * z).floor().clamp(min=0, max=self.quantization_bins - 1)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Applies the forward dequantization transformation and computes the log-determinant of the Jacobian.

        Args:
            x (torch.Tensor): Input tensor to transform.

        Returns:
            tuple: Transformed tensor and the log-determinant of the Jacobian.
        """
        z, ldj = self._dequant(x.to(torch.float32))
        # Uncomment the next line if an additional sigmoid transformation is needed
        # z, ldj = self._sigmoid(z, ldj)
        return z, ldj
    
    def _sigmoid(self, z, ldj):
        """
        Applies an invertible sigmoid transformation to the input.

        Args:
            z (torch.Tensor): Input tensor to transform.
            ldj (torch.Tensor): Log-determinant of the Jacobian.

        Returns:
            tuple: Transformed tensor and updated log-determinant of the Jacobian.
        """
        # Update ldj with the sigmoid transformation's contribution
        ldj += (-z - 2 * F.softplus(-z)).sum(dim=[1, 2, 3])
        z = torch.sigmoid(z)
        # Adjust the log-determinant for the alpha scaling
        ldj -= torch.log(torch.tensor(1.0 - self.alpha, device=z.device, dtype=z.dtype)) * z.flatten(1).shape[1]
        # Scale z to avoid boundaries
        z = (z - 0.5 * self.alpha) / (1 - self.alpha)
        return z, ldj
    
    def _sigmoid_inverse(self, z):
        """
        Applies the inverse of the sigmoid transformation to the input.

        Args:
            z (torch.Tensor): Input tensor to transform.

        Returns:
            torch.Tensor: Transformed tensor.
        """
        # Scale z to avoid boundaries 0 and 1
        z = z * (1 - self.alpha) + 0.5 * self.alpha
        # Apply the logit function (inverse sigmoid)
        z = torch.log(z) - torch.log(1 - z)
        return z
    
    def _dequant(self, x):
        """
        Transforms discrete values to continuous volumes for dequantization.

        Args:
            x (torch.Tensor): Input tensor with discrete values.

        Returns:
            tuple: Dequantized tensor and the log-determinant of the Jacobian.
        """
        # Add uniform noise to dequantize
        u = torch.rand(x.shape, device=x.device, dtype=x.dtype)
        z = (x + u) / self.quantization_bins
        # Compute the log-determinant of the Jacobian
        ldj = self._ldj(z.shape)
        return z, ldj




In [8]:
a = torch.randint(256,[4, 4])
b, _ = UniformDequantization()._forward_and_log_det_jacobian(a)
print(a)
print(b)

tensor([[113, 241, 139,  86],
        [216,  40,  75, 155],
        [ 64, 121,  71, 149],
        [242,  78, 126,  78]])
tensor([[0.4440, 0.9438, 0.5451, 0.3393],
        [0.8448, 0.1598, 0.2931, 0.6079],
        [0.2501, 0.4752, 0.2793, 0.5826],
        [0.9473, 0.3078, 0.4926, 0.3081]])


#### Variational Dequantization (TO DO)

In [9]:
#| export

# class VariationalDequantization(UniformDequantization):

#     def __init__(self, var_flows, alpha=1e-5, num_bits=8, device='cpu', name='variational_dequantization'):
#         """
#         Variational dequantization layer inheriting from UniformDequantization.
        
#         Inputs:
#             var_flows - A list of flow transformations to use for modeling q(u|x).
#             alpha - Small constant used to scale the input to avoid very small and large values.
#             num_bits - Number of bits used for quantization.
#             device - Device to run computations on (default: 'cpu').
#             name - Name of the module (default: 'variational_dequantization').
#         """
#         super().__init__(alpha=alpha, num_bits=num_bits, device=device, name=name)
#         self.flows = nn.ModuleList(var_flows)
        
#     def _dequant(self, x):
#         # Transform discrete values to continuous volumes
#         u = torch.rand(x.shape, device=x.device, dtype=x.dtype)
#         img = (x / (self.quantization_bins - 1)) * 2 - 1 # We condition the flows on x, i.e. the original image
        
#         u = self._sigmoid_inverse(u)
#         for flow in self.flows:
#             u, ldj = flow(u, ldj, orig_img=img)
#         u, ldj = self._sigmoid(u, ldj)
        
#         z = (x + u) / self.quantization_bins
#         ldj = self._ldj(z.shape)
#         return z, ldj

#     # def dequant(self, z, ldj):
#     #     z = z.to(torch.float32)
#     #     img = (z / 255.0) * 2 - 1 # We condition the flows on x, i.e. the original image

#     #     # Prior of u is a uniform distribution as before
#     #     # As most flow transformations are defined on [-infinity,+infinity], we apply an inverse sigmoid first.
#     #     deq_noise = torch.rand_like(z).detach()
#     #     deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=True)
#     #     for flow in self.flows:
#     #         deq_noise, ldj = flow(deq_noise, ldj, reverse=False, orig_img=img)
#     #     deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=False)

#     #     # After the flows, apply u as in standard dequantization
#     #     z = (z + deq_noise) / 256.0
#     #     ldj -= np.log(256.0) * np.prod(z.shape[1:])
#     #     return z, ldj


### Conditional Linear

In [10]:
#| export

@regist_layer
class ConditionalLinear(nn.Module):
    """
    Conditional linear transformation module.

    Applies different scales and biases based on average pixel size and camera values provided
    in the input. Supports both forward and inverse transformations.

    Attributes:
        name (str): Name of the transformation.
        setup_code (torch.Tensor): Predefined set of pixel sizes.
        exp_times (torch.Tensor): Predefined set of camera values.
        log_scale (torch.nn.Parameter): Learnable log-scale parameters.
        bias (torch.nn.Parameter): Learnable bias parameters.

    Methods:
        _inverse(z, **kwargs):
            Performs the inverse transformation based on the input 'z' and conditionals.

        _forward_and_log_det_jacobian(x, **kwargs):
            Performs the forward transformation and computes the log determinant of the Jacobian.
    """

    def __init__(self, device='cpu', name='linear_transformation', codes={'code': [1, 2, 3]}):
        """
        Initializes the ConditionalLinear module.

        Args:
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'linear_transformation').
        """
        super(ConditionalLinear, self).__init__()
        self.name = name
        self.device = device
        self.codes = codes

        # Learnable parameters
        self.par_num = 1
        for k,v in codes.items():
            self.par_num *= len(v)

        self.log_scale = nn.Parameter(torch.zeros(self.par_num), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(self.par_num), requires_grad=True)
        
    def _computeIndex(self, b, **kwargs):
        return compute_index(self.codes, device=self.device)(b, **kwargs)
        
    def _inverse(self, z, **kwargs):
        """
        Performs the inverse transformation based on the input 'z' and conditionals.

        Args:
            z (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
            **kwargs: Additional keyword arguments containing 'pixel' and 'cam' tensors.

        Returns:
            torch.Tensor: Output tensor after applying the inverse transformation.
        """
        batch_size = z.shape[0]
        idx = self._computeIndex(batch_size,**kwargs)
        idx = torch.arange(0, self.par_num).to(self.device) == idx.unsqueeze(1)

        # Select corresponding log_scale and bias values based on idx indices
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[idx]
        bias = self.bias.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[idx]

        # Compute the inverse transformation
        x = (z - bias.reshape((-1, 1, 1, 1))) / torch.exp(log_scale.reshape((-1, 1, 1, 1)))
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Performs the forward transformation and computes the log determinant of the Jacobian.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
            **kwargs: Additional keyword arguments containing 'pixel' and 'cam' tensors.

        Returns:
            torch.Tensor: Output tensor after applying the forward transformation.
            torch.Tensor: Log determinant of the Jacobian.
        """
        batch_size = x.shape[0]
        idx = self._computeIndex(batch_size,**kwargs)
        idx = torch.arange(0, self.par_num).to(self.device) == idx.unsqueeze(1)

        # Select corresponding log_scale and bias values based on idx indices
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[idx]
        bias = self.bias.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[idx]

        # Compute the forward transformation
        z = x * torch.exp(log_scale.reshape((-1, 1, 1, 1))) + bias.reshape((-1, 1, 1, 1))

        # Compute the log determinant of the Jacobian
        log_abs_det_J_inv = log_scale * np.prod(x.shape[1:])

        return z, log_abs_det_J_inv



In [11]:
batch_size = 2
channels = 1
height = 2
width = 2
device = 'cpu'

codes = {
        'exposure-time': torch.tensor([10, 50, 100], dtype=torch.float32, device=device),
        'optical-setup': torch.tensor([0, 1], dtype=torch.float32).to(device),
        # 'camera': torch.tensor([0, 1], dtype=torch.float32).to(device)
    }

x = torch.randn(batch_size, channels, height, width).to(device)
setup_idx = torch.tensor([1] * batch_size, dtype=torch.float32).to(device)
time_idx = torch.tensor([10] * batch_size, dtype=torch.float32).to(device)

kwargs = {'optical-setup': setup_idx, 'exposure-time': time_idx}

print(compute_index(codes, device=device)(batch_size, **kwargs))

# Forward transformation
z, log_det_jacobian = ConditionalLinear(device=device, codes=codes)._forward_and_log_det_jacobian(x, **kwargs)
test_eq(z.shape, x.shape)
test_eq(log_det_jacobian.shape, torch.Size([batch_size]))

# Inverse transformation
x_reconstructed = ConditionalLinear(device=device, codes=codes)._inverse(z, **kwargs)
test_eq(x_reconstructed.shape, x.shape)

# Check if the reconstructed input is close to the original input
assert torch.allclose(x, x_reconstructed, atol=1e-5)


tensor([1., 1.])


In [12]:
z, log_det_jacobian, x, x_reconstructed

(tensor([[[[-1.1548, -1.2173],
           [ 0.1894, -0.4818]]],
 
 
         [[[ 0.1171, -0.1363],
           [-0.4131, -0.7997]]]], grad_fn=<AddBackward0>),
 tensor([0., 0.], grad_fn=<MulBackward0>),
 tensor([[[[-1.1548, -1.2173],
           [ 0.1894, -0.4818]]],
 
 
         [[[ 0.1171, -0.1363],
           [-0.4131, -0.7997]]]]),
 tensor([[[[-1.1548, -1.2173],
           [ 0.1894, -0.4818]]],
 
 
         [[[ 0.1171, -0.1363],
           [-0.4131, -0.7997]]]], grad_fn=<DivBackward0>))

### Conditional Linear $e^2$

In [13]:
#| export

@regist_layer
class ConditionalLinearExp2(nn.Module):
    """
    Conditional linear transformation layer for flows, conditioned on specific ISO levels and setup codes.
    
    This module applies a linear transformation to the input tensor, where the transformation parameters
    (log scale and bias) are conditioned based on the pixel size and setup code provided as input. 
    The module supports both forward and inverse transformations.

    Attributes:
        name (str): Name of the module.
        device (str): Device to run computations on.
        pixel_size (tensor): pixel size used for conditioning.
        cam_vals (tensor): Predefined setup codes used for conditioning.
        log_scale (nn.Parameter): Learnable log scale parameters for the transformation.
        bias (nn.Parameter): Learnable bias parameters for the transformation.

    Methods:
        _inverse(z, **kwargs):
            Applies the inverse transformation to the input tensor z.
        
        _forward_and_log_det_jacobian(x, **kwargs):
            Applies the forward transformation to the input tensor x and computes the log determinant
            of the Jacobian of the transformation.
    """
    def __init__(self, in_ch=1, device='cpu', name='linear_transformation_exp2', codes={'code': [1, 2, 3]}):
        """
        Initializes the ConditionalLinearExp2 module with specified input channels, device, and name.

        Args:
            in_ch (int): Number of input channels. Default is 1.
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'linear_transformation_exp2').
        """
        super(ConditionalLinearExp2, self).__init__()
        self.name = name
        self.device = device 
        self.codes = codes

        # Learnable parameters
        self.par_num = 1
        for k,v in codes.items():
            self.par_num *= len(v)

        self.log_scale = nn.Parameter(torch.zeros(self.par_num, in_ch), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(self.par_num, in_ch), requires_grad=True)
        
    def _computeIndex(self, b, **kwargs):
        return compute_index(self.codes, device=self.device)(b, **kwargs)

    def _inverse(self, z, **kwargs):
        """
        Applies the inverse transformation to the input tensor z.
        
        Args:
            z (tensor): Input tensor to be inversely transformed.
            kwargs: Additional keyword arguments containing 'ISO-level' and 'setup-code' for conditioning.
        
        Returns:
            tensor: The inversely transformed tensor.
        """
        batch_size = z.shape[0]
        
        # Combine pixel and camera indices to get a unique index for each combination
        idx = self._computeIndex(batch_size,**kwargs)
        idx = torch.arange(0, self.par_num).to(self.device) == idx.unsqueeze(1)

        # Select log scale and bias parameters based on the unique index
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[idx]
        bias = self.bias.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[idx]

        # Apply the inverse transformation
        x = (z - bias.reshape((-1, z.shape[1], 1, 1))) / torch.exp(log_scale.reshape((-1, z.shape[1], 1, 1)))
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Applies the forward transformation to the input tensor x and computes the log determinant of the Jacobian.
        
        Args:
            x (tensor): Input tensor to be transformed.
            kwargs: Additional keyword arguments containing 'ISO-level' and 'setup-code' for conditioning.
        
        Returns:
            tensor: The transformed tensor.
            tensor: The log determinant of the Jacobian of the transformation.
        """
        batch_size = x.shape[0]

        # Combine pixel and camera indices to get a unique index for each combination
        idx = self._computeIndex(batch_size,**kwargs)
        idx = torch.arange(0, self.par_num).to(self.device) == idx.unsqueeze(1)

        # Select log scale and bias parameters based on the unique index
        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[idx]
        bias = self.bias.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[idx]

        # Apply the forward transformation
        z = x * torch.exp(log_scale.reshape((-1, x.shape[1], 1, 1))) + bias.reshape((-1, x.shape[1], 1, 1))
        log_abs_det_J_inv = torch.sum(log_scale * np.prod(x.shape[2:]), dim=1)

        return z, log_abs_det_J_inv


In [14]:
batch_size = 2
channels = 1
height = 2
width = 2
device = 'cpu'

codes = {
        'exposure-time': torch.tensor([10, 50, 100], dtype=torch.float32, device=device),
        'optical-setup': torch.tensor([0, 1], dtype=torch.float32).to(device),
        # 'camera': torch.tensor([0, 1], dtype=torch.float32).to(device)
    }

x = torch.randn(batch_size, channels, height, width).to(device)

kwargs = {
        'exposure-time': torch.tensor([50], dtype=torch.float32).to(device),
        'optical-setup': torch.tensor([0], dtype=torch.float32).to(device)
    }

 # Forward transformation
z, log_det_jacobian = ConditionalLinearExp2(device=device, in_ch=x.shape[1], codes=codes)._forward_and_log_det_jacobian(x, **kwargs)
test_eq(z.shape, x.shape)
test_eq(log_det_jacobian.shape, torch.Size([batch_size]))

# Inverse transformation
x_reconstructed = ConditionalLinearExp2(device=device, in_ch=x.shape[1], codes=codes)._inverse(z, **kwargs)
test_eq(x_reconstructed.shape, x.shape)

# Check if the reconstructed input is close to the original input
assert torch.allclose(x, x_reconstructed, atol=1e-5)


In [15]:
z, x, log_det_jacobian

(tensor([[[[-0.8903, -0.4137],
           [-0.5727,  1.4910]]],
 
 
         [[[ 0.1081, -0.2904],
           [-0.2454,  0.1507]]]], grad_fn=<AddBackward0>),
 tensor([[[[-0.8903, -0.4137],
           [-0.5727,  1.4910]]],
 
 
         [[[ 0.1081, -0.2904],
           [-0.2454,  0.1507]]]]),
 tensor([0., 0.], grad_fn=<SumBackward1>))

### Signal Dependent Conditional Linear

In [16]:
#| export

@regist_layer
class SignalDependentConditionalLinear(nn.Module):
    """
    Signal-dependent conditional linear transformation layer for flows.
    
    This module applies a linear transformation to the input tensor, where the transformation parameters
    (log scale and bias) are conditioned on ISO levels and smartphone codes provided as input features.
    The conditioning is performed using embeddings generated from meta encoders and scale-and-bias modules.

    Attributes:
        name (str): Name of the module.
        device (str): Device to run computations on.
        in_ch (int): Number of input channels.
        setup_codes (tensor): Predefined ISO levels used for conditioning.
        exp_times (tensor): Predefined smartphone codes used for conditioning.
        encode_ch (int): Number of channels in the embeddings generated by the meta encoder.
        meta_encoder (nn.Module): Meta encoder module to generate embeddings from ISO and camera inputs.
        scale_and_bias (nn.Module): Module to compute scale and bias parameters based on embeddings and input features.

    Methods:
        _get_embeddings(x, **kwargs):
            Generates embeddings from ISO-level and smartphone-code inputs and concatenates them with additional features.

        _inverse(z, **kwargs):
            Applies the inverse transformation to the input tensor z.

        _forward_and_log_det_jacobian(x, **kwargs):
            Applies the forward transformation to the input tensor x and computes the log determinant of the Jacobian.
    """
    def __init__(self, meta_encoder, scale_and_bias, in_ch=1, device='cpu', name='signal_dependent_condition_linear', codes={'code': [1, 2, 3]}, encode_ch = 3):
        """
        Initializes the SignalDependentConditionalLinear module with specified meta encoder, scale-and-bias module,
        input channels, device, and name.

        Args:
            meta_encoder (nn.Module): Meta encoder module to generate embeddings from ISO and camera inputs.
            scale_and_bias (nn.Module): Module to compute scale and bias parameters based on embeddings and input features.
            in_ch (int): Number of input channels. Default is 1.
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'signal_dependent_condition_linear').
        """
        super(SignalDependentConditionalLinear, self).__init__()
        self.name = name
        self.device = device 
        self.codes = codes
        self.in_ch = in_ch
        self.encode_ch = encode_ch
        
        n = 0
        for k,v in codes.items():
            n += len(v)
        self.meta_encoder = meta_encoder(n, self.encode_ch)
        self.scale_and_bias = scale_and_bias(self.encode_ch+in_ch, in_ch*2) # scale, bias per channels
        
    def _computeOneHot(self, b, **kwargs):
        return compute_one_hot(self.codes, device=self.device)(b, **kwargs)

    def _get_embeddings(self, x, **kwargs):
        """
        Generates embeddings from ISO-level and smartphone-code inputs and concatenates them with additional features.

        Args:
            x (tensor): Input tensor.
            kwargs: Additional keyword arguments containing 'ISO-level', 'smartphone-code', and 'clean'.

        Returns:
            tensor: Embeddings concatenated with additional features.
        """
        batch_size = x.shape[0]

        # Generate embeddings using the meta encoder
        embedding = self.meta_encoder(self._computeOneHot(batch_size,**kwargs))
        embedding = embedding.reshape((-1, self.encode_ch, 1, 1))
        embedding = torch.repeat_interleave(embedding, x.shape[-2], dim=-2)
        embedding = torch.repeat_interleave(embedding, x.shape[-1], dim=-1)

        # Concatenate embeddings with additional features
        embedding = torch.cat((embedding, kwargs['clean']), dim=1)

        # Compute scale and bias parameters
        embedding = self.scale_and_bias(embedding)
        return embedding
    
    def _inverse(self, z, **kwargs):
        """
        Applies the inverse transformation to the input tensor z.

        Args:
            z (tensor): Input tensor to be inversely transformed.
            kwargs: Additional keyword arguments containing 'ISO-level', 'smartphone-code', and 'clean'.

        Returns:
            tensor: The inversely transformed tensor.
        """
        embedding = self._get_embeddings(z, **kwargs)

        log_scale = embedding[:, :self.in_ch, ...]
        bias = embedding[:, self.in_ch:, ...]

        z = (z - bias) / torch.exp(log_scale)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Applies the forward transformation to the input tensor x and computes the log determinant of the Jacobian.

        Args:
            x (tensor): Input tensor to be transformed.
            kwargs: Additional keyword arguments containing 'ISO-level', 'smartphone-code', and 'clean'.

        Returns:
            tensor: The transformed tensor.
            tensor: The log determinant of the Jacobian of the transformation.
        """
        embedding = self._get_embeddings(x, **kwargs)

        log_scale = embedding[:, :self.in_ch, ...]
        bias = embedding[:, self.in_ch:, ...]
        
        z = torch.exp(log_scale) * x + bias
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1, 2, 3])
        return z, log_abs_det_J_inv


### Structure-Aware Conditional Linear Layer

In [None]:
#| export

@regist_layer
class StructureAwareConditionalLinearLayer(nn.Module):
    """
    Structure-aware conditional linear transformation layer for flows.
    
    This module applies a linear transformation to the input tensor, where the transformation parameters
    (log scale and bias) are conditioned on ISO levels and smartphone codes provided as input features.
    The conditioning involves both meta encoding and structure encoding of input features.

    Attributes:
        in_ch (int): Number of input channels.
        iso_vals (tensor): Predefined ISO levels used for conditioning.
        cam_vals (tensor): Predefined smartphone codes used for conditioning.
        meta_encoder (nn.Module): Meta encoder module to generate embeddings from ISO and camera inputs.
        structure_encoder (nn.Module): Structure encoder module to generate embeddings from input features.

    Methods:
        _get_embeddings(x, **kwargs):
            Generates embeddings from ISO-level and smartphone-code inputs and combines them using structure encoding.

        _inverse(z, **kwargs):
            Applies the inverse transformation to the input tensor z.

        _forward_and_log_det_jacobian(x, **kwargs):
            Applies the forward transformation to the input tensor x and computes the log determinant of the Jacobian.
    """
    def __init__(self, meta_encoder, structure_encoder, in_ch=1, device='cpu', name='structure_aware_condition_linear', codes={'code': [1, 2, 3]}):
        """
        Initializes the StructureAwareConditionalLinearLayer module with specified meta encoder, structure encoder,
        input channels, device, and name.

        Args:
            meta_encoder (nn.Module): Meta encoder module to generate embeddings from ISO and camera inputs.
            structure_encoder (nn.Module): Structure encoder module to generate embeddings from input features.
            in_ch (int): Number of input channels. Default is 3.
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'structure_aware_condition_linear').
        """
        super(StructureAwareConditionalLinearLayer, self).__init__()
        self.in_ch = in_ch
        self.codes = codes

        n = 0
        for k,v in codes.items():
            n += len(v)
        self.meta_encoder = meta_encoder(n, in_ch * 2)
        self.structure_encoder = structure_encoder(in_ch, in_ch * 2)
        
    def _computeOneHot(self, b, **kwargs):
        return compute_one_hot(self.codes, device=self.device)(b, **kwargs)

    def _get_embeddings(self, x, **kwargs):
        """
        Generates embeddings from ISO-level and smartphone-code inputs and combines them using structure encoding.

        Args:
            x (tensor): Input tensor.
            kwargs: Additional keyword arguments containing 'ISO-level', 'smartphone-code', and 'clean'.

        Returns:
            tensor: Combined embeddings from meta and structure encodings.
        """
        batch_size = x.shape[0]

        # Generate embeddings using the meta encoder
        meta_embedding = self.meta_encoder(self._computeOneHot(batch_size,**kwargs))
        meta_embedding = meta_embedding.reshape((-1, self.in_ch * 2, 1, 1))

        # Generate structure embeddings and combine with meta embeddings
        structure_embedding = self.structure_encoder(kwargs['clean'])
        embedding = structure_embedding * meta_embedding
        return embedding
    
    def _inverse(self, z, **kwargs):
        """
        Applies the inverse transformation to the input tensor z.

        Args:
            z (tensor): Input tensor to be inversely transformed.
            kwargs: Additional keyword arguments containing 'ISO-level', 'smartphone-code', and 'clean'.

        Returns:
            tensor: The inversely transformed tensor.
        """
        embedding = self._get_embeddings(z, **kwargs)

        log_scale = embedding[:, :self.in_ch, ...]
        bias = embedding[:, self.in_ch:, ...]
        z = (z - bias) / torch.exp(log_scale)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Applies the forward transformation to the input tensor x and computes the log determinant of the Jacobian.

        Args:
            x (tensor): Input tensor to be transformed.
            kwargs: Additional keyword arguments containing 'ISO-level', 'smartphone-code', and 'clean'.

        Returns:
            tensor: The transformed tensor.
            tensor: The log determinant of the Jacobian of the transformation.
        """
        embedding = self._get_embeddings(x, **kwargs)

        log_scale = embedding[:, :self.in_ch, ...]
        bias = embedding[:, self.in_ch:, ...]
        
        z = torch.exp(log_scale) * x + bias
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1, 2, 3])
        return z, log_abs_det_J_inv



## Convolutions

### Pointwise Convs

In [None]:
#| export

@regist_layer
class PointwiseConvs(nn.Module):
    """
    Pointwise convolutional module for neural networks.

    This module consists of a series of pointwise convolutions with instance normalization
    and LeakyReLU activation functions.

    Attributes:
        name (str): Name of the module.
        device (str): Device to run computations on.
        body (nn.Sequential): Sequential module containing the layers.

    Methods:
        _get_basic_module(in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
            Returns a basic convolutional module with instance normalization and LeakyReLU activation.

        forward(x):
            Performs forward pass through the module.
    """
    def __init__(self, in_features=1, out_features=1, feats=32, device='cpu', name='pointwise_convs'):
        """
        Initializes the PointwiseConvs module with specified parameters.

        Args:
            in_features (int): Number of input features/channels. Default is 3.
            out_features (int): Number of output features/channels. Default is 3.
            feats (int): Number of features in intermediate layers. Default is 32.
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'pointwise_convs').
        """
        super(PointwiseConvs, self).__init__()
        self.name = name
        self.device = device 
        self.body = nn.Sequential(
            nn.Conv2d(in_features, feats, kernel_size=1, stride=1, padding=0),
            self._get_basic_module(feats, feats*2, k_size=1, stride=1, padding=0),
            self._get_basic_module(feats*2, feats*2, k_size=1, stride=1, padding=0),
            self._get_basic_module(feats*2, feats, k_size=1, stride=1, padding=0),
            nn.Conv2d(feats, out_features, kernel_size=1, stride=1, padding=0),
            nn.Tanh()
        )

    def _get_basic_module(self, in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
        """
        Returns a basic convolutional module with instance normalization and LeakyReLU activation.

        Args:
            in_ch (int): Number of input channels.
            out_ch (int): Number of output channels.
            k_size (int): Kernel size of the convolution. Default is 1.
            stride (int): Stride of the convolution. Default is 1.
            padding (int): Padding of the convolution. Default is 1.
            negative_slope (float): Slope of the LeakyReLU activation function. Default is 0.2.

        Returns:
            nn.Sequential: Sequential module containing Conv2d, InstanceNorm2d, and LeakyReLU layers.
        """
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding),
            nn.InstanceNorm2d(out_ch, affine=True),  # Instance normalization
            nn.LeakyReLU(negative_slope, inplace=True)
        )
    
    def forward(self, x):
        """
        Performs forward pass through the PointwiseConvs module.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features, height, width).

        Returns:
            torch.Tensor: Output tensor after passing through the module.
        """
        return self.body(x)


### Spatial Convs

In [None]:
#| export

@regist_layer
class SpatialConvs(nn.Module):
    """
    Spatial convolutional module for neural networks.

    This module consists of a series of spatial convolutions with ReLU activation functions.

    Attributes:
        name (str): Name of the module.
        device (str): Device to run computations on.
        receptive_field (int): Size of the receptive field for spatial convolutions.
        body (nn.Sequential): Sequential module containing the layers.

    Methods:
        _get_basic_module(in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
            Returns a basic convolutional module with instance normalization and LeakyReLU activation.

        forward(x):
            Performs forward pass through the module.
    """
    def __init__(self, in_features=1, out_features=1, feats=32, receptive_field=9, device='cpu', name='spatial_convs'):
        """
        Initializes the SpatialConvs module with specified parameters.

        Args:
            in_features (int): Number of input features/channels. Default is 3.
            out_features (int): Number of output features/channels. Default is 3.
            feats (int): Number of features in intermediate layers. Default is 32.
            receptive_field (int): Size of the receptive field for spatial convolutions. Default is 9.
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'spatial_convs').
        """
        super(SpatialConvs, self).__init__()
        self.name = name
        self.device = device 
        self.receptive_field = receptive_field

        self.body = nn.Sequential()
        self.body.add_module('conv_in', nn.Conv2d(in_features, feats, kernel_size=1, stride=1, padding=0))
        self.body.add_module('relu_in', nn.ReLU(inplace=True))

        # Add spatial convolutions with ReLU activations
        for _ in range(self.receptive_field // 2):
            self.body.add_module('conv', nn.Conv2d(feats, feats, kernel_size=3, stride=1, padding=1))
            self.body.add_module('relu', nn.ReLU(inplace=True))
        
        self.body.add_module('conv_out', nn.Conv2d(feats, out_features, kernel_size=1, stride=1, padding=0))
        self.body.add_module('tanh_out', nn.Tanh())

    def _get_basic_module(self, in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
        """
        Returns a basic convolutional module with instance normalization and LeakyReLU activation.

        Args:
            in_ch (int): Number of input channels.
            out_ch (int): Number of output channels.
            k_size (int): Kernel size of the convolution. Default is 1.
            stride (int): Stride of the convolution. Default is 1.
            padding (int): Padding of the convolution. Default is 1.
            negative_slope (float): Slope of the LeakyReLU activation function. Default is 0.2.

        Returns:
            nn.Sequential: Sequential module containing Conv2d, InstanceNorm2d, and LeakyReLU layers.
        """
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding),
            nn.InstanceNorm2d(out_ch, affine=True),  # Instance normalization
            nn.LeakyReLU(negative_slope, inplace=True)
        )
    
    def forward(self, x):
        """
        Performs forward pass through the SpatialConvs module.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features, height, width).

        Returns:
            torch.Tensor: Output tensor after passing through the module.
        """
        return self.body(x)


## Noise Extraction

In [None]:
@regist_layer
class NoiseExtraction(nn.Module):
    """
    Module for noise extraction in neural networks.

    This module extracts noise by adding or subtracting the clean signal from the input.

    Attributes:
        name (str): Name of the module.
        device (str): Device to run computations on.

    Methods:
        _inverse(z, **kwargs):
            Computes the inverse operation by adding the clean signal to z.

        _forward_and_log_det_jacobian(x, **kwargs):
            Computes forward operation by subtracting the clean signal from x and returns a zero log determinant Jacobian.
    """
    def __init__(self, device='cpu', name='noise_extraction'):
        """
        Initializes the NoiseExtraction module with specified parameters.

        Args:
            device (str): Device to run computations on (default: 'cpu').
            name (str): Name of the module (default: 'noise_extraction').
        """
        super(NoiseExtraction, self).__init__()
        self.name = name
        self.device = device

    def _inverse(self, z, **kwargs):
        """
        Computes the inverse operation by adding the clean signal to z.

        Args:
            z (torch.Tensor): Input tensor of shape (batch_size, ...).
            **kwargs: Additional keyword arguments, expected 'clean' as a tensor of the same shape as z.

        Returns:
            torch.Tensor: Output tensor after adding the clean signal to z.
        """
        x = z + kwargs['clean']
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        """
        Computes forward operation by subtracting the clean signal from x and returns a zero log determinant Jacobian.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, ...).
            **kwargs: Additional keyword arguments, expected 'clean' as a tensor of the same shape as x.

        Returns:
            torch.Tensor: Output tensor after subtracting the clean signal from x.
            torch.Tensor: Log determinant Jacobian (ldj), which is zero in this case.
        """
        z = x - kwargs['clean']
        ldj = torch.zeros(x.shape[0], device=self.device)
        return z, ldj


## Residual Net (to be moved)

In [None]:
#| export

from torch.nn import functional as F, init

class ResidualBlock(nn.Module):
    """A general-purpose residual block. Works only with 1-dim inputs."""

    def __init__(self,
                 features,
                 context_features,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False,
                 zero_initialization=True):
        super().__init__()
        self.activation = activation

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList([
                #nn.BatchNorm1d(features, eps=1e-3, track_running_stats=False)
                nn.BatchNorm1d(features, eps=1e-3)
                for _ in range(2)
            ])
        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)
        self.linear_layers = nn.ModuleList([
            nn.Linear(features, features)
            for _ in range(2)
        ])
        if dropout_probability > 0.:
            self.dropout = nn.Dropout(p=dropout_probability)
        else:
            self.dropout = None
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        if self.dropout:
            temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(
                torch.cat(
                    (temps, self.context_layer(context)),
                    dim=1
                ),
                dim=1
            )
        return inputs + temps


In [None]:
#| export

@regist_layer
class ResidualNet(nn.Module):
    """A general-purpose residual network. Works only with 1-dim inputs."""

    def __init__(self,
                 in_features,
                 out_features,
                 hidden_features,
                 context_features=None,
                 num_blocks=2,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False):
        super().__init__()
        self.hidden_features = hidden_features
        self.context_features = context_features
        if context_features is not None:
            self.initial_layer = nn.Linear(in_features + context_features, hidden_features)
        else:
            self.initial_layer = nn.Linear(in_features, hidden_features)
        self.blocks = nn.ModuleList([
            ResidualBlock(
                features=hidden_features,
                context_features=context_features,
                activation=activation,
                dropout_probability=dropout_probability,
                use_batch_norm=use_batch_norm,
            ) for _ in range(num_blocks)
        ])
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, inputs, context=None):
        if context is None:
            temps = self.initial_layer(inputs)
        else:
            temps = self.initial_layer(
                torch.cat((inputs, context), dim=1)
            )
        for block in self.blocks:
            temps = block(temps, context=context)
        outputs = self.final_layer(temps)
        return outputs

# Noise Flow Layers


In [None]:
#| export

# class Gain(Flow):
#     """
#     Gain & Offset flow layer
#     """

#     def __init__(self, shape):
#         """Constructor

        
#         """
#         super().__init__()
#         self.shape = shape
#         self.gain = AffineConstFlow(self.shape)

#     def forward(self, z, **kwargs):
#         return self.gain(z)

#     def inverse(self, z, **kwargs):
#         return self.gain.inverse(z)

In [None]:
# channels = 1
# hidden_channels = 16

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# x = torch_randn(1, channels, 16, 16).to(device)
# print(x.device)

# # tst =  AffineSdn(x.shape[1:]).to(device)
# tst = Unconditional(channels=x.shape[1],hidden_channels = 16,split_mode='channel' if x.shape[1] != 1 else 'checkerboard').to(device)
# # tst = Gain(x.shape[1:]).to(device)  
# print(tst)
# kwargs = {}; kwargs['clean'] = x
# y, _ = tst(x,**kwargs)
# test_eq(y.shape, x.shape)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()