# Noise2NoiseFlow

> noiseflow


In [None]:
#| default_exp noise2noiseflow

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [None]:
#| export

from fastai.vision.all import nn, torch, np
from Noise2Model.utils import attributesFromDict
from Noise2Model.models import DnCNN, UNet
from Noise2Model.utils import gaussian_diag, batch_PSNR, weights_init_orthogonal #, weights_init_kaiming


In [None]:
#| export

# from Noise2Model.layers.conv2d1x1 import Conv2d1x1
# from Noise2Model.layers.affine_coupling import AffineCoupling, ShiftAndLogScale
# from Noise2Model.layers.signal_dependant import SignalDependant
# from Noise2Model.layers.gain import Gain
# from Noise2Model.layers.utils import SdnModelScale

### Noise2NoiseFlow


In [None]:
#| export

class Noise2NoiseFlow(nn.Module):
    def __init__(self, x_shape, arch, flow_permutation, param_inits, lu_decomp, denoiser_model='unet', dncnn_num_layers=9, lmbda=262144):
        super(Noise2NoiseFlow, self).__init__()

        self.noise_flow = NoiseFlow(x_shape, arch, flow_permutation, param_inits, lu_decomp)
        if denoiser_model == 'dncnn':
            self.denoiser = DnCNN(x_shape[0], dncnn_num_layers)
            # TODO: self.dncnn should be named self.denoiser by definition, but I changed it here since i needed it to be backward compatible for loading previous models for sampling.
            # self.denoiser.apply(weights_init_kaiming)
            self.denoiser.apply(weights_init_orthogonal)
        elif denoiser_model == 'unet':
            self.denoiser = UNet(in_channels=4, out_channels=4)

        self.denoiser_loss = nn.MSELoss(reduction='mean')
        self.lmbda = lmbda

    def denoise(self, noisy, clip=True):
        denoised = self.denoiser(noisy)
        if clip:
            denoised = torch.clamp(denoised, 0., 1.)

        return denoised

    def forward_u(self, noisy, **kwargs):
        denoised = self.denoise(noisy)
        kwargs.update({'clean' : denoised})
        noise = noisy - denoised

        z, objective = self.noise_flow.forward(noise, **kwargs)

        return z, objective, denoised

    def symmetric_loss(self, noisy1, noisy2, **kwargs):
        denoised1 = self.denoise(noisy1)
        denoised2 = self.denoise(noisy2)
        
        noise1 = noisy1 - denoised2
        noise2 = noisy2 - denoised1

        kwargs.update({'clean' : denoised2})
        nll1, _ = self.noise_flow.loss(noise1, **kwargs)

        kwargs.update({'clean' : denoised1})
        nll2, _ = self.noise_flow.loss(noise2, **kwargs)

        nll = (nll1 + nll2) / 2
        return nll

    def symmetric_loss_with_mse(self, noisy1, noisy2, **kwargs):
        denoised1 = self.denoise(noisy1, clip=False)
        denoised2 = self.denoise(noisy2, clip=False)

        mse_loss1 = self.denoiser_loss(denoised1, noisy2)
        mse_loss2 = self.denoiser_loss(denoised2, noisy1)

        denoised1 = torch.clamp(denoised1, 0., 1.)
        denoised2 = torch.clamp(denoised2, 0., 1.)
        
        noise1 = noisy1 - denoised2
        noise2 = noisy2 - denoised1

        kwargs.update({'clean' : denoised2})
        nll1, _ = self.noise_flow.loss(noise1, **kwargs)

        kwargs.update({'clean' : denoised1})
        nll2, _ = self.noise_flow.loss(noise2, **kwargs)

        nll = (nll1 + nll2) / 2
        mse_loss = (mse_loss1 + mse_loss2) / 2

        return nll, mse_loss


    def _loss_u(self, noisy1, noisy2, **kwargs):
        denoised1 = self.denoise(noisy1, clip=False)

        mse_loss = self.denoiser_loss(denoised1, noisy2)

        denoised1 = torch.clamp(denoised1, 0., 1.)

        noise = noisy1 - denoised1
        kwargs.update({'clean' : denoised1})
        nll, _ = self.noise_flow.loss(noise, **kwargs)

        return nll, mse_loss

    def loss_u(self, noisy1, noisy2, **kwargs):
        # return self.symmetric_loss(noisy1, noisy2, **kwargs), 0, 0

        # nll, mse = self._loss_u(noisy1, noisy2, **kwargs)
        nll, mse = self.symmetric_loss_with_mse(noisy1, noisy2, **kwargs)

        return nll + self.lmbda * mse, nll.item(), mse.item()
        # return nll, nll.item(), mse.item()

    def forward_s(self, noise, **kwargs):
        return self.noise_flow.forward(noise, **kwargs)

    def _loss_s(self, x, **kwargs):
        return self.noise_flow._loss(x, **kwargs)

    def loss_s(self, x, **kwargs):
        return self.noise_flow.loss(x, **kwargs)

    def mse_loss(self, noisy, clean, **kwargs):
        denoised = self.denoise(noisy, clip=False)
        mse_loss = self.denoiser_loss(denoised, clean)
        psnr = batch_PSNR(denoised, clean, 1.)
        return mse_loss.item(), psnr

    def sample(self, eps_std=None, **kwargs):
        return self.noise_flow.sample(eps_std, **kwargs)


In [None]:
def init_params():
    npcam = 3
    c_i = 1.0
    beta1_i = -5.0 / c_i
    beta2_i = 0.0
    gain_params_i = np.ndarray([5])
    gain_params_i[:] = -5.0 / c_i
    cam_params_i = np.ndarray([npcam, 5])
    cam_params_i[:, :] = 1.0
    return (c_i, beta1_i, beta2_i, gain_params_i, cam_params_i)

x = torch_randn(16,1,64,64)
xdim = len(x.shape)-2

tst = Noise2NoiseFlow(x.shape[1:], arch='gain', flow_permutation=0, param_inits=init_params(), lu_decomp=0)
mods = list(tst.children())
print(mods)
# test_eq(tst(x.cuda()).shape, [16, 1, 32, 64, 64])
logp, sample = tst.forward_s(x.cuda())
print(logp.shape)
print(sample.shape)

z, objective, denoised = tst.forward_u(x.cuda())
print(z.shape)
print(objective.shape)
print(denoised.shape)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()