# NoiseFlow Layers: Affine coupling

> noiseflow


In [1]:
#| default_exp layers.affine_coupling

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export

from fastai.vision.all import nn, torch, ConvLayer
from torch.nn import functional as F, init
from Noise2Model.utils import attributesFromDict

In [4]:
from torch import randn as torch_randn
from fastai.vision.all import test_eq

### Affine coupling


In [5]:
#| export
class ShiftAndLogScale(nn.Module):
    def __init__(self, num_in, num_out, width=4, ndim=2, shift_only=False, activation=nn.ReLU, device='cpu'):
        super(ShiftAndLogScale, self).__init__()
        attributesFromDict(locals())
        self.scale = nn.Parameter(torch.full((1,), 1e-4, device=device))

        self.conv_1 = ConvLayer(self.num_in, self.width, ndim=self.ndim,
                                ks=3, act_cls=activation, padding=1, init=nn.init.normal_)
        self.conv_2 = ConvLayer(self.width, self.width, ndim=self.ndim,
                                ks=1, act_cls=activation, padding=0, init=nn.init.normal_)
        self.net = nn.Sequential(self.conv_1, self.conv_2)

        padding = (1,) * (2 + self.ndim) + (0, 1)
        # self.padding = nn.ConstantPad3d((1, 1, 1, 1, 0, 1), 0.)
        self.padding = nn.ConstantPad3d(padding, 0.)

        self.conv_3 = ConvLayer(self.width + 1, self.num_out, ndim=self.ndim,
                                ks=3, act_cls=None, padding=0, init=nn.init.zeros_)
        self.logs = nn.Parameter(torch.zeros(
            [1, self.num_out, 1, 1], device=device))

    def forward(self, x):
        x = self.net(x)

        x = self.padding(x)
        x[:, self.width - 1, :1, :] = 1.0
        x[:, self.width - 1, -1:, :] = 1.0
        x[:, self.width - 1, :, :1] = 1.0
        x[:, self.width - 1, :, -1:] = 1.0

        x = self.conv_3(x)
        x *= torch.exp(self.logs * 3)

        if self.shift_only:
            return x, torch.zeros_like(x)

        shift, log_scale = torch.split(x, int(x.shape[1] / 2), dim=1)
        log_scale = self.scale * torch.tanh(log_scale)

        return shift, log_scale

In [6]:
x = torch_randn(16, 2, 64, 64)
# xdim = len(x.shape)-2
shift, log_scale = ShiftAndLogScale(2, 4)(x)
test_eq(shift.shape, [16, 2, 64, 64])

In [7]:
#| export
class AffineCoupling(nn.Module):
    def __init__(self, x_shape, shift_and_log_scale, name="real_nvp", device='cuda'):
        super(AffineCoupling, self).__init__()
        self.x_shape = x_shape
        self.ic, self.i0, self.i1 = x_shape
        self._shift_and_log_scale = shift_and_log_scale(
            num_in=self.ic // 2, num_out=2*(self.ic - self.ic // 2), device=device, ndim=len(x_shape))
        self.name = name

    def _inverse(self, z, **kwargs):
        z0 = z[:, :self.ic // 2, :, :]
        z1 = z[:, self.ic // 2:, :, :]
        shift, log_scale = self._shift_and_log_scale(z0)
        x1 = z1
        x1 = (z1 - shift) * torch.exp(-log_scale)
        x = torch.cat([z0, x1], dim=1)
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        x0 = x[:, :self.ic // 2, :, :]
        x1 = x[:, self.ic // 2:, :, :]

        writer = kwargs['writer'] if 'writer' in kwargs.keys() else None
        step = kwargs['step'] if 'step' in kwargs.keys() else None

        shift, log_scale = self._shift_and_log_scale(x0, writer, step)

        if 'writer' in kwargs.keys():
            writer.add_scalar('model/' + self.name +
                              '_log_scale_mean', torch.mean(log_scale), step)
            writer.add_scalar('model/' + self.name +
                              '_log_scale_min', torch.min(log_scale), step)
            writer.add_scalar('model/' + self.name +
                              '_log_scale_max', torch.max(log_scale), step)

        z1 = x1 * torch.exp(log_scale) + shift
        z = torch.cat([x0, z1], dim=1)
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1, 2, 3])
        return z, log_abs_det_J_inv

In [8]:
# x = torch_randn(16,2,64,64,32)
# # xdim = len(x.shape)-2
# tst = AffineCoupling_v1(x.shape[2:],ShiftAndLogScale)._forward_and_log_det_jacobian
# y = tst(x)
# print(y.shape)

In [9]:
#| export
class ConditionalAffineCoupling(nn.Module):
    def __init__(self, x_shape, shift_and_log_scale, encoder, name="conditional_coupling", device='cpu'):
        super(ConditionalAffineCoupling, self).__init__()
        self.x_shape = x_shape
        self.ic, self.i0, self.i1 = x_shape
        num_out = 2 * (self.ic - self.ic // 2)
        self._shift_and_log_scale = shift_and_log_scale(
            num_in=self.ic // 2 + self.ic, num_out=num_out, device=device)
        self.name = name

        # 'IP', 'GP', 'S6', 'N6', 'G4'
        self.cam_vals = torch.tensor(
            [0, 1, 2, 3, 4], dtype=torch.float32, device=device)
        self.iso_vals = torch.tensor(
            [100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)

        self._encoder = encoder(10, 1)

    def _inverse(self, z, **kwargs):
        gain_one_hot = self.iso_vals == torch.mean(
            kwargs['iso'], dim=[1, 2, 3]).unsqueeze(1)
        gain_one_hot = torch.where(gain_one_hot, 1., 0.)
        cam_one_hot = self.cam_vals == torch.mean(
            kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
        cam_one_hot = torch.where(cam_one_hot, 1., 0.)
        embedding = self._encoder(
            torch.cat((gain_one_hot, cam_one_hot), dim=1))
        embedding = embedding.reshape((-1, 1, 1, 1))

        z0 = z[:, :self.ic // 2, :, :]
        z1 = z[:, self.ic // 2:, :, :]
        shift, log_scale = self._shift_and_log_scale(
            torch.cat((z0, kwargs['clean']), dim=1))
        log_scale *= embedding
        x1 = z1
        x1 = (z1 - shift) * torch.exp(-log_scale)
        x = torch.cat([z0, x1], dim=1)
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        gain_one_hot = self.iso_vals == torch.mean(
            kwargs['iso'], dim=[1, 2, 3]).unsqueeze(1)
        gain_one_hot = torch.where(gain_one_hot, 1., 0.)
        cam_one_hot = self.cam_vals == torch.mean(
            kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
        cam_one_hot = torch.where(cam_one_hot, 1., 0.)
        embedding = self._encoder(
            torch.cat((gain_one_hot, cam_one_hot), dim=1))
        embedding = embedding.reshape((-1, 1, 1, 1))

        x0 = x[:, :self.ic // 2, :, :]
        x1 = x[:, self.ic // 2:, :, :]
        shift, log_scale = self._shift_and_log_scale(
            torch.cat((x0, kwargs['clean']), dim=1))
        log_scale *= embedding
        z1 = x1 * torch.exp(log_scale) + shift
        z = torch.cat([x0, z1], dim=1)
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1, 2, 3])
        return z, log_abs_det_J_inv

In [10]:
#| export
class ConditionalAffine(nn.Module):
    def __init__(self, x_shape, shift_and_log_scale, encoder, name="conditional_coupling", device='cpu', only_clean=False):
        super(ConditionalAffine, self).__init__()
        self.x_shape = x_shape
        self.ic, self.i0, self.i1 = x_shape
        num_out = 2 * self.ic
        self._shift_and_log_scale = shift_and_log_scale(
            num_in=self.ic, num_out=num_out, device=device)
        self.name = name
        self.only_clean = only_clean

        if not self.only_clean:
            # 'IP', 'GP', 'S6', 'N6', 'G4'
            self.cam_vals = torch.tensor(
                [0, 1, 2, 3, 4], dtype=torch.float32, device=device)
            self.iso_vals = torch.tensor(
                [100, 400, 800, 1600, 3200], dtype=torch.float32, device=device)

            self._encoder = encoder(10, 1)

    def _inverse(self, z, **kwargs):
        if not self.only_clean:
            gain_one_hot = self.iso_vals == torch.mean(
                kwargs['iso'], dim=[1, 2, 3]).unsqueeze(1)
            gain_one_hot = torch.where(gain_one_hot, 1., 0.)
            cam_one_hot = self.cam_vals == torch.mean(
                kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
            cam_one_hot = torch.where(cam_one_hot, 1., 0.)
            embedding = self._encoder(
                torch.cat((gain_one_hot, cam_one_hot), dim=1))
            embedding = embedding.reshape((-1, 1, 1, 1))

        shift, log_scale = self._shift_and_log_scale(kwargs['clean'])

        if not self.only_clean:
            log_scale *= embedding

        x = (z - shift) * torch.exp(-log_scale)
        return x

    def _forward_and_log_det_jacobian(self, x, **kwargs):
        if not self.only_clean:
            gain_one_hot = self.iso_vals == torch.mean(
                kwargs['iso'], dim=[1, 2, 3]).unsqueeze(1)
            gain_one_hot = torch.where(gain_one_hot, 1., 0.)
            cam_one_hot = self.cam_vals == torch.mean(
                kwargs['cam'], dim=[1, 2, 3]).unsqueeze(1)
            cam_one_hot = torch.where(cam_one_hot, 1., 0.)
            embedding = self._encoder(
                torch.cat((gain_one_hot, cam_one_hot), dim=1))
            embedding = embedding.reshape((-1, 1, 1, 1))

        shift, log_scale = self._shift_and_log_scale(kwargs['clean'])
        if not self.only_clean:
            log_scale *= embedding
        z = x * torch.exp(log_scale) + shift
        log_abs_det_J_inv = torch.sum(log_scale, dim=[1, 2, 3])
        return z, log_abs_det_J_inv

In [11]:
#| hide
import nbdev
nbdev.nbdev_export()