# Models

> models


In [1]:
#| default_exp models

In [2]:
#| hide
from nbdev.showdoc import *
!gpustat

[1m[37mmerlin                    [m  Fri Jun 14 20:51:39 2024  [1m[30m525.147.05[m
[36m[0][m [34mNVIDIA GeForce RTX 4090[m |[31m 48°C[m, [32m  0 %[m | [36m[1m[33m    6[m / [33m24564[m MB |
[36m[1][m [34mNVIDIA GeForce RTX 4090[m |[31m 48°C[m, [32m  0 %[m | [36m[1m[33m  229[m / [33m24564[m MB |


In [3]:
#| export

import os

from typing import Iterator
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from Noise2Model.layers import get_flow_layer
from Noise2Model.networks import get_network_class, UNet, DnCNN
from Noise2Model.utils import StandardNormal, attributesFromDict


In [4]:
# from torch import randn as torch_randn
# from fastai.vision.all import test_eq

print(torch.cuda.is_available())

True


In [5]:
#| export

model_class_dict = {}

def regist_model(model_class):
    model_name = model_class.__name__.lower()
    assert not model_name in model_class_dict, 'there is already registered model: %s in model_class_dict.' % model_name
    model_class_dict[model_name] = model_class

    return model_class

def get_model_class(model_name:str):
    model_name = model_name.lower()
    return model_class_dict[model_name]

## NMFlow

### Noise Modeler

In [6]:
#| export

@regist_model
class NMFlow(nn.Module):
    def __init__(
        self,
        in_ch=1,
        ch_exp_coef = 1.,
        width_exp_coef = 2.,
        num_bits=16,
        conv_net_feats=16,
        pre_arch="UD",
        arch="NE|SAL|SDL|CL2|SAL|SDL|CL2",
        device='cpu',
        codes=None
    ):
        super(NMFlow, self).__init__()
        attributesFromDict(locals()) # stores all the input parameters in self
        
        if codes==None:
            self.codes= {
                'camera': torch.tensor([2], dtype=torch.float32).to(device)
            }
        
        self.pre_bijectors = list()
        pre_arch_lyrs = pre_arch.split('|')
        for lyr in pre_arch_lyrs:
            self.pre_bijectors.append(self.get_flow_layer(lyr))
        self.pre_bijectors = nn.Sequential(*self.pre_bijectors)

        self.bijectors = list()
        arch_lyrs = arch.split('|')
        for lyr in arch_lyrs:
            self.bijectors.append(self.get_flow_layer(lyr))
        self.bijectors = nn.Sequential(*self.bijectors)
        self.dist = StandardNormal()

    def internal_channels(self):
        return int(self.in_ch * self.ch_exp_coef)
    
    def internal_widths(self):
        return int(self.in_ch * self.width_exp_coef)

    def get_flow_layer(self, name):
        match name:
            case "UD":
                return get_flow_layer("UniformDequantization")(device=self.device, num_bits=self.num_bits)
            
            case "NE":
                return get_flow_layer("NoiseExtraction")(device=self.device)    
            
            case "CL2":
                return get_flow_layer("ConditionalLinearExp2")(
                    in_ch=self.internal_channels(),
                    device=self.device,
                    codes=self.codes,
                )
                
            case "SDL":
                return get_flow_layer("SignalDependentConditionalLinear")(
                    meta_encoder=lambda in_features, out_features: get_network_class("ResidualNet")(
                        in_features=in_features,
                        out_features=out_features,
                        hidden_features=5,
                        num_blocks=3,
                        use_batch_norm=True,
                        dropout_probability=0.0
                    ),
                    scale_and_bias=lambda in_features, out_features: get_flow_layer("PointwiseConvs")(
                        in_features=in_features,
                        out_features=out_features,
                        feats=self.conv_net_feats
                    ),
                    in_ch=self.internal_channels(),
                    device=self.device,
                    codes=self.codes,
                )
                
            case "SAL":
                return get_flow_layer("StructureAwareConditionalLinearLayer")(
                    meta_encoder=lambda in_features, out_features: get_network_class("ResidualNet")(
                        in_features=in_features,
                        out_features=out_features,
                        hidden_features=5,
                        num_blocks=3,
                        use_batch_norm=True,
                        dropout_probability=0.0
                    ),
                    structure_encoder=lambda in_features, out_features: get_flow_layer("SpatialConvs")(
                        in_features=in_features,
                        out_features=out_features,
                        receptive_field=9,
                        feats=self.conv_net_feats
                    ),
                    in_ch=self.internal_channels(),
                    codes=self.codes,
                    device=self.device
                )
            
            case _: 
                assert False, f"Invalid layer name : {name}"

    def forward(self, noisy, clean, kwargs=dict()):
        x = noisy
        kwargs['clean'] = clean.clone()

        objectives = 0.
        for bijector in self.pre_bijectors:
            if isinstance(bijector, get_flow_layer("UniformDequantization")):
                kwargs['clean'], _ = bijector._forward_and_log_det_jacobian(kwargs['clean'])

            x, ldj = bijector._forward_and_log_det_jacobian(x, **kwargs)
            objectives += ldj

        for bijector in self.bijectors:
            x, ldj = bijector._forward_and_log_det_jacobian(x, **kwargs)
            objectives += ldj
        return x, objectives

    def sample(self, kwargs=dict()):
        for bijector in self.pre_bijectors:
            if isinstance(bijector, get_flow_layer("UniformDequantization")):
                kwargs['clean'], _ = bijector._forward_and_log_det_jacobian(kwargs['clean'], **kwargs)

        b,_,h,w = kwargs['clean'].shape
        x = self.dist.sample((b,self.internal_channels(),h,w))
        for bijector in reversed(self.bijectors):
            x = bijector._inverse(x, **kwargs)

        for bijector in reversed(self.pre_bijectors):
            if isinstance(bijector, get_flow_layer("UniformDequantization")):
                kwargs['clean'] = bijector._inverse(kwargs['clean'], **kwargs)
            x = bijector._inverse(x, **kwargs)
        x = torch.clip(x, 0, 2**self.num_bits)
        return x 


In [7]:
NMFlow(arch="NE|SAL|SDL|CL2")

NMFlow(
  (pre_bijectors): Sequential(
    (0): UniformDequantization()
  )
  (bijectors): Sequential(
    (0): NoiseExtraction()
    (1): StructureAwareConditionalLinearLayer(
      (meta_encoder): ResidualNet(
        (initial_layer): Linear(in_features=1, out_features=5, bias=True)
        (blocks): ModuleList(
          (0-2): 3 x ResidualBlock(
            (batch_norm_layers): ModuleList(
              (0-1): 2 x BatchNorm1d(5, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            )
            (linear_layers): ModuleList(
              (0-1): 2 x Linear(in_features=5, out_features=5, bias=True)
            )
          )
        )
        (final_layer): Linear(in_features=5, out_features=2, bias=True)
      )
      (structure_encoder): SpatialConvs(
        (body): Sequential(
          (conv_in): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1))
          (relu_in): ReLU(inplace=True)
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding

### Denoiser

In [8]:
#| export

class NMFlowDenoiser(nn.Module):
    def __init__(
            self,
            denoiser,
            kwargs_flow,
            flow_pth_path,
            num_bits=8,
        ):
        super().__init__()
        attributesFromDict(locals()) # stores all the input parameters in self

        self.noise_model = get_model_class("NMFlow")(**kwargs_flow)
        self._load_checkpoint(self.noise_model, flow_pth_path)

    def _load_checkpoint(self, module, path, name='noise_model'):
        assert os.path.exists(path), f"{path} is not exist."
        pth = torch.load(path)
        module.load_state_dict(pth['model_weight'][name])
        module.eval()
    
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        return self.denoiser.parameters(recurse) # the parameters of denoiser will be trained only.
        
    def forward(self, x, kwargs=dict()):
        # x: clean image
        x_scaled = x / (2**self.num_bits) # x_scaled: 0 ~ 1
        x_scaled = x_scaled * (2**self.noise_model.num_bits) # x_scaled: 0 ~ noise model's max GL.
        
        kwargs['clean'] = x_scaled
        with torch.no_grad(): 
            n = self.noise_model.sample(kwargs) # noisy image

        n_scaled = n / (2**self.noise_model.num_bits) # n_scaled: 0 ~ 1
        n_scaled = torch.clip(n_scaled, 0., 1.)
        y = self.denoiser(n_scaled)
        y = y * (2**self.num_bits) # y: 0 ~ denoiser's max GL.
        return y
    
    def denoise(self, x, kwargs=None):
        # x: noisy image
        if kwargs is None or 'num_bits' not in kwargs: num_bits = self.num_bits
        else: num_bits = kwargs['num_bits']

        x_scaled = x / (2**num_bits) # x_scaled: 0 ~ 1
        y =  self.denoiser(x_scaled) 
        y = torch.clip(y, 0., 1.)
        y *= (2**num_bits) # x_scaled: 0 ~ denoiser's max GL.
        return y
    
    def sample(self, x, kwargs=None):
        # x: clean image
        if kwargs is None or 'num_bits' not in kwargs: num_bits = self.num_bits
        else: num_bits = kwargs['num_bits']

        x_scaled = x / (2**num_bits) # x_scaled: 0 ~ 1
        x_scaled = x_scaled * (2**self.noise_model.num_bits) # x_scaled: 0 ~ noise model's max GL.

        kwargs = dict()
        kwargs['clean'] = x_scaled
        n = self.noise_model.sample(kwargs) # n: 0 ~ noise model's max GL.
        
        n_scaled = n / (2**self.noise_model.num_bits) # n_scaled: 0 ~ 1
        n_scaled = n_scaled * (2**num_bits) # n_scaled: 0 ~ denoiser's max GL.
        return n_scaled

## NMFlowGAN

### Generator

In [9]:
#| export

@regist_model
class NMFlowGANGenerator(nn.Module):
    def __init__(
        self,
        kwargs_unet,
        kwargs_flow,
    ):
        super(NMFlowGANGenerator, self).__init__()
        self._flow_init(
            **kwargs_flow
        )
        self.generator = UNet(
             **kwargs_unet
        )  
         
    def _flow_init(
        self,
        in_ch=1,
        ch_exp_coef = 1.,
        width_exp_coef = 2.,
        num_bits=16,
        conv_net_feats=16,
        pre_arch="UD",
        arch="NE|SAL|SDL|CL2|SAL|SDL|CL2",
        device='cpu',
        codes=None
    ):
        attributesFromDict(locals()) # stores all the input parameters in self
        
        if codes==None:
            self.codes= {
                'camera': torch.tensor([2], dtype=torch.float32).to(device)
            }
        
        self.pre_bijectors = list()
        pre_arch_lyrs = pre_arch.split('|')
        for lyr in pre_arch_lyrs:
            self.pre_bijectors.append(self.get_flow_layer(lyr))
        self.pre_bijectors = nn.Sequential(*self.pre_bijectors)

        self.bijectors = list()
        arch_lyrs = arch.split('|')
        for lyr in arch_lyrs:
            self.bijectors.append(self.get_flow_layer(lyr))
        self.bijectors = nn.Sequential(*self.bijectors)
        self.dist = StandardNormal()

    def internal_channels(self):
        return int(self.in_ch * self.ch_exp_coef)
    
    def internal_widths(self):
        return int(self.in_ch * self.width_exp_coef)

    def get_flow_layer(self, name):
        match name:
            case "UD":
                return get_flow_layer("UniformDequantization")(device=self.device, num_bits=self.num_bits)
            
            case "NE":
                return get_flow_layer("NoiseExtraction")(device=self.device)    
            
            case "CL2":
                return get_flow_layer("ConditionalLinearExp2")(
                    in_ch=self.internal_channels(),
                    device=self.device,
                    codes=self.codes,
                )
                
            case "SDL":
                return get_flow_layer("SignalDependentConditionalLinear")(
                    meta_encoder=lambda in_features, out_features: get_network_class("ResidualNet")(
                        in_features=in_features,
                        out_features=out_features,
                        hidden_features=5,
                        num_blocks=3,
                        use_batch_norm=True,
                        dropout_probability=0.0
                    ),
                    scale_and_bias=lambda in_features, out_features: get_flow_layer("PointwiseConvs")(
                        in_features=in_features,
                        out_features=out_features,
                        feats=self.conv_net_feats
                    ),
                    in_ch=self.internal_channels(),
                    device=self.device,
                    codes=self.codes,
                )
                
            case "SAL":
                return get_flow_layer("StructureAwareConditionalLinearLayer")(
                    meta_encoder=lambda in_features, out_features: get_network_class("ResidualNet")(
                        in_features=in_features,
                        out_features=out_features,
                        hidden_features=5,
                        num_blocks=3,
                        use_batch_norm=True,
                        dropout_probability=0.0
                    ),
                    structure_encoder=lambda in_features, out_features: get_flow_layer("SpatialConvs")(
                        in_features=in_features,
                        out_features=out_features,
                        receptive_field=9,
                        feats=self.conv_net_feats
                    ),
                    in_ch=self.internal_channels(),
                    codes=self.codes,
                    device=self.device
                )
            
            case _: 
                assert False, f"Invalid layer name : {name}"

    def _flow_forward(self, noisy, clean, kwargs=dict()):
        x = noisy
        kwargs['clean'] = clean.clone()

        objectives = 0.
        for bijector in self.pre_bijectors:
            if isinstance(bijector, get_flow_layer("UniformDequantization")):
                kwargs['clean'], _ = bijector._forward_and_log_det_jacobian(kwargs['clean'])

            x, ldj = bijector._forward_and_log_det_jacobian(x, **kwargs)
            objectives += ldj

        for bijector in self.bijectors:
            x, ldj = bijector._forward_and_log_det_jacobian(x, **kwargs)
            objectives += ldj
        return x, objectives

    def _flow_sample(self, kwargs=dict()):
        for bijector in self.pre_bijectors:
            if isinstance(bijector, get_flow_layer("UniformDequantization")):
                kwargs['clean'], _ = bijector._forward_and_log_det_jacobian(kwargs['clean'], **kwargs)

        b,_,h,w = kwargs['clean'].shape
        x = self.dist.sample((b,self.internal_channels(),h,w))
        for bijector in reversed(self.bijectors):
            x = bijector._inverse(x, **kwargs)

        for bijector in reversed(self.pre_bijectors):
            if isinstance(bijector, get_flow_layer("UniformDequantization")):
                kwargs['clean'] = bijector._inverse(kwargs['clean'], **kwargs)
            x = bijector._inverse(x, **kwargs)
        x = torch.clip(x, 0, 2**self.num_bits)
        return x 

    def forward(self, noisy, clean, kwargs=dict()):
        z, objectives = self._flow_forward(noisy, clean, kwargs)
        kwargs['clean']=clean
        with torch.no_grad():
            x = self._flow_sample(kwargs) - kwargs['clean'] 
        x_scaled = x / (2**self.num_bits) # x_scaled: -1 ~ 1
        y = (self.generator(x_scaled) * (2**self.num_bits) + kwargs['clean']).requires_grad_(True)
        return z, objectives, y, x
    
    def sample(self, kwargs=dict()):
        x = self._flow_sample(kwargs) - kwargs['clean'] # pixelwise noise
        x_scaled = x / (2**self.num_bits) # x_scaled: -1 ~ 1
        y = self.generator(x_scaled) * (2**self.num_bits) + kwargs['clean']
        y = torch.clip(y, 0, 2**self.num_bits)
        return y

In [10]:
kwargs_unet = {
        'depth': 1,
}
print(kwargs_unet)
model = NMFlowGANGenerator(kwargs_unet,dict())

{'depth': 1}


In [11]:
device = 'cpu'
noisy = torch.randint(256,[5, 1, 2, 2])
clean = torch.randint(256,[5, 1, 2, 2])
kwargs = dict()
kwargs['camera'] = torch.tensor([2], dtype=torch.float32, device=device)

output = model.forward(noisy,clean, kwargs=kwargs)
assert len(output) == 4

### Critic

In [12]:
#| export

@regist_model
class NMFlowGANCritic(nn.Module):
    def __init__(
            self,
            in_ch=1,
            nc=64,
            num_bits=8
    ):
        super(NMFlowGANCritic, self).__init__()
        self.num_bits = num_bits
        self.critic = Discriminator_96(in_ch, nc)

    def forward(self, x):
         x_scaled = x / (2**self.num_bits)
         return self.critic(x_scaled)


In [13]:
#| export

class Discriminator_96(nn.Module):
    """Discriminator with 96x96 input, refer to Kai Zhang, https://github.com/cszn/KAIR"""
    def __init__(self, in_nc=3, nc=64):
        super(Discriminator_96, self).__init__()
        conv0 = nn.Conv2d(in_nc, nc, kernel_size=7, padding=3)
        conv1 = self._get_basic_module(nc, nc, kernel_size=4, stride=2)
        # 48, 64
        conv2 = self._get_basic_module(nc, nc*2, kernel_size=3, stride=1)
        conv3 = self._get_basic_module(nc*2, nc*2, kernel_size=4, stride=2)
        # 24, 128
        conv4 = self._get_basic_module(nc*2, nc*4, kernel_size=3, stride=1)
        conv5 = self._get_basic_module(nc*4, nc*4, kernel_size=4, stride=2)
        # 12, 256
        conv6 = self._get_basic_module(nc*4, nc*8, kernel_size=3, stride=1)
        conv7 = self._get_basic_module(nc*8, nc*8, kernel_size=4, stride=2)
        # 6, 512
        conv8 = self._get_basic_module(nc*8, nc*8, kernel_size=3, stride=1)
        conv9 = self._get_basic_module(nc*8, nc*8, kernel_size=4, stride=2)
        # 3, 512
        self.features = nn.Sequential(*[conv0, conv1, conv2, conv3, conv4,
                                     conv5, conv6, conv7, conv8, conv9])

        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def _get_basic_module(self, in_ch, out_ch, kernel_size=1, stride=1, padding=1, negative_slope=0.2):
            return nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
                    nn.InstanceNorm2d(out_ch, affine=True), #batch normalization?
                    nn.LeakyReLU(negative_slope, inplace=True)
            )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    

### GAN Denoiser

In [14]:
#| export

class NMFlowGANDenoiser(nn.Module):
    def __init__(
            self,
            denoiser,
            kwargs_flow,
            kwargs_unet,
            pretrained_path,
            num_bits=8,
        ):
        super().__init__()
        self.denoiser = denoiser
        self.kwargs_flow = kwargs_flow
        self.pretrained_path = pretrained_path
        self.num_bits = num_bits
        self.noise_model = get_model_class("NMFlowGANGenerator")(kwargs_unet, kwargs_flow)
        self._load_checkpoint(self.noise_model, self.pretrained_path)

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        return self.denoiser.parameters(recurse) # the parameters of denoiser will be trained only.
    
    def _load_checkpoint(self, module, path):
        if not os.path.exists(path):
            print(os.path.exists(path), f"WARNING: {path} is not exist.")
            return
        pth = torch.load(path)
        module.load_state_dict(pth['model_weight']['generator'])
        module.eval()
        
    def forward(self, x, kwargs=dict()):
        # x: clean image
        x_scaled = x / (2**self.num_bits) # x_scaled: 0 ~ 1
        x_scaled = x_scaled * (2**self.noise_model.num_bits) # x_scaled: 0 ~ noise model's max GL.
        
        kwargs['clean'] = x_scaled
        with torch.no_grad(): 
            n = self.noise_model.sample(kwargs) # noisy image

        n_scaled = n / (2**self.noise_model.num_bits) # n_scaled: 0 ~ 1
        n_scaled = torch.clip(n_scaled, 0., 1.)
        y = self.denoiser(n_scaled)
        y = y * (2**self.num_bits) # y: 0 ~ denoiser's max GL.
        return y
    
    def denoise(self, x, kwargs=None):
        # x: noisy image
        if kwargs is None or 'num_bits' not in kwargs: num_bits = self.num_bits
        else: num_bits = kwargs['num_bits']

        x_scaled = x / (2**num_bits) # x_scaled: 0 ~ 1
        y =  self.denoiser(x_scaled) 
        y = torch.clip(y, 0., 1.)
        y *= (2**num_bits) # x_scaled: 0 ~ denoiser's max GL.
        return y
    
    def sample(self, x, kwargs=None):
        # x: clean image
        if kwargs is None or 'num_bits' not in kwargs: num_bits = self.num_bits
        else: num_bits = kwargs['num_bits']

        x_scaled = x / (2**num_bits) # x_scaled: 0 ~ 1
        x_scaled = x_scaled * (2**self.noise_model.num_bits) # x_scaled: 0 ~ noise model's max GL.

        kwargs = dict()
        kwargs['clean'] = x_scaled
        n = self.noise_model.sample(kwargs) # n: 0 ~ noise model's max GL.
        
        n_scaled = n / (2**self.noise_model.num_bits) # n_scaled: 0 ~ 1
        n_scaled = torch.clip(n_scaled, 0., 1.)
        n_scaled = n_scaled * (2**num_bits) # n_scaled: 0 ~ denoiser's max GL.
        return n_scaled

In [15]:
#| export

@regist_model
class DnCNNFlowGAN(NMFlowGANDenoiser):
    def __init__(
        self,
        kwargs_dncnn,
        kwargs_unet,
        kwargs_flow,
        pretrained_path,
        num_bits=8
        ):
        super().__init__(
            DnCNN(**kwargs_dncnn),
            kwargs_flow,
            kwargs_unet,
            pretrained_path,
            num_bits,
        )

In [16]:
#| hide
import nbdev; nbdev.nbdev_export()