# NoiseFlow

> noiseflow


In [1]:
#| default_exp noiseflow

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [4]:
 #| export

from fastai.vision.all import nn, torch, np, Path, get_image_files, Image 
import normflows as nf
from Noise2Model.utils import attributesFromDict
# from Noise2Model.models import DnCNN, UNet
from Noise2Model.utils import gaussian_diag #, batch_PSNR, weights_init_orthogonal #, weights_init_kaiming


# Noise Flow


In [5]:
#| export
from Noise2Model.layers import Unconditional, Gain, AffineSdn

In [6]:
#| export
class NoiseFlow(nn.Module):

    def __init__(self, x_shape, arch, device='cuda'): # device may be removed
        super(NoiseFlow, self).__init__()
        attributesFromDict(locals( ))
        self.model = nn.ModuleList(self.noise_flow_arch(x_shape))

    def noise_flow_arch(self, x_shape):
        arch_lyrs = self.arch.split('|')  # e.g., unc|sdn|unc|gain|unc
        bijectors = []
        for i, lyr in enumerate(arch_lyrs):
            # is_last_layer = False

            if lyr == 'unc':
                print('|-AffineCoupling')
                print(self.device)
                bijectors.append(
                    Unconditional(
                        channels=x_shape[1],
                        hidden_channels = 16,
                        split_mode='channel' if x_shape[1] != 1 else 'checkerboard'
                    )#.to(self.device)
                )
            elif lyr == 'sdn':
                print('|-SignalDependant')
                bijectors.append(
                    AffineSdn(x_shape[1:])#.to(self.device)
                )
            elif lyr == 'gain':
                print('|-Gain')
                bijectors.append(
                    Gain(x_shape[1:])#.to(self.device)    # Gain and offset
                )

        return bijectors

    def forward(self, x, **kwargs):
        z = x
        objective = torch.zeros(x.shape[0], dtype=torch.float32, device=self.device)
        for bijector in self.model:
            z, log_abs_det_J_inv = bijector.forward(z, **kwargs)
            objective += log_abs_det_J_inv

            if 'writer' in kwargs.keys():
                kwargs['writer'].add_scalar('model/' + bijector.name, torch.mean(log_abs_det_J_inv), kwargs['step'])
        return z#, objective

    def _loss(self, x, **kwargs):
        z, objective = self.forward(x, **kwargs)
        # base measure
        logp, _ = self.prior("prior", x)

        log_z = logp(z)
        objective += log_z

        if 'writer' in kwargs.keys():
            kwargs['writer'].add_scalar('model/log_z', torch.mean(log_z), kwargs['step'])
            kwargs['writer'].add_scalar('model/z', torch.mean(z), kwargs['step'])
        nobj = - objective
        # std. dev. of z
        # mu_z = torch.mean(x, dim=[1, 2, 3])
        var_z = torch.var(x, dim=[1, 2, 3])
        sd_z = torch.mean(torch.sqrt(var_z))

        return nobj, sd_z

    def loss(self, x, **kwargs):
        
        # if 'writer' in kwargs.keys():
        #     batch_average = torch.mean(x, dim=0)
        #     kwargs['writer'].add_histogram('real_noise', batch_average, kwargs['step'])
        #     kwargs['writer'].add_scalar('real_noise_std', torch.std(batch_average), kwargs['step'])

        nll, sd_z = self._loss(x=x, **kwargs)
        nll_dim = torch.mean(nll) / np.prod(x.shape[1:])
        # nll_dim = torch.mean(nll)      # The above line should be uncommented

        return nll_dim, sd_z

    def inverse(self, z, **kwargs):
        x = z
        for bijector in reversed(self.model):
            x = bijector._inverse(x, **kwargs)
        return x
    
    def sample(self, eps_std=None, **kwargs):
        _, sample = self.prior("prior", kwargs['clean'])
        z = sample(eps_std)
        x = self.inverse(z, **kwargs)
        batch_average = torch.mean(x, dim=0)
        if 'writer' in kwargs.keys():
            kwargs['writer'].add_histogram('sample_noise', batch_average, kwargs['step'])
            kwargs['writer'].add_scalar('sample_noise_std', torch.std(batch_average), kwargs['step'])

        return x

    def prior(self, name, x):
        n_z = x.shape[1]
        h = torch.zeros([x.shape[0]] +  [2 * n_z] + list(x.shape[2:4]), device=x.device)
        pz = gaussian_diag(h[:, :n_z, :, :], h[:, n_z:, :, :])

        def logp(z1):
            objective = pz.logp(z1)
            return objective

        def sample(eps_std=None):
            if eps_std is not None:
                z = pz.sample2(pz.eps * torch.reshape(eps_std, [-1, 1, 1, 1]))
            else:
                z = pz.sample
            return z

        return logp, sample

In [7]:
from torch import randn as torch_randn
from fastai.vision.all import test_eq

In [10]:
x = torch_randn(16,1,64,64).to('cuda')
xdim = len(x.shape)-2

tst = NoiseFlow(x.shape, arch='unc').to('cuda')
mods = list(tst.children())
print(mods)
# test_eq(tst(x.cuda()).shape, [16, 1, 32, 64, 64])
logp = tst(x, clean=x)
print(logp.shape)
test_eq(logp.shape, x.shape)

|-AffineCoupling
cuda
[ModuleList(
  (0): Unconditional(
    (glow): GlowBlock(
      (flows): ModuleList(
        (0): AffineCouplingBlock(
          (flows): ModuleList(
            (0): Split()
            (1): AffineCoupling(
              (param_map): ConvNet2d(
                (net): Sequential(
                  (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                  (1): LeakyReLU(negative_slope=0.0)
                  (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
                  (3): LeakyReLU(negative_slope=0.0)
                  (4): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                )
              )
            )
            (2): Merge()
          )
        )
        (1): ActNorm()
      )
    )
  )
)]
torch.Size([16, 1, 64, 64])


In [11]:
x.view(x.size(1), -1).shape

torch.Size([1, 65536])

In [12]:
#| hide
import nbdev; nbdev.nbdev_export()