In [None]:
!pip install einops

In [None]:
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm

# from basicsr.utils import get_root_logger

# try:
#     from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
#                                         modulated_deform_conv)
# except ImportError:
#     # print('Cannot import dcn. Ignore this warning if dcn is not used. '
#     #       'Otherwise install BasicSR with compiling dcn.')
#

@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


class ResidualBlockNoBN(nn.Module):
    """Residual block without BN.

    It has a style of:
        ---Conv-ReLU-Conv-+-
         |________________|

    Args:
        num_feat (int): Channel number of intermediate features.
            Default: 64.
        res_scale (float): Residual scale. Default: 1.
        pytorch_init (bool): If set to True, use pytorch default init,
            otherwise, use default_init_weights. Default: False.
    """

    def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
        super(ResidualBlockNoBN, self).__init__()
        self.res_scale = res_scale
        self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.relu = nn.ReLU(inplace=True)

        if not pytorch_init:
            default_init_weights([self.conv1, self.conv2], 0.1)

    def forward(self, x):
        identity = x
        out = self.conv2(self.relu(self.conv1(x)))
        return identity + out * self.res_scale


class Upsample(nn.Sequential):
    """Upsample module.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. '
                             'Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)


def flow_warp(x,
              flow,
              interp_mode='bilinear',
              padding_mode='zeros',
              align_corners=True):
    """Warp an image or feature map with optical flow.

    Args:
        x (Tensor): Tensor with size (n, c, h, w).
        flow (Tensor): Tensor with size (n, h, w, 2), normal value.
        interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
        padding_mode (str): 'zeros' or 'border' or 'reflection'.
            Default: 'zeros'.
        align_corners (bool): Before pytorch 1.3, the default value is
            align_corners=True. After pytorch 1.3, the default value is
            align_corners=False. Here, we use the True as default.

    Returns:
        Tensor: Warped image or feature map.
    """
    assert x.size()[-2:] == flow.size()[1:3]
    _, _, h, w = x.size()
    # create mesh grid
    grid_y, grid_x = torch.meshgrid(
        torch.arange(0, h).type_as(x),
        torch.arange(0, w).type_as(x))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False

    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(
        x,
        vgrid_scaled,
        mode=interp_mode,
        padding_mode=padding_mode,
        align_corners=align_corners)

    # TODO, what if align_corners=False
    return output


def resize_flow(flow,
                size_type,
                sizes,
                interp_mode='bilinear',
                align_corners=False):
    """Resize a flow according to ratio or shape.

    Args:
        flow (Tensor): Precomputed flow. shape [N, 2, H, W].
        size_type (str): 'ratio' or 'shape'.
        sizes (list[int | float]): the ratio for resizing or the final output
            shape.
            1) The order of ratio should be [ratio_h, ratio_w]. For
            downsampling, the ratio should be smaller than 1.0 (i.e., ratio
            < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
            ratio > 1.0).
            2) The order of output_size should be [out_h, out_w].
        interp_mode (str): The mode of interpolation for resizing.
            Default: 'bilinear'.
        align_corners (bool): Whether align corners. Default: False.

    Returns:
        Tensor: Resized flow.
    """
    _, _, flow_h, flow_w = flow.size()
    if size_type == 'ratio':
        output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
    elif size_type == 'shape':
        output_h, output_w = sizes[0], sizes[1]
    else:
        raise ValueError(
            f'Size type should be ratio or shape, but got type {size_type}.')

    input_flow = flow.clone()
    ratio_h = output_h / flow_h
    ratio_w = output_w / flow_w
    input_flow[:, 0, :, :] *= ratio_w
    input_flow[:, 1, :, :] *= ratio_h
    resized_flow = F.interpolate(
        input=input_flow,
        size=(output_h, output_w),
        mode=interp_mode,
        align_corners=align_corners)
    return resized_flow


# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.

    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.

    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)


# class DCNv2Pack(ModulatedDeformConvPack):
#     """Modulated deformable conv for deformable alignment.
#
#     Different from the official DCNv2Pack, which generates offsets and masks
#     from the preceding features, this DCNv2Pack takes another different
#     features to generate offsets and masks.
#
#     Ref:
#         Delving Deep into Deformable Alignment in Video Super-Resolution.
#     """
#
#     def forward(self, x, feat):
#         out = self.conv_offset(feat)
#         o1, o2, mask = torch.chunk(out, 3, dim=1)
#         offset = torch.cat((o1, o2), dim=1)
#         mask = torch.sigmoid(mask)
#
#         offset_absmean = torch.mean(torch.abs(offset))
#         if offset_absmean > 50:
#             logger = get_root_logger()
#             logger.warning(
#                 f'Offset abs mean is {offset_absmean}, larger than 50.')
#
#         return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
#                                      self.stride, self.padding, self.dilation,
#                                      self.groups, self.deformable_groups)


In [None]:
## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881


import torch
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace as stx
import numbers

from einops import rearrange



##########################################################################
## Layer Norm

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma+1e-5) * self.weight

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)



##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x



##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)



    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out



##########################################################################
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x



##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)

        return x



##########################################################################
## Resizing modules
class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)

##########################################################################
##---------- Restormer -----------------------
class RainEncoder(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        dim = 48,
        num_blocks = [2,3,3,4],
        num_refinement_blocks = 2,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(RainEncoder, self).__init__()

        self.GAP = nn.AdaptiveAvgPool2d(1)

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        self.head = nn.Sequential(
                        nn.Linear(384, 512),
                        nn.ReLU(inplace=True),
                        nn.Linear(512, 256)
                    )
        #self.weight_latent = nn.Linear(256, 384)
        
        
    def forward(self, inp_img):

        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)

        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)
        fea = self.GAP(latent)
        fea = fea.reshape(fea.size(0), -1)
        fea = self.head(fea)
#         linear_weights = self.weight_latent(fea)
#         print(fea.shape)

        return latent, fea, out_enc_level1, out_enc_level2, out_enc_level3#, linear_weights

In [None]:
class RainDecoder(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        dim = 48,
        num_blocks = [2,3,3,4],
        num_refinement_blocks = 2,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(RainDecoder, self).__init__()

        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
#         self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
#         self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.output = nn.Conv2d(int(dim**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, latent):
        inp_dec_level3 = self.up4_3(latent)
        # inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        # inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        # inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        # inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        # inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

#         out_dec_level1 = self.refinement(out_dec_level1)

        out_dec_level1 = self.output(out_dec_level1)

        return out_dec_level1


In [None]:
class BackgroundDecoder(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        dim = 48,
        num_blocks = [2,3,3,4],
        num_refinement_blocks = 2,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(BackgroundDecoder, self).__init__()

        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])

        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
        ###########################

        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, latent, out_enc_level1, out_enc_level2, out_enc_level3):
        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)

        out_dec_level1 = self.output(out_dec_level1)


        return out_dec_level1

In [None]:
class DerainingFramework(nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        is_train = False,
        dim = 48,
        num_blocks = [2,3,3,4],
        num_refinement_blocks = 2,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(DerainingFramework, self).__init__()
        self.is_train = is_train
        self.inp_rain_encoder = RainEncoder()
        self.inp_rain_decoder = RainDecoder()
        self.out_rain_encoder = RainEncoder()
#         self.out_rain_decoder = RainDecoder()
        self.inp_bg_encoder = RainEncoder()
        self.inp_bg_decoder = BackgroundDecoder()
        self.out_bg_encoder = RainEncoder()
#         self.weight_latent = None
#         self.out_bg_decoder = RainDecoder()


    def forward(self, inp_img, weakly_img=[]):
        if self.is_train:
            latent_rain_weakly, weakly_rain_fea,_,_,_ = self.out_rain_encoder(weakly_img)
            latent_bg_weakly, weakly_bg_fea,_,_,_ = self.out_bg_encoder(weakly_img)
#             print(weight_latent.shape)
#             self.weight_latent = weight_latent.view(-1,384,1,1)
            latent_rain_inp, inp_rain_fea, rain_enc_level1, rain_enc_level2, rain_enc_level3 = self.inp_rain_encoder(inp_img)
#             self.weight_latent = weight_latent.view(-1,384,1,1)
#             latent_rain_inp = latent_rain_inp*self.weight_latent
            rain_img = self.inp_rain_decoder(latent_rain_inp)
            latent_bg, inp_bg_fea, bg_enc_level1, bg_enc_level2, bg_enc_level3 = self.inp_bg_encoder(inp_img)
            latent_bg = latent_bg-latent_rain_inp
            bg_img = self.inp_bg_decoder(latent_bg, bg_enc_level1, bg_enc_level2, bg_enc_level3)
            re_rain_img = rain_img + bg_img

            return inp_rain_fea, weakly_rain_fea, inp_bg_fea, weakly_bg_fea, re_rain_img, bg_img, rain_img
        else:
            latent_rain_inp, _,_,_,_,_ = self.inp_rain_encoder(inp_img)
            rain_img = self.inp_rain_decoder(latent_rain_inp)
            latent_bg, _, bg_enc_level1, bg_enc_level2, bg_enc_level3,_ = self.inp_bg_encoder(inp_img)
            bg_img = self.inp_bg_decoder(latent_bg, bg_enc_level1, bg_enc_level2, bg_enc_level3)
            return bg_img


In [None]:
def contrastiveloss(x1, x2, label, margin: float = 1.5):
    """
    Computes Contrastive Loss
    """

    dist = torch.nn.functional.pairwise_distance(x1, x2)

    loss = (1 - label) * torch.pow(dist, 2) + (label) * torch.pow(torch.clamp(margin - dist, min=0.0), 2)
    loss = torch.mean(loss)
    
    return loss

In [None]:
class ContrastiveLoss(nn.Module):
    """Contrastive loss.

    Args:
        loss_weight (float): Loss weight for contrastive loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean', label=0):
        super(ContrastiveLoss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. '
                             f'Supported ones are: {_reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction
        self.label = label

    def forward(self, fea1, fea2, **kwargs):
        """
        Args:

        """
        return self.loss_weight * contrastiveloss(fea1, fea2, self.label)

In [None]:
def l1_loss(pred, target):
    return F.l1_loss(pred, target, reduction='none')

In [None]:
class L1Loss(nn.Module):
    """L1 (mean absolute error, MAE) loss.

    Args:
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(L1Loss, self).__init__()
        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
                weights. Default: None.
        """
        return self.loss_weight * F.l1_loss(
            pred, target, weight, reduction=self.reduction)

In [None]:
def mse_loss(pred, target):
    return F.mse_loss(pred, target, reduction='none')

In [None]:
class MSELoss(nn.Module):
    """L1 (mean absolute error, MAE) loss.

    Args:
        loss_weight (float): Loss weight for L2 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(MSELoss, self).__init__()
        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
                weights. Default: None.
        """
        return self.loss_weight * F.mse_loss(
            pred, target, weight, reduction=self.reduction)

In [None]:
# class SparsityLoss(nn.Module):
#     def __init__(self, loss_weight=1.0, reduction='mean'):
#         super(SparsityLoss, self).__init__()
#         self.loss_weight = loss_weight
#         self.reduction = reduction

#     def forward(self, pred, weight=None, **kwargs):
#         return self.loss_weight * (torch.norm(pred, p=1, dim=(2, 3)).sum(dim=1))# + 0.2*(256*256*3-torch.count_nonzero(torch.clamp(pred-0.1, min=0, max=1), dim=(2,3)).sum(dim=1)))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import DatasetFolder
from torchvision.utils import save_image
import torchvision

import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import DatasetFolder, ImageFolder
from torchvision.io import read_image
import torchvision.transforms.functional as FT

In [None]:
image_size = 256

# Đường dẫn đến thư mục chứa ảnh input và ground truth
root_dataset = '/kaggle/input/downscaleraindataset/RealTrafficRain/train/TrafficRain'

dataset = ImageFolder(root=root_dataset)

In [None]:
dataset[6800][1]

In [None]:
import random
import numpy as np

In [None]:
class DerainingDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def transform(self, image, mask):
        # Resize
        resize = transforms.Resize(size=(640, 480))
        image = resize(image)
        mask = resize(mask)

        # Random crop
        seed = np.random.randint(42)
        random.seed(seed)
        torch.manual_seed(seed)
        i, j, h, w = transforms.RandomCrop.get_params(
            image, output_size=(256, 256))
        image = FT.crop(image, i, j, h, w)
        mask = FT.crop(mask, i, j, h, w)
        
        image = FT.to_tensor(image)
        mask = FT.to_tensor(mask)
        return image, mask
        
    def __getitem__(self, index):
        input_image = self.dataset[index%6800][0]
        ground_truth_image = self.dataset[6800+index%6800][0]
        return self.transform(input_image, ground_truth_image)

    def __len__(self):
        return 6800

# Tạo đối tượng dataset deraining
random.seed(42)
deraining_dataset = DerainingDataset(dataset)

# Tạo DataLoader cho việc training
batch_size = 2
deraining_dataloader = DataLoader(deraining_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
print(f"Number of training examples: {len(deraining_dataset)}")
print(f"Number of batches in the dataloader: {len(deraining_dataloader)}")

In [None]:
from tqdm import tqdm

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = DerainingFramework(is_train=False)
# # Wrap the model with DataParallel
# if torch.cuda.device_count() > 1:
#     print("Using", torch.cuda.device_count(), "GPUs!")
#     model = nn.DataParallel(model, device_ids=[0, 1])
# model.load_state_dict(torch.load('/kaggle/input/2004verlightblue/ver1804_derain2024_7.pth'))

# model = model.to(device)

In [None]:
# plt.imshow(dataset[0][0].resize((1280,720)))

In [None]:
# !mkdir /kaggle/working/outputimg
# !mkdir /kaggle/working/rainmapimg
# !mkdir /kaggle/working/inputimg

In [None]:
from PIL import Image

In [None]:
topil = transforms.ToPILImage()
# idclear = 0
# idrain = 0
# idinput = 0
# for i, (input_images, weakly_grths) in enumerate(tqdm(deraining_dataloader)):
#         input_images = input_images.to(device)
#         weakly_grths = weakly_grths.to(device)
        
#         clear_imgs = model(input_images)
#         clear_imgs = clear_imgs
# #         print(clear_imgs.shape)
#         for clear_img in clear_imgs:
#             print(torch.min(clear_img))
#             print(torch.max(clear_img))
#             print(torch.count_nonzero(clear_img))
#             pil_clear_img = topil(clear_img.detach().cpu())
#             pil_clear_img.save(f'/kaggle/working/outputimg/{idclear}.png')
#             img = Image.open(f'/kaggle/working/outputimg/{idclear}.png')
#             plt.imshow(img)
#             plt.show()
#             idclear += 1
# #         for rainmap_img in clear_imgs:
# #             pil_rainmap_img = topil(rainmap_img)
# #             pil_rainmap_img.save(f'/kaggle/working/rainmapimg/{idrain}.png')
# #             idrain += 1
#         for input_img in input_images:
#             pil_input_img = topil(input_img)
#             pil_input_img.save(f'/kaggle/working/inputimg/{idinput}.png')
# #             img2 = Image.open(f'/kaggle/working/outputimg/{idinput}.png')
# #             plt.imshow(img2)
# #             plt.show()
#             idinput += 1
            
#         if idinput == 30:
#             break
            

In [None]:
# clear_img[0][0][:3,:128,:128].shape

In [None]:
# for i, (input_images, weakly_grths) in enumerate(tqdm(deraining_dataloader)):
#         input_images = input_images.to(device)
#         weakly_grths = weakly_grths.to(device)
#         print(torch.norm(input_images, p=1, dim=(2, 3)).sum(dim=1))
#         print(torch.count_nonzero(input_images, dim=(2,3)).sum(dim=1))

In [None]:
def split_image_into_patches(image_tensor, patch_size, stride):
    """
    Split an image tensor into patches.
    
    Args:
        image_tensor (torch.Tensor): Input image tensor of shape (N, C, H, W).
        patch_size (int or tuple): Size of the patches, can be a single integer or a tuple (patch_height, patch_width).
        stride (int or tuple): Stride of the sliding window, can be a single integer or a tuple (vertical_stride, horizontal_stride).
        
    Returns:
        torch.Tensor: Tensor containing patches of shape (N, num_patches, C, patch_height, patch_width).
    """
    # If patch_size or stride is an integer, convert to tuple
    if isinstance(patch_size, int):
        patch_size = (patch_size, patch_size)
    if isinstance(stride, int):
        stride = (stride, stride)
    
    # Unfold the image tensor to get patches
    patches = image_tensor.unfold(2,patch_size[0], stride[0]).unfold(3,patch_size[0], stride[0]).permute(0,2,3,1,4,5).reshape(2,-1,3,patch_size[0], patch_size[0])
    
    return patches

In [None]:
# for i, (input_images, weakly_grths) in enumerate(tqdm(deraining_dataloader)):
#     print(input_images[0].shape)
#     pil_inp = topil(input_images[0])
#     plt.imshow(pil_inp)
#     plt.show()
#     patch_outputs = split_image_into_patches(input_images, (64,64),(32,32)).detach().cpu()
#     print(patch_outputs.shape)
#     for img1, img2 in zip(patch_outputs[0], patch_outputs[1]):
#         pil_img1 = topil(img1)
#         pil_img2 = topil(img2)
#         plt.imshow(pil_img1)
#         plt.show()

#         plt.imshow(pil_img2)
#         plt.show()
#         break



#     break

In [None]:
import torch

def compute_l2_distance_per_patch_pair(patch_tensor1, patch_tensor2):
    """
    Compute L2 distance for each pair of patches between two sets of patch tensors.
    
    Args:
        patch_tensor1 (torch.Tensor): First set of patch tensors of shape (N, num_patches, C, patch_height, patch_width).
        patch_tensor2 (torch.Tensor): Second set of patch tensors of shape (N, num_patches, C, patch_height, patch_width).
        
    Returns:
        torch.Tensor: L2 distances for each pair of patches of shape (N, num_patches).
    """
    # Reshape tensors for broadcasting
    patch_tensor1_expanded = patch_tensor1.unsqueeze(2)  # Shape: (N, num_patches, 1, C, patch_height, patch_width)
    patch_tensor2_expanded = patch_tensor2.unsqueeze(2)  # Shape: (N, num_patches, 1, C, patch_height, patch_width)
    
    # Compute element-wise squared difference
    squared_diff = (patch_tensor1_expanded - patch_tensor2_expanded) ** 2
    
    # Sum along the channel, height, and width axes
    sum_squared_diff = torch.sum(squared_diff, dim=(2, 3, 4, 5))
    
    # Compute square root to get L2 distance
    l2_distance = torch.sqrt(sum_squared_diff)
    
    return l2_distance

# Example usage
patch_tensor1 = torch.randn(1, 49, 3, 64, 64)  # First set of patch tensors
patch_tensor2 = torch.randn(1, 49, 3, 64, 64)  # Second set of patch tensors
l2_distances = compute_l2_distance_per_patch_pair(patch_tensor1, patch_tensor2)

print("L2 distances shape:", l2_distances[:,:].shape)

In [None]:
# l2 = l2_distances[:,:]

# indices = torch.topk(l2, dim=1, k=5).indices

In [None]:
# indices

In [None]:
# patch_tensor1[torch.arange(patch_tensor1.size(0)).unsqueeze(1), indices].shape

In [None]:
# print(patch_tensor1[0][indices[0]].shape)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DerainingFramework(is_train=True)

# Wrap the model with DataParallel
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model, device_ids=[0, 1])
    
model.load_state_dict(torch.load('/kaggle/input/2304verocean/wo_spar_ver1205_derain2024_7.pth'))

model = model.to(device)

learning_rate = 1e-5
# Define the loss function and optimizers
sparsity_loss = L1Loss(loss_weight=0.1)
mse_loss = MSELoss()
w_l1_loss = MSELoss(loss_weight=0.1)
cl_pos = ContrastiveLoss(label=0, loss_weight=1)
cl_neg = ContrastiveLoss(label=1, loss_weight=1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    training_loss = 0.0
    rain_loss = 0.0
    background_loss = 0.0
    con_loss = 0.0
    w_con_loss = 0.0
    spar_loss = 0.0
    for i, (input_images, weakly_grths) in enumerate(tqdm(deraining_dataloader)):
        input_images = input_images.to(device)
#         cp_weakly_grths = weakly_grths
        weakly_grths = weakly_grths.to(device)
        
        optimizer.zero_grad()

        inp_rain_fea, weakly_rain_fea, inp_bg_fea, weakly_bg_fea, re_rain_img, bg_img, rain_img = model(input_images, weakly_grths)

#         patch_outputs = split_image_into_patches(bg_img, (64, 64), (32, 32)).detach().cpu()
#         patch_inputs = split_image_into_patches(input_images, (64, 64), (32, 32)).detach().cpu()
#         patch_grths = split_image_into_patches(weakly_grths, (64, 64), (32, 32)).detach().cpu()

#         l2_distances = compute_l2_distance_per_patch_pair(patch_inputs, patch_grths)
# #         l2_dis = torch.diagonal(l2_distances, dim1=1, dim2=2)
        
#         indices = torch.topk(l2_distances, dim=1, k=5, largest=False).indices
        
#         topk_output_patches = patch_outputs[torch.arange(patch_outputs.size(0)).unsqueeze(1), indices]
#         topk_input_patches = patch_inputs[torch.arange(patch_inputs.size(0)).unsqueeze(1), indices]
#         topk_grth_patches = patch_grths[torch.arange(patch_grths.size(0)).unsqueeze(1), indices]
#         topk_output_patches = topk_output_patches.to(device)
#         topk_input_patches = topk_input_patches.to(device)
#         topk_grth_patches = topk_grth_patches.to(device)
#         pil_img1 = topil(topk_input_patches[0][1])
#         pil_img2 = topil(topk_grth_patches[0][1])
#         print("k")
#         plt.imshow(pil_img1)
#         plt.show()
#         print("k")
#         plt.imshow(pil_img2)
#         plt.show()
#         break
        
        l_con = mse_loss(re_rain_img, input_images)
#         w_l_con = w_l1_loss(topk_output_patches, topk_grth_patches)
        l_rain = cl_neg(inp_rain_fea, inp_bg_fea) + 0.01*cl_neg(weakly_bg_fea, weakly_rain_fea)
        l_bg = cl_pos(inp_bg_fea, weakly_bg_fea)# + cl_neg(weakly_bg_fea, weakly_rain_fea).mean()
        l_sparsity = sparsity_loss(input_images, bg_img)

        loss = l_rain + l_bg + l_con# + 0.0*l_sparsity# + w_l_con
        training_loss += loss.item()
        rain_loss += l_rain.item()
        background_loss += l_bg.item()
        con_loss += l_con.item()
#         w_con_loss += w_l_con.item()
        spar_loss += l_sparsity.item()
#         print(f'sparsity loss: {spar_loss}, weak l1 loss: {w_con_loss}, rain loss: {rain_loss}, back loss: {background_loss}, mse loss: {con_loss}, sum loss: {training_loss}')


#         print(f'sparsity loss: {spar_loss}, rain loss: {rain_loss}, back loss: {background_loss}, l1 loss: {con_loss}, weak l1 loss: {w_con_loss}, sum loss: {training_loss}')


        loss.backward()
        optimizer.step()
#     break

    print(f'sparsity loss: {spar_loss}, weak l1 loss: {w_con_loss}, rain loss: {rain_loss}, back loss: {background_loss}, mse loss: {con_loss}, sum loss: {training_loss}')
    des = f'/kaggle/working/derain2024_{epoch}.pth'
    torch.save(model.state_dict(), des)    