# Noise2NoiseFlow

> noiseflow


In [18]:
#| default_exp noise2noiseflow

In [19]:
#| hide
from nbdev.showdoc import *

In [20]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [21]:
#| export

from fastai.vision.all import nn, torch, np
from Noise2Model.utils import attributesFromDict
from Noise2Model.models import DnCNN, UNet
from Noise2Model.utils import gaussian_diag, batch_PSNR, weights_init_orthogonal #, weights_init_kaiming
from Noise2Model.noiseflow import NoiseFlow  


### Noise2NoiseFlow


In [22]:
# denoiser = UNet(in_channels=4, out_channels=4)
# denoiser = DnCNN(x_shape[0], dncnn_num_layers=9)

In [23]:
#| export

class Noise2NoiseFlow(nn.Module):
    def __init__(self, 
                 x_shape, 
                 arch, 
                 denoiser=UNet(depth=3, in_channels=1, out_channels=1),
                 #torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
                 lmbda=262144):
        super(Noise2NoiseFlow, self).__init__()
        attributesFromDict(locals( ))

        self.noise_flow = NoiseFlow(x_shape, arch)#.to(self.device)
        # self.denoiser = self.denoiser.to(self.device)
        
        if denoiser._get_name()=='DnCNN': self.denoiser.apply(weights_init_orthogonal)

        self.denoiser_loss = nn.MSELoss(reduction='mean')
        self.lmbda = lmbda

    def denoise(self, noisy, clip=True):
        denoised = self.denoiser(noisy)
        if clip: denoised = torch.clamp(denoised, 0., 1.)
        return denoised

    def forward(self, noisy, **kwargs):
        denoised = self.denoise(noisy)
        kwargs.update({'clean' : denoised})
        noise = noisy - denoised

        z = self.noise_flow.forward(noise, **kwargs)
        # z, objective = self.noise_flow.forward(noise, **kwargs)

        return z#, objective, denoised

    def symmetric_loss(self, noisy1, noisy2, **kwargs):
        denoised1 = self.denoise(noisy1)
        denoised2 = self.denoise(noisy2)
        
        noise1 = noisy1 - denoised2
        noise2 = noisy2 - denoised1

        kwargs.update({'clean' : denoised2})
        nll1, _ = self.noise_flow.loss(noise1, **kwargs)

        kwargs.update({'clean' : denoised1})
        nll2, _ = self.noise_flow.loss(noise2, **kwargs)

        nll = (nll1 + nll2) / 2
        return nll

    def symmetric_loss_with_mse(self, noisy1, noisy2, **kwargs):
        denoised1 = self.denoise(noisy1, clip=False)
        denoised2 = self.denoise(noisy2, clip=False)

        mse_loss1 = self.denoiser_loss(denoised1, noisy2)
        mse_loss2 = self.denoiser_loss(denoised2, noisy1)

        denoised1 = torch.clamp(denoised1, 0., 1.)
        denoised2 = torch.clamp(denoised2, 0., 1.)
        
        noise1 = noisy1 - denoised2
        noise2 = noisy2 - denoised1

        kwargs.update({'clean' : denoised2})
        nll1, _ = self.noise_flow.loss(noise1, **kwargs)

        kwargs.update({'clean' : denoised1})
        nll2, _ = self.noise_flow.loss(noise2, **kwargs)

        nll = (nll1 + nll2) / 2
        mse_loss = (mse_loss1 + mse_loss2) / 2

        return nll, mse_loss


    def _loss_u(self, noisy1, noisy2, **kwargs):
        denoised1 = self.denoise(noisy1, clip=False)

        mse_loss = self.denoiser_loss(denoised1, noisy2)

        denoised1 = torch.clamp(denoised1, 0., 1.)

        noise = noisy1 - denoised1
        kwargs.update({'clean' : denoised1})
        nll, _ = self.noise_flow.loss(noise, **kwargs)

        return nll, mse_loss

    def loss_u(self, noisy1, noisy2, **kwargs):
        # return self.symmetric_loss(noisy1, noisy2, **kwargs), 0, 0

        # nll, mse = self._loss_u(noisy1, noisy2, **kwargs)
        nll, mse = self.symmetric_loss_with_mse(noisy1, noisy2, **kwargs)

        return nll + self.lmbda * mse, nll.item(), mse.item()
        # return nll, nll.item(), mse.item()

    def forward_s(self, noise, **kwargs):
        return self.noise_flow.forward(noise, **kwargs)

    def _loss_s(self, x, **kwargs):
        return self.noise_flow._loss(x, **kwargs)

    def loss_s(self, x, **kwargs):
        return self.noise_flow.loss(x, **kwargs)

    def mse_loss(self, noisy, clean, **kwargs):
        denoised = self.denoise(noisy, clip=False)
        mse_loss = self.denoiser_loss(denoised, clean)
        psnr = batch_PSNR(denoised, clean, 1.)
        return mse_loss.item(), psnr

    def sample(self, eps_std=None, **kwargs):
        return self.noise_flow.sample(eps_std, **kwargs)


In [24]:
from torch import randn as torch_randn
from fastai.vision.all import test_eq

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

x = torch_randn(16,1,64,64).to(device)
xdim = len(x.shape)-2

tst = Noise2NoiseFlow(x.shape[1:], arch='gain', denoiser=UNet(2))
tst.to(device)
mods = list(tst.children())
print(mods)

# logp, sample = tst.forward_s(x)
# print(logp.shape)
# print(sample.shape)

z = tst.forward(x)
test_eq(z.shape, x.shape)

|-Gain
[UNet(
  (net_recurse): _Net_recurse(
    (sub_conv_more): Sequential(
      (0): ConvLayer(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): ConvLayer(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (sub_u): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): _Net_recurse(
        (sub_conv_more): Sequential(
          (0): ConvLayer(
            (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (1): ConvLay

In [25]:
#| hide
import nbdev; nbdev.nbdev_export()