In [None]:
# default_exp architecture.DeblurGANv2

In [None]:
# export
import torch
import torch.nn as nn
import numpy as np
import functools
import torch.nn.functional as F

'''
TODO: as soon as (ever?) torchvision has this, switch to torchvision
'''
from pretrainedmodels import inceptionresnetv2

'''
Copied from DeblurGANv2 repo.
https://github.com/VITA-Group/DeblurGANv2
TODO: simplify the code. seriously.
'''
instance_norm = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
# Defines the PatchGAN discriminator with the specified arguments.
# with n_layers=5 it is the full-gan
class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=5, norm_layer=instance_norm, use_sigmoid=False, use_parallel=True):
        super(Discriminator, self).__init__()
        self.use_parallel = use_parallel
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = int(np.ceil((kw-1)/2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)
    
    
class Generator(nn.Module):

    def __init__(self, norm_layer=instance_norm, output_ch=3, num_filters=128, num_filters_fpn=256):
        super().__init__()

        # Feature Pyramid Network (FPN) with four feature maps of resolutions
        # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
        self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)

        # The segmentation heads on top of the FPN

        self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
        self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
        self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
        self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)

        self.smooth = nn.Sequential(
            nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
            norm_layer(num_filters),
            nn.ReLU(),
        )

        self.smooth2 = nn.Sequential(
            nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
            norm_layer(num_filters // 2),
            nn.ReLU(),
        )

        self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)


    def unfreeze(self):
        self.fpn.unfreeze()

    def forward(self, x):
        map0, map1, map2, map3, map4 = self.fpn(x)

        map4 = nn.functional.interpolate(self.head4(map4), scale_factor=8, mode="nearest")
        map3 = nn.functional.interpolate(self.head3(map3), scale_factor=4, mode="nearest")
        map2 = nn.functional.interpolate(self.head2(map2), scale_factor=2, mode="nearest")
        map1 = nn.functional.interpolate(self.head1(map1), scale_factor=1, mode="nearest")

        smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
        smoothed = nn.functional.interpolate(smoothed, scale_factor=2, mode="nearest")
        smoothed = self.smooth2(smoothed + map0)
        smoothed = nn.functional.interpolate(smoothed, scale_factor=2, mode="nearest")

        final = self.final(smoothed)
        res = torch.tanh(final) + x

        return torch.clamp(res, min = -1,max = 1)
    
    
class FPNHead(nn.Module):
    def __init__(self, num_in, num_mid, num_out):
        super().__init__()

        self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
        self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        x = nn.functional.relu(self.block0(x), inplace=True)
        x = nn.functional.relu(self.block1(x), inplace=True)
        return x    
    
    
class FPN(nn.Module):

    def __init__(self, norm_layer=instance_norm, num_filters=256):
        """Creates an `FPN` instance for feature extraction.
        Args:
          num_filters: the number of filters in each output pyramid level
          pretrained: use ImageNet pre-trained backbone feature extractor
        """

        super().__init__()
        self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')

        self.enc0 = self.inception.conv2d_1a
        self.enc1 = nn.Sequential(
            self.inception.conv2d_2a,
            self.inception.conv2d_2b,
            self.inception.maxpool_3a,
        ) # 64
        self.enc2 = nn.Sequential(
            self.inception.conv2d_3b,
            self.inception.conv2d_4a,
            self.inception.maxpool_5a,
        )  # 192
        self.enc3 = nn.Sequential(
            self.inception.mixed_5b,
            self.inception.repeat,
            self.inception.mixed_6a,
        )   # 1088
        self.enc4 = nn.Sequential(
            self.inception.repeat_1,
            self.inception.mixed_7a,
        ) #2080
        self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
                                 norm_layer(num_filters),
                                 nn.ReLU(inplace=True))
        self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
                                 norm_layer(num_filters),
                                 nn.ReLU(inplace=True))
        self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
                                 norm_layer(num_filters),
                                 nn.ReLU(inplace=True))
        self.pad = nn.ReflectionPad2d(1)
        self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
        self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
        self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
        self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
        self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)

        for param in self.inception.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.inception.parameters():
            param.requires_grad = True
        print("Unfreeze successful.")

    def forward(self, x):

        # Bottom-up pathway, from ResNet
        enc0 = self.enc0(x)

        enc1 = self.enc1(enc0) # 256

        enc2 = self.enc2(enc1) # 512

        enc3 = self.enc3(enc2) # 1024

        enc4 = self.enc4(enc3) # 2048

        # Lateral connections

        lateral4 = self.pad(self.lateral4(enc4))
        lateral3 = self.pad(self.lateral3(enc3))
        lateral2 = self.lateral2(enc2)
        lateral1 = self.pad(self.lateral1(enc1))
        lateral0 = self.lateral0(enc0)

        # Top-down pathway
        pad = (1, 2, 1, 2)  # pad last dim by 1 on each side
        pad1 = (0, 1, 0, 1)
        map4 = lateral4
        map3 = self.td1(lateral3 + nn.functional.interpolate(map4, scale_factor=2, mode="nearest"))
        map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.interpolate(map3, scale_factor=2, mode="nearest"))
        map1 = self.td3(lateral1 + nn.functional.interpolate(map2, scale_factor=2, mode="nearest"))
        return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4


In [None]:
from nbdev.export import *
notebook2script()

Converted 01_model.ipynb.
Converted 02_architecture_common.ipynb.
Converted 03_architecture_MSResNet.ipynb.
Converted 04_dataset_common.ipynb.
Converted 05_dataset_MSResNet.ipynb.
Converted 06_trainer_MSResNet.ipynb.
Converted 07_metrics.ipynb.
Converted 08_architecture_DeblurGANv2.ipynb.
Converted 09_dataset_DeblurGANv2.ipynb.
Converted 10_losses.ipynb.
Converted 99_basemodel.ipynb.
Converted 99_dataset_DeblurGANv2_clean.ipynb.
Converted 99_diffaugment.ipynb.
Converted 99_model_DeblurGANv2_clean.ipynb.
Converted 99_model_MSResNet.ipynb.
Converted DeblurGANv2_lightning-from-vanilla-Copy1.ipynb.
Converted DeblurGANv2_lightning-from-vanilla.ipynb.
Converted DeblurGANv2_vanilla.ipynb.
Converted Tutorial_without_lightning.ipynb.
Converted fuckit.ipynb.
Converted hmmm.ipynb.
Converted model_without_lightning.ipynb.
Converted trials.ipynb.


In [None]:
patch_disc = Discriminator(n_layers=3,
                           norm_layer=functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True),
                           use_sigmoid=False)

In [None]:
generator = Generator(norm_layer=functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True))

In [None]:
x = torch.rand(1,3,720,1280)

In [None]:
_, c, h, w = x.shape
block_size= 32

In [None]:
 min_height = (h // block_size + 1) * block_size

In [None]:
min_width = (w // block_size + 1) * block_size

In [None]:
min_width, min_height

(1312, 736)

In [None]:
pad_w = (min_width - w)//2
pad_h = (min_height - h)//2

In [None]:
 pad_params = {'mode': 'constant',
                      'value': 0,
                      'pad': (pad_w, pad_w, pad_h, pad_h)
                      }
x = F.pad(x, **pad_params)

In [None]:
x.shape

torch.Size([1, 3, 736, 1312])

In [None]:
generator(x)

tensor([[[[ 0.0153,  0.1499, -0.2312,  ..., -0.3292, -0.1599, -0.0901],
          [-0.1745, -0.0168,  0.2872,  ..., -0.3346, -0.1249,  0.0216],
          [-0.5284, -0.3765, -0.1533,  ..., -0.1632, -0.0262,  0.0801],
          ...,
          [ 0.0500, -0.0424,  0.0689,  ..., -0.1204, -0.0469,  0.0082],
          [ 0.0857,  0.0094,  0.0410,  ..., -0.2651, -0.2845, -0.1377],
          [ 0.0975, -0.0058, -0.2135,  ..., -0.3063, -0.3515, -0.0448]],

         [[-0.3389,  0.0734, -0.0703,  ..., -0.0618, -0.0698,  0.1708],
          [ 0.1390,  0.3178,  0.3994,  ...,  0.1895,  0.0259,  0.2192],
          [-0.2641, -0.0761,  0.3369,  ...,  0.4438,  0.0495,  0.1800],
          ...,
          [-0.0663,  0.3616,  0.3557,  ..., -0.4078, -0.3815, -0.0998],
          [-0.2381,  0.0448,  0.0123,  ..., -0.5223, -0.4879, -0.3241],
          [-0.1543, -0.1696, -0.2505,  ..., -0.2275, -0.2418, -0.0958]],

         [[ 0.2883,  0.0787, -0.2844,  ..., -0.2128, -0.1868,  0.0233],
          [ 0.1258, -0.5242, -

In [None]:
enc0 = generator.fpn.enc0(x)
enc1 = generator.fpn.enc1(enc0)
enc2 = generator.fpn.enc2(enc1)
enc3 = generator.fpn.enc3(enc2)
enc4 = generator.fpn.enc4(enc3)

In [None]:
lateral4 = generator.fpn.pad(generator.fpn.lateral4(enc4))
lateral3 = generator.fpn.pad(generator.fpn.lateral3(enc3))
lateral2 = generator.fpn.lateral2(enc2)
lateral1 = generator.fpn.pad(generator.fpn.lateral1(enc1))
lateral0 = generator.fpn.lateral0(enc0)

In [None]:
pad = (1, 2, 1, 2)  # pad last dim by 1 on each side
pad1 = (0, 1, 0, 1)
map4 = lateral4
map3 = generator.fpn.td1(lateral3 + nn.functional.interpolate(map4, scale_factor=2, mode="nearest"))

RuntimeError: The size of tensor a (45) must match the size of tensor b (46) at non-singleton dimension 2

In [None]:
nn.functional.upsample(map4, scale_factor=2, mode="nearest").shape

torch.Size([1, 256, 46, 80])

In [None]:
print('enc0', enc0.shape)
print('enc1', enc1.shape)
print('enc2', enc2.shape)
print('enc2', enc2.shape)
print('enc3', enc3.shape)
print('enc4', enc4.shape)
print('lateral4', lateral4.shape)
print('lateral3', lateral3.shape)
print('lateral2', lateral2.shape)
print('lateral1', lateral1.shape)
print('lateral0', lateral0.shape)
print('map4', map4.shape)
print('map3', map3.shape)

enc0 torch.Size([1, 32, 359, 639])
enc1 torch.Size([1, 64, 178, 318])
enc2 torch.Size([1, 192, 87, 157])
enc2 torch.Size([1, 192, 87, 157])
enc3 torch.Size([1, 1088, 43, 78])
enc4 torch.Size([1, 2080, 21, 38])
lateral4 torch.Size([1, 256, 23, 40])
lateral3 torch.Size([1, 256, 45, 80])
lateral2 torch.Size([1, 256, 87, 157])
lateral1 torch.Size([1, 256, 180, 320])
lateral0 torch.Size([1, 128, 359, 639])
map4 torch.Size([1, 256, 23, 40])
map3 torch.Size([1, 256, 16, 16])


In [None]:
map2 = generator.fpn.td2(F.pad(lateral2, pad, "reflect") + nn.functional.interpolate(map3, scale_factor=2, mode="nearest"))
map1 = generator.fpn.td3(lateral1 + nn.functional.interpolate(map2, scale_factor=2, mode="nearest"))

In [None]:
print('map2', map2.shape)
print('map1', map1.shape)

map2 torch.Size([1, 256, 32, 32])
map1 torch.Size([1, 256, 64, 64])
