# NoiseFlow

> noiseflow


In [None]:
#| default_exp noiseflow

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [None]:
#| export

from fastai.vision.all import nn, torch, np
from Noise2Model.utils import attributesFromDict
from Noise2Model.models import DnCNN, UNet
from Noise2Model.utils import gaussian_diag, batch_PSNR, weights_init_orthogonal #, weights_init_kaiming


In [None]:
#| export

# from Noise2Model.layers.conv2d1x1 import Conv2d1x1
# from Noise2Model.layers.affine_coupling import AffineCoupling, ShiftAndLogScale
# from Noise2Model.layers.signal_dependant import SignalDependant
# from Noise2Model.layers.gain import Gain
# from Noise2Model.layers.utils import SdnModelScale

### Noise Flow


In [None]:
#| export
class NoiseFlow(nn.Module):

    def __init__(self, x_shape, arch, flow_permutation, param_inits, lu_decomp):
        super(NoiseFlow, self).__init__()
        attributesFromDict(locals( ))
        self.model = nn.ModuleList(self.noise_flow_arch(x_shape))

    def noise_flow_arch(self, x_shape):
        arch_lyrs = self.arch.split('|')  # e.g., unc|sdn|unc|gain|unc
        bijectors = []
        for i, lyr in enumerate(arch_lyrs):
            # is_last_layer = False

            if lyr == 'unc':
                if self.flow_permutation == 0:
                    pass
                elif self.flow_permutation == 1:
                    print('|-Conv2d1x1')
                    bijectors.append(
                        Conv2d1x1(
                            num_channels=x_shape[0],
                            LU_decomposed=self.decomp,
                            name='Conv2d_1x1_{}'.format(i)
                        )
                    )
                else:
                    print('|-No permutation specified. Not using any.')
                    
                print('|-AffineCoupling')
                bijectors.append(
                    AffineCoupling(
                        x_shape=x_shape,
                        shift_and_log_scale=ShiftAndLogScale,
                        name='unc_%d' % i
                    )
                )
            # elif lyr == 'lt':
            #     print('|-LinearTransfomation')
            #     bijectors.append(
            #         LinearTransformation(
            #             name='lt_{}'.format(i),
            #             device='cuda'
            #         )
            #     )
            elif lyr == 'sdn':
                print('|-SignalDependant')
                bijectors.append(
                    SignalDependant(
                        name='sdn_%d' % i,
                        scale=SdnModelScale,
                        param_inits=self.param_inits
                    )
                )
            elif lyr == 'gain':
                print('|-Gain')
                bijectors.append(
                    Gain(name='gain_%d' % i)
                )

        return bijectors

    def forward(self, x, **kwargs):
        z = x
        objective = torch.zeros(x.shape[0], dtype=torch.float32, device=x.device)
        for bijector in self.model:
            z, log_abs_det_J_inv = bijector._forward_and_log_det_jacobian(z, **kwargs)
            objective += log_abs_det_J_inv

            if 'writer' in kwargs.keys():
                kwargs['writer'].add_scalar('model/' + bijector.name, torch.mean(log_abs_det_J_inv), kwargs['step'])
        return z, objective

    def _loss(self, x, **kwargs):
        z, objective = self.forward(x, **kwargs)
        # base measure
        logp, _ = self.prior("prior", x)

        log_z = logp(z)
        objective += log_z

        if 'writer' in kwargs.keys():
            kwargs['writer'].add_scalar('model/log_z', torch.mean(log_z), kwargs['step'])
            kwargs['writer'].add_scalar('model/z', torch.mean(z), kwargs['step'])
        nobj = - objective
        # std. dev. of z
        # mu_z = torch.mean(x, dim=[1, 2, 3])
        var_z = torch.var(x, dim=[1, 2, 3])
        sd_z = torch.mean(torch.sqrt(var_z))

        return nobj, sd_z

    def loss(self, x, **kwargs):
        
        # if 'writer' in kwargs.keys():
        #     batch_average = torch.mean(x, dim=0)
        #     kwargs['writer'].add_histogram('real_noise', batch_average, kwargs['step'])
        #     kwargs['writer'].add_scalar('real_noise_std', torch.std(batch_average), kwargs['step'])

        nll, sd_z = self._loss(x=x, **kwargs)
        nll_dim = torch.mean(nll) / np.prod(x.shape[1:])
        # nll_dim = torch.mean(nll)      # The above line should be uncommented

        return nll_dim, sd_z

    def inverse(self, z, **kwargs):
        x = z
        for bijector in reversed(self.model):
            x = bijector._inverse(x, **kwargs)
        return x
    
    def sample(self, eps_std=None, **kwargs):
        _, sample = self.prior("prior", kwargs['clean'])
        z = sample(eps_std)
        x = self.inverse(z, **kwargs)
        batch_average = torch.mean(x, dim=0)
        if 'writer' in kwargs.keys():
            kwargs['writer'].add_histogram('sample_noise', batch_average, kwargs['step'])
            kwargs['writer'].add_scalar('sample_noise_std', torch.std(batch_average), kwargs['step'])

        return x

    def prior(self, name, x):
        n_z = x.shape[1]
        h = torch.zeros([x.shape[0]] +  [2 * n_z] + list(x.shape[2:4]), device=x.device)
        pz = gaussian_diag(h[:, :n_z, :, :], h[:, n_z:, :, :])

        def logp(z1):
            objective = pz.logp(z1)
            return objective

        def sample(eps_std=None):
            if eps_std is not None:
                z = pz.sample2(pz.eps * torch.reshape(eps_std, [-1, 1, 1, 1]))
            else:
                z = pz.sample
            return z

        return logp, sample

In [None]:
# from torch import randn as torch_randn
# from fastai.vision.all import test_eq

In [None]:
# def init_params():
#     npcam = 3
#     c_i = 1.0
#     beta1_i = -5.0 / c_i
#     beta2_i = 0.0
#     gain_params_i = np.ndarray([5])
#     gain_params_i[:] = -5.0 / c_i
#     cam_params_i = np.ndarray([npcam, 5])
#     cam_params_i[:, :] = 1.0
#     return (c_i, beta1_i, beta2_i, gain_params_i, cam_params_i)

# x = torch_randn(16,1,64,64)
# xdim = len(x.shape)-2

# tst = NoiseFlow(x.shape[1:], arch='gain', flow_permutation=0, param_inits=init_params(), lu_decomp=0)
# mods = list(tst.children())
# print(mods)
# # test_eq(tst(x.cuda()).shape, [16, 1, 32, 64, 64])
# logp, sample = tst(x.cuda())
# print(logp.shape)
# print(sample.shape)

|-Gain
[ModuleList(
  (0): Gain()
)]
torch.Size([16, 1, 64, 64])
torch.Size([16])
