# Layers

> custom layers


In [15]:
#| default_exp layers

In [16]:
#| hide
from nbdev.showdoc import *

In [17]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [18]:
 #| export

from fastai.vision.all import nn, torch, np, Path, get_image_files, Image 
import normflows as nf
from Noise2Model.utils import attributesFromDict
# from Noise2Model.models import DnCNN, UNet
from Noise2Model.utils import gaussian_diag #, batch_PSNR, weights_init_orthogonal #, weights_init_kaiming


## Normalizing Flows


In [19]:
from matplotlib import pyplot as plt
from tqdm import tqdm

### Example Glow Layers


In [20]:
from torch import randn as torch_randn
from fastai.vision.all import test_eq

In [21]:
# The split mode 'channel' necesitates an even number of channels
channels = 2
hidden_channels = 16

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

x = torch_randn(1, channels, 16, 16).to(device)

tst = nf.flows.AffineCouplingBlock(nf.nets.ConvNet2d([channels//2,hidden_channels,hidden_channels,channels], (3,1,3))).to(device)
print(tst)
y, _ = tst(x)
test_eq(y.shape, x.shape)

tst = nf.flows.GlowBlock(channels, hidden_channels, split_mode='channel').to('cuda')
print(tst)
y, _ = tst(x)
test_eq(y.shape, x.shape)

AffineCouplingBlock(
  (flows): ModuleList(
    (0): Split()
    (1): AffineCoupling(
      (param_map): ConvNet2d(
        (net): Sequential(
          (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): LeakyReLU(negative_slope=0.0)
          (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
          (3): LeakyReLU(negative_slope=0.0)
          (4): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (2): Merge()
  )
)
GlowBlock(
  (flows): ModuleList(
    (0): AffineCouplingBlock(
      (flows): ModuleList(
        (0): Split()
        (1): AffineCoupling(
          (param_map): ConvNet2d(
            (net): Sequential(
              (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): LeakyReLU(negative_slope=0.0)
              (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
              (3): LeakyReLU(negative_slope=0.0)
              (4): Conv2d(16, 2, kerne

In [22]:
# The 'checkerboard' split mode can be used also with odd numbers of channels
channels = 1
hidden_channels = 16

x = torch_randn(1, channels, 16, 16).to(device)

tst = nf.flows.GlowBlock(channels, hidden_channels, split_mode='checkerboard').to(device)
print(tst)
y, _ = tst(x)
test_eq(y.shape, x.shape)

GlowBlock(
  (flows): ModuleList(
    (0): AffineCouplingBlock(
      (flows): ModuleList(
        (0): Split()
        (1): AffineCoupling(
          (param_map): ConvNet2d(
            (net): Sequential(
              (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): LeakyReLU(negative_slope=0.0)
              (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
              (3): LeakyReLU(negative_slope=0.0)
              (4): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
          )
        )
        (2): Merge()
      )
    )
    (1): ActNorm()
  )
)


# Noise Flow Layers


In [23]:
#| export

from normflows.flows import GlowBlock, AffineConstFlow, Flow


In [24]:
#| export
class AffineSdn(Flow):
    """
    Sdn flow layer
    """

    def __init__(self, shape):
        """Constructor

        
        """
        super().__init__()
        self.shape = shape
        self.affine = AffineConstFlow(self.shape)

    def forward(self, z, **kwargs):
        if 'clean' in kwargs:
            y = kwargs['clean']
        else:
            y = z
        y_, log_det = self.affine(y)
        return z * torch.sqrt(y_), log_det

    def inverse(self, z, **kwargs):
        y_, log_det = self.affine(y)
        return z / torch.sqrt(y_), log_det

In [25]:
#| export
class Unconditional(Flow):
    """
    Unconditional flow layer
    """

    def __init__(self, channels, hidden_channels, split_mode):
        """Constructor

        
        """
        super().__init__()
        attributesFromDict(locals())
        self.glow = GlowBlock(channels=self.channels, hidden_channels = self.hidden_channels, split_mode=self.split_mode)

    def forward(self, z, **kwargs):
        z, log_det_tot = self.glow(z)
        return z, log_det_tot

    def inverse(self, z, **kwargs):
        z, log_det_tot = self.glow.inverse(z)
        

In [26]:
#| export
class Gain(Flow):
    """
    Gain & Offset flow layer
    """

    def __init__(self, shape):
        """Constructor

        
        """
        super().__init__()
        self.shape = shape
        self.gain = AffineConstFlow(self.shape)

    def forward(self, z, **kwargs):
        return self.gain(z)

    def inverse(self, z, **kwargs):
        return self.gain.inverse(z)

In [27]:
channels = 1
hidden_channels = 16

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

x = torch_randn(1, channels, 16, 16).to(device)
print(x.device)

# tst =  AffineSdn(x.shape[1:]).to(device)
tst = Unconditional(channels=x.shape[1],hidden_channels = 16,split_mode='channel' if x.shape[1] != 1 else 'checkerboard').to(device)
# tst = Gain(x.shape[1:]).to(device)  
print(tst)
kwargs = {}; kwargs['clean'] = x
y, _ = tst(x,**kwargs)
test_eq(y.shape, x.shape)

cuda:0
Unconditional(
  (glow): GlowBlock(
    (flows): ModuleList(
      (0): AffineCouplingBlock(
        (flows): ModuleList(
          (0): Split()
          (1): AffineCoupling(
            (param_map): ConvNet2d(
              (net): Sequential(
                (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): LeakyReLU(negative_slope=0.0)
                (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
                (3): LeakyReLU(negative_slope=0.0)
                (4): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              )
            )
          )
          (2): Merge()
        )
      )
      (1): ActNorm()
    )
  )
)


In [28]:
#| hide
import nbdev; nbdev.nbdev_export()