# Layers

> normalizing flow layers


In [1]:
#| default_exp layers

In [2]:
#| hide
from nbdev.showdoc import *

ModuleNotFoundError: No module named 'nbdev'

In [None]:
 #| export

import os
from importlib import import_module

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

In [None]:
 #| export
 
flow_layer_class_dict = {}

def regist_layer(layer_class):
    layer_name = layer_class.__name__.lower()
    assert not layer_name in flow_layer_class_dict, 'there is already registered layer: %s in flow_layer_class_dict.' % layer_name
    flow_layer_class_dict[layer_name] = layer_class
    return layer_class



In [None]:
def get_flow_layer(layer_name:str):
    layer_name = layer_name.lower()
    return flow_layer_class_dict[layer_name]


## Normalizing Flows


### Dequantization

#### Uniform Dequantization

In [None]:
#| export

@regist_layer
class UniformDequantization(nn.Module):
    def __init__(self, alpha=1e-5, num_bits=8, device='cpu', name='uniform_dequantization'):
        """
        Uniform dequantization layer for flows.
        
        Inputs:
            alpha - small constant used to scale the input to avoid very small and large values.
            num_bits - Number of bits used for quantization.
            device - Device to run computations on (default: 'cpu').
            name - Name of the module (default: 'uniform_dequantization').
        """
        super(UniformDequantization, self).__init__()
        self.alpha = alpha
        self.num_bits = num_bits
        self.quantization_bins = 2 ** num_bits
        self.register_buffer(
            'ldj_per_dim',
            - num_bits * torch.log(torch.tensor(2.0, device=device, dtype=torch.float))
        )
        self.name = name

    def _ldj(self, shape):
        batch_size = shape[0]
        num_dims = shape[1:].numel()
        ldj = self.ldj_per_dim * num_dims
        return ldj.repeat(batch_size)

    def _inverse(self, z, **kwargs):
        z = self._sigmoid_inverse(z)
        z = (self.quantization_bins * z).floor().clamp(min=0, max=self.quantization_bins - 1)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        z, ldj = self._dequant(x)
        z, ldj = self._sigmoid(z, ldj)
        return z, ldj
    
    def _sigmoid(self, z, ldj):
        # Applies an invertible sigmoid transformation
        ldj += (-z - 2 * F.softplus(-z)).sum(dim=[1, 2, 3])
        z = torch.sigmoid(z)
        # Reversing scaling for numerical stability
        ldj -= torch.log(torch.tensor(1.0 - self.alpha, device=z.device, dtype=z.dtype)) * z.flatten(1).shape[1]
        z = (z - 0.5 * self.alpha) / (1 - self.alpha)
        return z, ldj
    
    def _sigmoid_inverse(self, z):
        # Inverse sigmoid transformation
        z = z * (1 - self.alpha) + 0.5 * self.alpha  # Scale to prevent boundaries 0 and 1
        z = torch.log(z) - torch.log(1 - z)
        return z
    
    def _dequant(self, x):
        # Transform discrete values to continuous volumes
        u = torch.rand(x.shape, device=x.device, dtype=x.dtype)
        z = (x + u) / self.quantization_bins
        ldj = self._ldj(z.shape)
        return z, ldj



#### Variational Dequantization (TO DO)

In [None]:
#| export

# class VariationalDequantization(UniformDequantization):

#     def __init__(self, var_flows, alpha=1e-5, num_bits=8, device='cpu', name='variational_dequantization'):
#         """
#         Variational dequantization layer inheriting from UniformDequantization.
        
#         Inputs:
#             var_flows - A list of flow transformations to use for modeling q(u|x).
#             alpha - Small constant used to scale the input to avoid very small and large values.
#             num_bits - Number of bits used for quantization.
#             device - Device to run computations on (default: 'cpu').
#             name - Name of the module (default: 'variational_dequantization').
#         """
#         super().__init__(alpha=alpha, num_bits=num_bits, device=device, name=name)
#         self.flows = nn.ModuleList(var_flows)
        
#     def _dequant(self, x):
#         # Transform discrete values to continuous volumes
#         u = torch.rand(x.shape, device=x.device, dtype=x.dtype)
#         img = (x / (self.quantization_bins - 1)) * 2 - 1 # We condition the flows on x, i.e. the original image
        
#         u = self._sigmoid_inverse(u)
#         for flow in self.flows:
#             u, ldj = flow(u, ldj, orig_img=img)
#         u, ldj = self._sigmoid(u, ldj)
        
#         z = (x + u) / self.quantization_bins
#         ldj = self._ldj(z.shape)
#         return z, ldj

#     # def dequant(self, z, ldj):
#     #     z = z.to(torch.float32)
#     #     img = (z / 255.0) * 2 - 1 # We condition the flows on x, i.e. the original image

#     #     # Prior of u is a uniform distribution as before
#     #     # As most flow transformations are defined on [-infinity,+infinity], we apply an inverse sigmoid first.
#     #     deq_noise = torch.rand_like(z).detach()
#     #     deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=True)
#     #     for flow in self.flows:
#     #         deq_noise, ldj = flow(deq_noise, ldj, reverse=False, orig_img=img)
#     #     deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=False)

#     #     # After the flows, apply u as in standard dequantization
#     #     z = (z + deq_noise) / 256.0
#     #     ldj -= np.log(256.0) * np.prod(z.shape[1:])
#     #     return z, ldj


### Conditional Linear

In [None]:
#| export

@regist_layer
class ConditionalLinear(nn.Module):
    def __init__(self, device='cpu', name='linear_transformation'):
        super(ConditionalLinear, self).__init__()
        self.name = name

        self.iso_vals = torch.tensor([100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'

        self.log_scale = nn.Parameter(torch.zeros(25), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(25), requires_grad=True)

    def _inverse(self, z, **kwargs):
        gain_one_hot = self.iso_vals == torch.mean(kwargs['iso'], dim=[1, 2, 3]).unsqueeze(1)
        iso = gain_one_hot.nonzero()[:, 1]
        cam_one_hot = self.cam_vals == torch.mean(kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
        cam = cam_one_hot.nonzero()[:, 1]
        iso_cam = iso * 5 + cam
        iso_cam = torch.arange(0, 25).cuda() == iso_cam.unsqueeze(1)

        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]

        x = (z - bias.reshape((-1, 1, 1, 1))) / torch.exp(log_scale.reshape((-1, 1, 1, 1)))
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        gain_one_hot = self.iso_vals == torch.mean(kwargs['iso'], dim=[1, 2, 3]).unsqueeze(1)
        iso = gain_one_hot.nonzero()[:, 1]
        cam_one_hot = self.cam_vals == torch.mean(kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
        cam = cam_one_hot.nonzero()[:, 1]
        iso_cam = iso * 5 + cam
        iso_cam = torch.arange(0, 25).cuda() == iso_cam.unsqueeze(1)

        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]
        
        z = x * torch.exp(log_scale.reshape((-1, 1, 1, 1))) + bias.reshape((-1, 1, 1, 1))
        log_abs_det_J_inv = log_scale * np.prod(x.shape[1:])

        return z, log_abs_det_J_inv

### Conditional Linear $e^2$

In [None]:
#| export

@regist_layer
class ConditionalLinearExp2(nn.Module):
    def __init__(self, in_ch=3, device='cpu', name='linear_transformation_exp2'):
        super(ConditionalLinearExp2, self).__init__()
        self.name = name
        self.device = device 

        self.iso_vals = torch.tensor([100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'

        self.log_scale = nn.Parameter(torch.zeros(25, in_ch), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(25, in_ch), requires_grad=True)

    def _inverse(self, z, **kwargs):
        b,_,_,_ = z.shape

        iso = torch.zeros([b], device=self.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.iso_vals):
            iso += torch.where(kwargs['ISO-level'] == iso_val, iso_idx, 0.0)

        cam = torch.zeros([b], device=self.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['smartphone-code'] == cam_val, cam_idx, 0.0)

        iso_cam = iso * self.iso_vals.shape[0] + cam
        iso_cam = torch.arange(0, self.iso_vals.shape[0] * self.cam_vals.shape[0]).cuda() == iso_cam.unsqueeze(1)

        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(z.shape[0], dim=0)[iso_cam]

        x = (z - bias.reshape((-1, z.shape[1], 1, 1))) / torch.exp(log_scale.reshape((-1, z.shape[1], 1, 1)))
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        b,_,_,_ = x.shape

        iso = torch.zeros([b], device=x.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.iso_vals):
            iso += torch.where(kwargs['ISO-level'] == iso_val, iso_idx, 0.0)

        cam = torch.zeros([b], device=x.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['smartphone-code'] == cam_val, cam_idx, 0.0)

        iso_cam = iso * self.iso_vals.shape[0] + cam
        iso_cam = torch.arange(0, self.iso_vals.shape[0] * self.cam_vals.shape[0]).cuda() == iso_cam.unsqueeze(1)

        log_scale = self.log_scale.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]
        bias = self.bias.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)[iso_cam]
        z = x * torch.exp(log_scale.reshape((-1, x.shape[1], 1, 1))) + bias.reshape((-1, x.shape[1], 1, 1))
        log_abs_det_J_inv = torch.sum(log_scale * np.prod(x.shape[2:]), dim=1)

        return z, log_abs_det_J_inv

### Signal Dependent Conditional Linear

In [None]:
#| export


@regist_layer
class SignalDependentConditionalLinear(nn.Module):
    def __init__(self, meta_encoder, scale_and_bias, in_ch=3, device='cpu', name='signal_dependent_condition_linear'):
        super(SignalDependentConditionalLinear, self).__init__()
        self.name = name
        self.device = device 

        self.in_ch = in_ch
        self.iso_vals = torch.tensor([100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'
        self.encode_ch = 3
        self.meta_encoder = meta_encoder(10, self.encode_ch)
        self.scale_and_bias = scale_and_bias(self.encode_ch+in_ch, in_ch*2) # scale, bias per channels

    def _get_embeddings(self, x, **kwargs):
        b,_,_,_ = x.shape

        iso = torch.zeros([b], device=x.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.iso_vals):
            iso += torch.where(kwargs['ISO-level'] == iso_val, iso_idx, 0.0)

        cam = torch.zeros([b], device=x.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['smartphone-code'] == cam_val, cam_idx, 0.0)

        iso_one_hot = F.one_hot(iso.to(torch.int64), num_classes=self.iso_vals.shape[0]).to(torch.float32)
        cam_one_hot = F.one_hot(cam.to(torch.int64), num_classes=self.cam_vals.shape[0]).to(torch.float32)

        embedding = self.meta_encoder(torch.cat((iso_one_hot, cam_one_hot), dim=1)) # [b, 10] -> [b,encode_ch]
        embedding = embedding.reshape((-1, self.encode_ch, 1, 1))
        embedding = torch.repeat_interleave(embedding, x.shape[-2], dim=-2)# [b, encode_ch, 1, 1] -> [b, encode_ch, h, 1]
        embedding = torch.repeat_interleave(embedding, x.shape[-1], dim=-1)# [b, encode_ch, h, 1] -> [b, encode_ch, h, w]

        embedding = torch.cat((embedding, kwargs['clean']), dim=1) # [b, encode_ch, h, w], [b, c, h, w] -> [b, c+encode_ch, h, w]

        embedding = self.scale_and_bias(embedding)
        return embedding
    
    def _inverse(self, z, **kwargs):
        embedding = self._get_embeddings(z, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        z = (z - bias)/torch.exp(log_scale)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        embedding = self._get_embeddings(x, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        
        z = torch.exp(log_scale)*x + bias
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1,2,3])
        return z, log_abs_det_J_inv

### Structure-Aware Conditional Linear Layer

In [None]:
#| export


@regist_layer
class StructureAwareConditionalLinearLayer(nn.Module):
    def __init__(self, meta_encoder, structure_encoder, in_ch=3, device='cpu', name='signal_dependent_condition_linear'):
        super(StructureAwareConditionalLinearLayer, self).__init__()
        self.in_ch = in_ch
        self.iso_vals = torch.tensor([100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)
        self.cam_vals = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32, device=device)  # 'IP', 'GP', 'S6', 'N6', 'G4'

        self.meta_encoder = meta_encoder(10, in_ch*2)
        self.structure_encoder = structure_encoder(in_ch, in_ch*2)

    def _get_embeddings(self, x, **kwargs):
        b,_,_,_ = x.shape

        iso = torch.zeros([b], device=x.device, dtype=torch.float32)
        for iso_idx, iso_val in enumerate(self.iso_vals):
            iso += torch.where(kwargs['ISO-level'] == iso_val, iso_idx, 0.0)

        cam = torch.zeros([b], device=x.device, dtype=torch.float32)
        for cam_idx, cam_val in enumerate(self.cam_vals):
            cam += torch.where(kwargs['smartphone-code'] == cam_val, cam_idx, 0.0)

        iso_one_hot = F.one_hot(iso.to(torch.int64), num_classes=self.iso_vals.shape[0]).to(torch.float32)
        cam_one_hot = F.one_hot(cam.to(torch.int64), num_classes=self.cam_vals.shape[0]).to(torch.float32)

        meta_embedding = self.meta_encoder(torch.cat((iso_one_hot, cam_one_hot), dim=1)) # [b, 10] -> [b,encode_ch]
        meta_embedding = meta_embedding.reshape((-1, self.in_ch*2, 1, 1))
        
        structure_embedding = self.structure_encoder(kwargs['clean'])
        embedding = structure_embedding * meta_embedding
        return embedding
    
    def _inverse(self, z, **kwargs):
        embedding = self._get_embeddings(z, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        z = (z - bias)/torch.exp(log_scale)
        return z

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        embedding = self._get_embeddings(x, **kwargs)

        log_scale = embedding[:,:self.in_ch, ...]
        bias = embedding[:,self.in_ch:, ...]
        
        z = torch.exp(log_scale)*x + bias
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1,2,3])
        return z, log_abs_det_J_inv


### Pointwise Convs

In [None]:
#| export

@regist_layer
class PointwiseConvs(nn.Module):
    def __init__(self, in_features=3, out_features=3, feats=32, device='cpu', name='pointwise_convs'):
        super(PointwiseConvs, self).__init__()
        self.name = name
        self.device = device 
        self.body = nn.Sequential(
            nn.Conv2d(in_features, feats, kernel_size=1, stride=1, padding=0),
            self._get_basic_module(feats, feats*2, k_size=1, stride=1, padding=0),
            self._get_basic_module(feats*2, feats*2, k_size=1, stride=1, padding=0),
            self._get_basic_module(feats*2, feats, k_size=1, stride=1, padding=0),
            nn.Conv2d(feats, out_features, kernel_size=1, stride=1, padding=0),
            nn.Tanh()
        )

    def _get_basic_module(self, in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
            return nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding),
                    nn.InstanceNorm2d(out_ch, affine=True), #batch normalization?
                    nn.LeakyReLU(negative_slope, inplace=True)
            )
    
    def forward(self, x):
        return self.body(x)

### Spatial Convs

In [None]:
#| export

@regist_layer
class SpatialConvs(nn.Module):
    def __init__(self, in_features=3, out_features=3, feats=32, receptive_field=9, device='cpu', name='pointwise_convs'):
        super(SpatialConvs, self).__init__()
        self.name = name
        self.device = device 

        self.receptive_field = receptive_field

        self.body = list()
        self.body.append(nn.Conv2d(in_features, feats, kernel_size=1, stride=1, padding=0))
        self.body.append(nn.ReLU(inplace=True))

        for _ in range(self.receptive_field//2):
            self.body.append(nn.Conv2d(feats, feats, kernel_size=3, stride=1, padding=1))
            self.body.append(nn.ReLU(inplace=True))
        
        self.body.append(nn.Conv2d(feats, out_features, kernel_size=1, stride=1, padding=0))
        self.body.append(nn.Tanh())
        self.body = nn.Sequential(*self.body)

    def _get_basic_module(self, in_ch, out_ch, k_size=1, stride=1, padding=1, negative_slope=0.2):
            return nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=k_size, stride=stride, padding=padding),
                    nn.InstanceNorm2d(out_ch, affine=True), #batch normalization?
                    nn.LeakyReLU(negative_slope, inplace=True)
            )
    
    def forward(self, x):
        return self.body(x)

## Noise Extraction

In [None]:
@regist_layer
class NoiseExtraction(nn.Module):
    def __init__(self, device='cpu', name='noise_extraction'):
        super(NoiseExtraction, self).__init__()
        self.name = name
        self.device = device

    def _inverse(self, z, **kwargs):
        x = z + kwargs['clean']
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        z = x - kwargs['clean']
        ldj = torch.zeros(x.shape[0], device=self.device)
        return z, ldj

## Residual Net (to be moved)

In [None]:
#| export

from torch.nn import functional as F, init

class ResidualBlock(nn.Module):
    """A general-purpose residual block. Works only with 1-dim inputs."""

    def __init__(self,
                 features,
                 context_features,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False,
                 zero_initialization=True):
        super().__init__()
        self.activation = activation

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList([
                #nn.BatchNorm1d(features, eps=1e-3, track_running_stats=False)
                nn.BatchNorm1d(features, eps=1e-3)
                for _ in range(2)
            ])
        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)
        self.linear_layers = nn.ModuleList([
            nn.Linear(features, features)
            for _ in range(2)
        ])
        if dropout_probability > 0.:
            self.dropout = nn.Dropout(p=dropout_probability)
        else:
            self.dropout = None
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        if self.dropout:
            temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(
                torch.cat(
                    (temps, self.context_layer(context)),
                    dim=1
                ),
                dim=1
            )
        return inputs + temps


In [None]:
#| export

@regist_layer
class ResidualNet(nn.Module):
    """A general-purpose residual network. Works only with 1-dim inputs."""

    def __init__(self,
                 in_features,
                 out_features,
                 hidden_features,
                 context_features=None,
                 num_blocks=2,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False):
        super().__init__()
        self.hidden_features = hidden_features
        self.context_features = context_features
        if context_features is not None:
            self.initial_layer = nn.Linear(in_features + context_features, hidden_features)
        else:
            self.initial_layer = nn.Linear(in_features, hidden_features)
        self.blocks = nn.ModuleList([
            ResidualBlock(
                features=hidden_features,
                context_features=context_features,
                activation=activation,
                dropout_probability=dropout_probability,
                use_batch_norm=use_batch_norm,
            ) for _ in range(num_blocks)
        ])
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, inputs, context=None):
        if context is None:
            temps = self.initial_layer(inputs)
        else:
            temps = self.initial_layer(
                torch.cat((inputs, context), dim=1)
            )
        for block in self.blocks:
            temps = block(temps, context=context)
        outputs = self.final_layer(temps)
        return outputs

# Noise Flow Layers


In [None]:
#| export

# class Gain(Flow):
#     """
#     Gain & Offset flow layer
#     """

#     def __init__(self, shape):
#         """Constructor

        
#         """
#         super().__init__()
#         self.shape = shape
#         self.gain = AffineConstFlow(self.shape)

#     def forward(self, z, **kwargs):
#         return self.gain(z)

#     def inverse(self, z, **kwargs):
#         return self.gain.inverse(z)

In [None]:
# channels = 1
# hidden_channels = 16

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# x = torch_randn(1, channels, 16, 16).to(device)
# print(x.device)

# # tst =  AffineSdn(x.shape[1:]).to(device)
# tst = Unconditional(channels=x.shape[1],hidden_channels = 16,split_mode='channel' if x.shape[1] != 1 else 'checkerboard').to(device)
# # tst = Gain(x.shape[1:]).to(device)  
# print(tst)
# kwargs = {}; kwargs['clean'] = x
# y, _ = tst(x,**kwargs)
# test_eq(y.shape, x.shape)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()