<a href="https://colab.research.google.com/github/esraaelbaz/graduation_project/blob/main/O_HAZE_CANT_Haze_v10_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Dependencies

In [None]:
## To install pytorch 2.1.0 run this code block twice!
# try:
#   from torch import compile as t_compile
# except:
#   !pip install numpy --pre torch[dynamo] torchvision torchaudio --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
#   import os
#   os.kill(os.getpid(), 9)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image as imwrite

import numpy as np
from torchvision import transforms
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms.functional as TF
import random
from math import exp, log10, ceil
import gc

try:
  import einops
  from einops import rearrange
except:
  !pip install einops
  import einops
  from einops import rearrange

import os
import gc

try:
  from torchmetrics import PeakSignalNoiseRatio as PSNR
  from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
except:
  !pip install torchmetrics
  from torchmetrics import PeakSignalNoiseRatio as PSNR
  from torchmetrics import StructuralSimilarityIndexMeasure as SSIM

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.4


# Patchify

## Custom Size

In [None]:
def custom_patchify(frame_in, crops_size, overlap_size):
    patch_list = []
    oversize_H = False
    oversize_W = False
    crops_size_H = crops_size[1]
    crops_size_W = crops_size[0]
    overlap_size_H = overlap_size[1]
    overlap_size_W = overlap_size[0]
    frame_in_H = frame_in.shape[-2]
    frame_in_W = frame_in.shape[-1]
    assert (crops_size_H >= 2 * overlap_size_H) and (crops_size_W >= 2 * overlap_size_W), "Crops size should be at least 2x greater than overlap size"
    crops_per_row = 1
    row_size = crops_size_H
    while(row_size < frame_in_H):
      # print(row_size)
      row_size += (crops_size_H - overlap_size_H)
      crops_per_row += 1

    final_row_size = row_size - crops_size_H

    crops_per_col = 1
    col_size = crops_size_W
    while(col_size < frame_in_W):
      # print(col_size)
      col_size += (crops_size_W - overlap_size_W)
      crops_per_col += 1

    final_col_size = col_size - crops_size_W

    oversize_value_H = (int) (frame_in_H - final_row_size)
    oversize_value_W = (int) (frame_in_W - final_col_size)
    if (oversize_value_H != 0):
      oversize_H = True
    if (oversize_value_W != 0):
      oversize_W = True

    top = 0
    height = crops_size_H
    for i in range(crops_per_row):

      left = 0
      crop = []
      if(i == (crops_per_row - 1)):
        if (oversize_H == True):
          height = oversize_value_H
        else:
          height = crops_size_H

      for j in range(crops_per_col):

        if((j != (crops_per_col - 1)) or oversize_W == False):
          width = crops_size_W
        elif(oversize_W == True):
          width = oversize_value_W

        # print(height, width)
        crop.append(TF.crop(frame_in, top, left, height, width)) ##top , Left , Height , Width)
        # print("LEFT", j, left)
        left += crops_size_W - overlap_size_W

      patch_list.append(crop)
      # print("TOP" ,i, top)
      top += crops_size_H - overlap_size_H

    return patch_list, crops_per_row, crops_per_col

def custom_unpatchify(patch_list, overlap_size, crops_per_row, crops_per_col):

    overlap_size_H = overlap_size[1]
    overlap_size_W = overlap_size[0]
    unpatch_list = []
    unpatch_list_W = []
    unpatch_list_H = []
    end_W = patch_list[0][0].shape[-1] - overlap_size_W
    end_H = patch_list[0][0].shape[-2] - overlap_size_H

    for i in range(crops_per_row):
      crop_unpatch_list = []
      for j in range(crops_per_col):
        if(j == 0):
          crop_W = 0
        else:
          crop_W = overlap_size_W

        if(j != (crops_per_col - 1)):
          crop_unpatch_list.append(patch_list[i][j][:, :, :, crop_W : end_W])
          # print(patch_list[i][j].shape)
        else:
          crop_unpatch_list.append(patch_list[i][j][:, :, :, crop_W :])
          # print(patch_list[i][j].shape)

        if((j+1) < crops_per_col):
          # print(patch_list[i][j][:, :, :,end_W : ].shape)
          overlapping_area_W = (patch_list[i][j][:, :, :,end_W : ] + \
                                patch_list[i][j + 1][:, :, :,  : overlap_size_W]) / 2
          # print(overlapping_area_W.shape)
          crop_unpatch_list.append(overlapping_area_W)

      # for k in range(len(crop_unpatch_list)):
      #   print(crop_unpatch_list[k].shape[-2])
      unpatch_list_W.append(torch.cat(crop_unpatch_list,-1))

      if((i - 1) >= 0):
        overlapping_area_H = (unpatch_list_W[i - 1][:, :, end_H :, :] + \
                              unpatch_list_W[i][:, :,  : overlap_size_H, :]) / 2
        unpatch_list_H.append(overlapping_area_H)
        if(i == 1):
          unpatch_list_W[i - 1] = unpatch_list_W[i - 1][:, :, : end_H, :]

        else:
          unpatch_list_W[i - 1] = unpatch_list_W[i - 1][:, :, overlap_size_H : end_H, :]
        # print(unpatch_list_W[i - 1].shape)
      if(i == (crops_per_row - 1)):

        unpatch_list_W[i] = unpatch_list_W[i][:, :, overlap_size_H :, :]
        # print(unpatch_list_W[i].shape)

    for z in range(len(unpatch_list_W)):
      unpatch_list.append(unpatch_list_W[z])
      if(z < len(unpatch_list_H)):
        unpatch_list.append(unpatch_list_H[z])

    frame_out = torch.cat(unpatch_list,-2)

    return frame_out

# ----------- testing the function ----------- #
# crop_size = [384,384]
# overlap_size = [100,100]
# transform = transforms.ToTensor()
# img = Image.open(Path("/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/52_GT.png")).convert("RGB")
# img_t = transform(img).unsqueeze(0)
# print(img_t.shape)
# transform = transforms.ToPILImage()
# #transform(img_t.squeeze(0))

# patch_list, crops_per_row, crops_per_col= custom_patchify(img_t, crop_size, overlap_size)
# # print(crops_per_row, crops_per_col)
# frameout=custom_unpatchify(patch_list, overlap_size, crops_per_row, crops_per_col)
# print(frameout.shape)

# transform(frameout.squeeze(0))

## Grid

In [None]:
def grid_patchify(frame_in,crops_amount):
    patch_list_top = []
    patch_list_bottom = []
    crops_per_row = (int) (crops_amount / 2)
    oversize = False
    if(frame_in.shape[-2] % crops_per_row != 0):
        oversize = True
        oversize_value = (int) (frame_in.shape[-2] - (crops_per_row - 1) * (frame_in.shape[-2]//crops_per_row))
    y = (int) (frame_in.shape[-2]//crops_per_row)
    x = (int) (frame_in.shape[-1]//2)
    for i in range(crops_per_row):
        if(i != (crops_per_row - 1) or oversize == False):
            w=TF.crop(frame_in, i*y, 0,  y,  x) ##top , Left , Height , Width
            patch_list_top.append(w)
            w=TF.crop(frame_in, i*y, x,  y,  x) ##top , Left , Height , Width
            patch_list_bottom.append(w)
        elif(oversize):
            w=TF.crop(frame_in, i*y, 0,  oversize_value,  x) ##top , Left , Height , Width
            patch_list_top.append(w)
            w=TF.crop(frame_in, i*y, x,  oversize_value,  x) ##top , Left , Height , Width
            patch_list_bottom.append(w)

    return patch_list_top,patch_list_bottom

def unpatchify(patch_list_T,patch_list_B):

    frame_out_T=torch.cat(tuple(patch_list_T),-2)
    # print(frame_out_T.shape)
    frame_out_B=torch.cat(tuple(patch_list_B),-2)
    # print(frame_out_B.shape)
    frame_out = torch.cat((frame_out_T,frame_out_B),-1)
    return frame_out

# patch_list_T, patch_list_B=patchify(hazy,4)
# frameout=unpatchify(patch_list_T,patch_list_B)
# frameout.shape

## Horizontal

In [None]:
def horizontal_patchify(img,desiredpatchsize=512):
    patches=[]
    step=img.shape[1]//desiredpatchsize
    stop=0
    oversize = 0
    for i in range (step):
        patches.append(img[:,stop:(stop+desiredpatchsize),:])
        stop+=desiredpatchsize

    if img.shape[1]%desiredpatchsize != 0:
        oversize= desiredpatchsize-(img.shape[1]%desiredpatchsize)
        patches.append(img[:,(stop-oversize):((stop-oversize)+desiredpatchsize),:])
    return patches, oversize

# Data Loader

In [None]:
class CustomDataLoader(Dataset):
    def __init__(self, HAZY_path = None, GT_path = None, image_size = (64,64), crop = False, resize = None):
        self.HAZY_path = Path(HAZY_path)
        self.GT_path = Path(GT_path)
        self.HAZY_Image = []
        self.GT_Image = []
        for extension in ['png', 'jpg', 'JPG']:
            self.HAZY_Image.extend(sorted(self.HAZY_path.glob('*.' + extension))) # list all the files present in HAZY images folder...
            self.GT_Image.extend(sorted(self.GT_path.glob('*.' + extension))) # list all the files present in GT images folder...
        self.HAZY_Image_Name = []

        assert len(self.HAZY_Image) == len(self.GT_Image)
        print(f"Dataset has: {len(self.HAZY_Image)} images")
        self.crop = crop
        self.resize = resize
        if(self.crop):
            self.train_transforms = transforms.Compose([transforms.ToTensor(),
                                                        transforms.TenCrop(image_size)])
        elif(self.resize):
            self.train_transforms = transforms.Compose([transforms.Resize(image_size),
                                                        transforms.ToTensor()])
        else:
            self.train_transforms = transforms.Compose([transforms.ToTensor()])

    def load_image(self, index: int, image_type = "HAZY") -> Image.Image:
        "Opens an image via a path and returns it."

        if image_type == "HAZY":
          image_path = self.HAZY_Image[index]

        elif image_type == "GT":
          image_path = self.GT_Image[index]

        return Image.open(image_path)

    def __len__(self):
        return len(self.HAZY_Image) # return length of dataset

    def __getitem__(self, index):
        #print(len(self.HAZY_Image),len(self.GT_image))
        HAZY = Image.open(self.HAZY_Image[index]).convert("RGB")
        GT = Image.open(self.GT_Image[index]).convert("RGB")
        return self.train_transforms(HAZY), self.train_transforms(GT), self.HAZY_Image[index].stem

#data augmentation for image rotate
def custom_augment(hazy, clean):
    augmentation_method = random.choice([0, 1, 2, 3, 4, 5])
    rotate_degree = random.choice([90, 180, 270])
    '''Rotate'''
    if augmentation_method == 0:
        hazy = transforms.functional.rotate(hazy, rotate_degree)
        clean = transforms.functional.rotate(clean, rotate_degree)
        return hazy, clean
    '''Vertical'''
    if augmentation_method == 1:
        vertical_flip = torchvision.transforms.RandomVerticalFlip(p=1)
        hazy = vertical_flip(hazy)
        clean = vertical_flip(clean)
        return hazy, clean
    '''Horizontal'''
    if augmentation_method == 2:
        horizontal_flip = torchvision.transforms.RandomHorizontalFlip(p=1)
        hazy = horizontal_flip(hazy)
        clean = horizontal_flip(clean)
        return hazy, clean
    '''no change'''
    if augmentation_method == 3 or augmentation_method == 4 or augmentation_method == 5:
        return hazy, clean

class custom_dehaze_train_dataset(Dataset):
    def __init__(self, HAZY_path = None, GT_path = None, Image_Size = (256,256), is_train = True, random_crops = False, random_crop_sizes = [256, 512, 768]):
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.HAZY_path = Path(HAZY_path)
        self.GT_path = Path(GT_path)
        self.HAZY_Image = []
        self.GT_Image = []
        for extension in ['png', 'jpg', 'JPG']:
            self.HAZY_Image.extend(sorted(self.HAZY_path.glob('*.' + extension))) # list all the files present in HAZY images folder...
            self.GT_Image.extend(sorted(self.GT_path.glob('*.' + extension))) # list all the files present in GT images folder...
        self.Image_Size = Image_Size
        self.is_train = is_train
        self.random_crops = random_crops
        self.random_crop_sizes = random_crop_sizes
        print(f"Dataset has: {len(self.HAZY_Image)} images")

    def __getitem__(self, index):
        hazy = Image.open(self.HAZY_Image[index]).convert("RGB")
        clean = Image.open(self.GT_Image[index]).convert("RGB")
        if self.is_train:
            #crop a patch

            if self.random_crops:
                width = random.choice(self.random_crop_sizes)
                height = random.choice(self.random_crop_sizes)
                i,j,h,w = transforms.RandomCrop.get_params(hazy, output_size = (height, width))
            else:
                i,j,h,w = transforms.RandomCrop.get_params(hazy, output_size = self.Image_Size)
            hazy_ = TF.crop(hazy, i, j, h, w)
            clean_ = TF.crop(clean, i, j, h, w)

            #data argumentation
            hazy_arg, clean_arg = custom_augment(hazy_, clean_)
            hazy = self.transform(hazy_arg)
            clean = self.transform(clean_arg)
            return hazy,clean
        else:
            hazy = self.transform(hazy)
            clean = self.transform(clean)
            return hazy,clean

    def __len__(self):
        return len(self.HAZY_Image) # return length of dataset

class dehaze_test_dataset(Dataset):
    def __init__(self, HAZY_PATH = None, GT_PATH = None, Patchify = False):
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.root_hazy = Path(HAZY_PATH)
        self.root_GT = Path(GT_PATH)
        self.patchify = Patchify
        self.list_test = []
        self.list_GT = []
        for extension in ['png', 'jpg', 'JPG']:
            self.list_test.extend(sorted(self.root_hazy.glob('*.' + extension))) # list all the files present in HAZY images folder...
            self.list_GT.extend(sorted(self.root_GT.glob('*.' + extension))) # list all the files present in GT images folder...
        self.file_len = len(self.list_test)
    def __getitem__(self, index, is_train=True):
        hazy = Image.open(self.list_test[index]).convert("RGB")
        hazy = self.transform(hazy)
        if(self.patchify):
          hazy, oversize = horizontal_patchify(hazy)
        name=self.list_test[index].stem
        if (len(self.list_GT) == 0) and self.patchify:
          return hazy, oversize, name
        elif len(self.list_GT) == 0:
          return hazy, name
        else:
          clean=Image.open(self.list_GT[index]).convert("RGB")
          clean = self.transform(clean)
          clean = clean
          if self.patchify:
            return hazy, oversize, clean, name
          else:
            return hazy, clean, name
    def __len__(self):
        return self.file_len

# DWT

## DWT functions

In [None]:
# Copyright (c) 2019, Adobe Inc. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License. To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.

"""
自定义pytorch函数，实现一维、二维、三维张量的DWT和IDWT，未考虑边界延拓
只有当图像行列数都是偶数，且重构滤波器组低频分量长度为2时，才能精确重构，否则在边界处有误差。
"""
import torch
from torch.autograd import Function


class DWTFunction_1D(Function):
    @staticmethod
    def forward(ctx, input, matrix_Low, matrix_High):
        ctx.save_for_backward(matrix_Low, matrix_High)
        L = torch.matmul(input, matrix_Low.t())
        H = torch.matmul(input, matrix_High.t())
        return L, H

    @staticmethod
    def backward(ctx, grad_L, grad_H):
        matrix_L, matrix_H = ctx.saved_tensors
        grad_input = torch.add(torch.matmul(
            grad_L, matrix_L), torch.matmul(grad_H, matrix_H))
        return grad_input, None, None


class IDWTFunction_1D(Function):
    @staticmethod
    def forward(ctx, input_L, input_H, matrix_L, matrix_H):
        ctx.save_for_backward(matrix_L, matrix_H)
        output = torch.add(torch.matmul(input_L, matrix_L),
                           torch.matmul(input_H, matrix_H))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        matrix_L, matrix_H = ctx.saved_tensors
        grad_L = torch.matmul(grad_output, matrix_L.t())
        grad_H = torch.matmul(grad_output, matrix_H.t())
        return grad_L, grad_H, None, None


class DWTFunction_2D(Function):
    @staticmethod
    def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1,
                              matrix_High_0, matrix_High_1)
        L = torch.matmul(matrix_Low_0, input)
        H = torch.matmul(matrix_High_0, input)
        LL = torch.matmul(L, matrix_Low_1)
        LH = torch.matmul(L, matrix_High_1)
        HL = torch.matmul(H, matrix_Low_1)
        HH = torch.matmul(H, matrix_High_1)
        return LL, LH, HL, HH

    @staticmethod
    def backward(ctx, grad_LL, grad_LH, grad_HL, grad_HH):
        matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors
        grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()),
                           torch.matmul(grad_LH, matrix_High_1.t()))
        grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()),
                           torch.matmul(grad_HH, matrix_High_1.t()))
        grad_input = torch.add(torch.matmul(
            matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
        return grad_input, None, None, None, None


class DWTFunction_2D_tiny(Function):
    @staticmethod
    def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1,
                              matrix_High_0, matrix_High_1)
        L = torch.matmul(matrix_Low_0, input)
        LL = torch.matmul(L, matrix_Low_1)
        return LL

    @staticmethod
    def backward(ctx, grad_LL):
        matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors
        grad_L = torch.matmul(grad_LL, matrix_Low_1.t())
        grad_input = torch.matmul(matrix_Low_0.t(), grad_L)
        return grad_input, None, None, None, None


class IDWTFunction_2D(Function):
    @staticmethod
    def forward(ctx, input_LL, input_LH, input_HL, input_HH,
                matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1,
                              matrix_High_0, matrix_High_1)
        L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()),
                      torch.matmul(input_LH, matrix_High_1.t()))
        H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()),
                      torch.matmul(input_HH, matrix_High_1.t()))
        output = torch.add(torch.matmul(matrix_Low_0.t(), L),
                           torch.matmul(matrix_High_0.t(), H))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors
        grad_L = torch.matmul(matrix_Low_0, grad_output)
        grad_H = torch.matmul(matrix_High_0, grad_output)
        grad_LL = torch.matmul(grad_L, matrix_Low_1)
        grad_LH = torch.matmul(grad_L, matrix_High_1)
        grad_HL = torch.matmul(grad_H, matrix_Low_1)
        grad_HH = torch.matmul(grad_H, matrix_High_1)
        return grad_LL, grad_LH, grad_HL, grad_HH, None, None, None, None


class DWTFunction_3D(Function):
    @staticmethod
    def forward(ctx, input,
                matrix_Low_0, matrix_Low_1, matrix_Low_2,
                matrix_High_0, matrix_High_1, matrix_High_2):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
                              matrix_High_0, matrix_High_1, matrix_High_2)
        L = torch.matmul(matrix_Low_0, input)
        H = torch.matmul(matrix_High_0, input)
        LL = torch.matmul(L, matrix_Low_1).transpose(dim0=2, dim1=3)
        LH = torch.matmul(L, matrix_High_1).transpose(dim0=2, dim1=3)
        HL = torch.matmul(H, matrix_Low_1).transpose(dim0=2, dim1=3)
        HH = torch.matmul(H, matrix_High_1).transpose(dim0=2, dim1=3)
        LLL = torch.matmul(matrix_Low_2, LL).transpose(dim0=2, dim1=3)
        LLH = torch.matmul(matrix_Low_2, LH).transpose(dim0=2, dim1=3)
        LHL = torch.matmul(matrix_Low_2, HL).transpose(dim0=2, dim1=3)
        LHH = torch.matmul(matrix_Low_2, HH).transpose(dim0=2, dim1=3)
        HLL = torch.matmul(matrix_High_2, LL).transpose(dim0=2, dim1=3)
        HLH = torch.matmul(matrix_High_2, LH).transpose(dim0=2, dim1=3)
        HHL = torch.matmul(matrix_High_2, HL).transpose(dim0=2, dim1=3)
        HHH = torch.matmul(matrix_High_2, HH).transpose(dim0=2, dim1=3)
        return LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH

    @staticmethod
    def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH,
                 grad_HLL, grad_HLH, grad_HHL, grad_HHH):
        matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_tensors
        grad_LL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HLL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_LH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HLH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_HL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HHL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_HH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HHH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()),
                           torch.matmul(grad_LH, matrix_High_1.t()))
        grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()),
                           torch.matmul(grad_HH, matrix_High_1.t()))
        grad_input = torch.add(torch.matmul(
            matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
        return grad_input, None, None, None, None, None, None, None, None


class IDWTFunction_3D(Function):
    @staticmethod
    def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH,
                input_HLL, input_HLH, input_HHL, input_HHH,
                matrix_Low_0, matrix_Low_1, matrix_Low_2,
                matrix_High_0, matrix_High_1, matrix_High_2):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
                              matrix_High_0, matrix_High_1, matrix_High_2)
        input_LL = torch.add(torch.matmul(matrix_Low_2.t(), input_LLL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HLL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_LH = torch.add(torch.matmul(matrix_Low_2.t(), input_LLH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HLH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_HL = torch.add(torch.matmul(matrix_Low_2.t(), input_LHL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HHL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_HH = torch.add(torch.matmul(matrix_Low_2.t(), input_LHH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HHH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()),
                            torch.matmul(input_LH, matrix_High_1.t()))
        input_H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()),
                            torch.matmul(input_HH, matrix_High_1.t()))
        output = torch.add(torch.matmul(matrix_Low_0.t(), input_L),
                           torch.matmul(matrix_High_0.t(), input_H))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_tensors
        grad_L = torch.matmul(matrix_Low_0, grad_output)
        grad_H = torch.matmul(matrix_High_0, grad_output)
        grad_LL = torch.matmul(grad_L, matrix_Low_1).transpose(dim0=2, dim1=3)
        grad_LH = torch.matmul(grad_L, matrix_High_1).transpose(dim0=2, dim1=3)
        grad_HL = torch.matmul(grad_H, matrix_Low_1).transpose(dim0=2, dim1=3)
        grad_HH = torch.matmul(grad_H, matrix_High_1).transpose(dim0=2, dim1=3)
        grad_LLL = torch.matmul(
            matrix_Low_2, grad_LL).transpose(dim0=2, dim1=3)
        grad_LLH = torch.matmul(
            matrix_Low_2, grad_LH).transpose(dim0=2, dim1=3)
        grad_LHL = torch.matmul(
            matrix_Low_2, grad_HL).transpose(dim0=2, dim1=3)
        grad_LHH = torch.matmul(
            matrix_Low_2, grad_HH).transpose(dim0=2, dim1=3)
        grad_HLL = torch.matmul(
            matrix_High_2, grad_LL).transpose(dim0=2, dim1=3)
        grad_HLH = torch.matmul(
            matrix_High_2, grad_LH).transpose(dim0=2, dim1=3)
        grad_HHL = torch.matmul(
            matrix_High_2, grad_HL).transpose(dim0=2, dim1=3)
        grad_HHH = torch.matmul(
            matrix_High_2, grad_HH).transpose(dim0=2, dim1=3)
        return grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH, None, None, None, None, None, None

## DWT Layers

In [None]:
"""
自定义 pytorch 层，实现一维、二维、三维张量的 DWT 和 IDWT，未考虑边界延拓
只有当图像行列数都是偶数，且重构滤波器组低频分量长度为 2 时，才能精确重构，否则在边界处有误差。
"""
import math
# import wave

import numpy as np
import pywt
import torch
from torch.nn import Module

# from .DWT_IDWT_Functions import DWTFunction_1D, IDWTFunction_1D, \
#     DWTFunction_2D_tiny, DWTFunction_2D, IDWTFunction_2D, \
#     DWTFunction_3D, IDWTFunction_3D


__all__ = ['DWT_1D', 'IDWT_1D', 'DWT_2D',
           'IDWT_2D', 'DWT_3D', 'IDWT_3D', 'DWT_2D_tiny']


class DWT_1D(Module):
    """
    input: the 1D data to be decomposed -- (N, C, Length)
    output: lfc -- (N, C, Length/2)
            hfc -- (N, C, Length/2)
    """

    def __init__(self, wavename):
        """
        1D discrete wavelet transform (DWT) for sequence decomposition
        用于序列分解的一维离散小波变换 DWT
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(DWT_1D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.rec_lo
        self.band_high = wavelet.rec_hi
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        生成变换矩阵
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = self.input_height
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)
        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_h = matrix_h[:, (self.band_length_half - 1):end]
        matrix_g = matrix_g[:, (self.band_length_half - 1):end]
        if torch.cuda.is_available():
            self.matrix_low = torch.Tensor(matrix_h).cuda()
            self.matrix_high = torch.Tensor(matrix_g).cuda()
        else:
            self.matrix_low = torch.Tensor(matrix_h)
            self.matrix_high = torch.Tensor(matrix_g)

    def forward(self, input):
        """
        input_low_frequency_component = \mathcal{L} * input
        input_high_frequency_component = \mathcal{H} * input
        :param input: the data to be decomposed
        :return: the low-frequency and high-frequency components of the input data
        """
        assert len(input.size()) == 3
        self.input_height = input.size()[-1]
        self.get_matrix()
        return DWTFunction_1D.apply(input, self.matrix_low, self.matrix_high)


class IDWT_1D(Module):
    """
    input:  lfc -- (N, C, Length/2)
            hfc -- (N, C, Length/2)
    output: the original data -- (N, C, Length)
    """

    def __init__(self, wavename):
        """
        1D inverse DWT (IDWT) for sequence reconstruction
        用于序列重构的一维离散小波逆变换 IDWT
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(IDWT_1D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.dec_lo
        self.band_high = wavelet.dec_hi
        self.band_low.reverse()
        self.band_high.reverse()
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        generating the matrices: \mathcal{L}, \mathcal{H}
        生成变换矩阵
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = self.input_height
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)
        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_h = matrix_h[:, (self.band_length_half - 1):end]
        matrix_g = matrix_g[:, (self.band_length_half - 1):end]
        if torch.cuda.is_available():
            self.matrix_low = torch.Tensor(matrix_h).cuda()
            self.matrix_high = torch.Tensor(matrix_g).cuda()
        else:
            self.matrix_low = torch.Tensor(matrix_h)
            self.matrix_high = torch.Tensor(matrix_g)

    def forward(self, L, H):
        """
        :param L: the low-frequency component of the original data
        :param H: the high-frequency component of the original data
        :return: the original data
        """
        assert len(L.size()) == len(H.size()) == 3
        self.input_height = L.size()[-1] + H.size()[-1]
        self.get_matrix()
        return IDWTFunction_1D.apply(L, H, self.matrix_low, self.matrix_high)


class DWT_2D_tiny(Module):
    """
    input: the 2D data to be decomposed -- (N, C, H, W)
    output -- lfc: (N, C, H/2, W/2)
              #hfc_lh: (N, C, H/2, W/2)
              #hfc_hl: (N, C, H/2, W/2)
              #hfc_hh: (N, C, H/2, W/2)
    DWT_2D_tiny only outputs the low-frequency component, which is used in WaveCNet;
    the all four components could be get using DWT_2D, which is used in WaveUNet.
    """

    def __init__(self, wavename):
        """
        2D discrete wavelet transform (DWT) for 2D image decomposition
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(DWT_2D_tiny, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.rec_lo
        self.band_high = wavelet.rec_hi
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        生成变换矩阵
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)

        if torch.cuda.is_available():
            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
        else:
            self.matrix_low_0 = torch.Tensor(matrix_h_0)
            self.matrix_low_1 = torch.Tensor(matrix_h_1)
            self.matrix_high_0 = torch.Tensor(matrix_g_0)
            self.matrix_high_1 = torch.Tensor(matrix_g_1)

    def forward(self, input):
        """
        input_lfc = \mathcal{L} * input * \mathcal{L}^T
        #input_hfc_lh = \mathcal{H} * input * \mathcal{L}^T
        #input_hfc_hl = \mathcal{L} * input * \mathcal{H}^T
        #input_hfc_hh = \mathcal{H} * input * \mathcal{H}^T
        :param input: the 2D data to be decomposed
        :return: the low-frequency component of the input 2D data
        """
        assert len(input.size()) == 4
        self.input_height = input.size()[-2]
        self.input_width = input.size()[-1]
        self.get_matrix()
        return DWTFunction_2D_tiny.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)


class DWT_2D(Module):
    """
    input: the 2D data to be decomposed -- (N, C, H, W)
    output -- lfc: (N, C, H/2, W/2)
              hfc_lh: (N, C, H/2, W/2)
              hfc_hl: (N, C, H/2, W/2)
              hfc_hh: (N, C, H/2, W/2)
    """

    def __init__(self, wavename):
        """
        2D discrete wavelet transform (DWT) for 2D image decomposition
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(DWT_2D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.rec_lo
        self.band_high = wavelet.rec_hi
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        生成变换矩阵
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)

        if torch.cuda.is_available():
            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
        else:
            self.matrix_low_0 = torch.Tensor(matrix_h_0)
            self.matrix_low_1 = torch.Tensor(matrix_h_1)
            self.matrix_high_0 = torch.Tensor(matrix_g_0)
            self.matrix_high_1 = torch.Tensor(matrix_g_1)

    def forward(self, input):
        """
        input_lfc = \mathcal{L} * input * \mathcal{L}^T
        input_hfc_lh = \mathcal{H} * input * \mathcal{L}^T
        input_hfc_hl = \mathcal{L} * input * \mathcal{H}^T
        input_hfc_hh = \mathcal{H} * input * \mathcal{H}^T
        :param input: the 2D data to be decomposed
        :return: the low-frequency and high-frequency components of the input 2D data
        """
        assert len(input.size()) == 4
        self.input_height = input.size()[-2]
        self.input_width = input.size()[-1]
        self.get_matrix()
        return DWTFunction_2D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)


class IDWT_2D(Module):
    """
    input:  lfc -- (N, C, H/2, W/2)
            hfc_lh -- (N, C, H/2, W/2)
            hfc_hl -- (N, C, H/2, W/2)
            hfc_hh -- (N, C, H/2, W/2)
    output: the original 2D data -- (N, C, H, W)
    """

    def __init__(self, wavename):
        """
        2D inverse DWT (IDWT) for 2D image reconstruction
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(IDWT_2D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.dec_lo
        self.band_low.reverse()
        self.band_high = wavelet.dec_hi
        self.band_high.reverse()
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        生成变换矩阵
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)
        if torch.cuda.is_available():
            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
        else:
            self.matrix_low_0 = torch.Tensor(matrix_h_0)
            self.matrix_low_1 = torch.Tensor(matrix_h_1)
            self.matrix_high_0 = torch.Tensor(matrix_g_0)
            self.matrix_high_1 = torch.Tensor(matrix_g_1)

    def forward(self, LL, LH, HL, HH):
        """
        recontructing the original 2D data
        the original 2D data = \mathcal{L}^T * lfc * \mathcal{L}
                             + \mathcal{H}^T * hfc_lh * \mathcal{L}
                             + \mathcal{L}^T * hfc_hl * \mathcal{H}
                             + \mathcal{H}^T * hfc_hh * \mathcal{H}
        :param LL: the low-frequency component
        :param LH: the high-frequency component, hfc_lh
        :param HL: the high-frequency component, hfc_hl
        :param HH: the high-frequency component, hfc_hh
        :return: the original 2D data
        """
        assert len(LL.size()) == len(LH.size()) == len(
            HL.size()) == len(HH.size()) == 4
        self.input_height = LL.size()[-2] + HH.size()[-2]
        self.input_width = LL.size()[-1] + HH.size()[-1]
        self.get_matrix()
        return IDWTFunction_2D.apply(LL, LH, HL, HH, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)


class DWT_3D(Module):
    """
    input: the 3D data to be decomposed -- (N, C, D, H, W)
    output: lfc -- (N, C, D/2, H/2, W/2)
            hfc_llh -- (N, C, D/2, H/2, W/2)
            hfc_lhl -- (N, C, D/2, H/2, W/2)
            hfc_lhh -- (N, C, D/2, H/2, W/2)
            hfc_hll -- (N, C, D/2, H/2, W/2)
            hfc_hlh -- (N, C, D/2, H/2, W/2)
            hfc_hhl -- (N, C, D/2, H/2, W/2)
            hfc_hhh -- (N, C, D/2, H/2, W/2)
    """

    def __init__(self, wavename):
        """
        3D discrete wavelet transform (DWT) for 3D data decomposition
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(DWT_3D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.rec_lo
        self.band_high = wavelet.rec_hi
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        生成变换矩阵
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_h_2 = matrix_h[0:(math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_h_2 = matrix_h_2[:, (self.band_length_half - 1):end]

        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)
        matrix_g_2 = matrix_g_2[:, (self.band_length_half - 1):end]
        if torch.cuda.is_available():
            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
            self.matrix_low_2 = torch.Tensor(matrix_h_2).cuda()
            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
            self.matrix_high_2 = torch.Tensor(matrix_g_2).cuda()
        else:
            self.matrix_low_0 = torch.Tensor(matrix_h_0)
            self.matrix_low_1 = torch.Tensor(matrix_h_1)
            self.matrix_low_2 = torch.Tensor(matrix_h_2)
            self.matrix_high_0 = torch.Tensor(matrix_g_0)
            self.matrix_high_1 = torch.Tensor(matrix_g_1)
            self.matrix_high_2 = torch.Tensor(matrix_g_2)

    def forward(self, input):
        """
        :param input: the 3D data to be decomposed
        :return: the eight components of the input data, one low-frequency and seven high-frequency components
        """
        assert len(input.size()) == 5
        self.input_depth = input.size()[-3]
        self.input_height = input.size()[-2]
        self.input_width = input.size()[-1]
        self.get_matrix()
        return DWTFunction_3D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,
                                    self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)


class IDWT_3D(Module):
    """
    input:  lfc -- (N, C, D/2, H/2, W/2)
            hfc_llh -- (N, C, D/2, H/2, W/2)
            hfc_lhl -- (N, C, D/2, H/2, W/2)
            hfc_lhh -- (N, C, D/2, H/2, W/2)
            hfc_hll -- (N, C, D/2, H/2, W/2)
            hfc_hlh -- (N, C, D/2, H/2, W/2)
            hfc_hhl -- (N, C, D/2, H/2, W/2)
            hfc_hhh -- (N, C, D/2, H/2, W/2)
    output: the original 3D data -- (N, C, D, H, W)
    """

    def __init__(self, wavename):
        """
        3D inverse DWT (IDWT) for 3D data reconstruction
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(IDWT_3D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.dec_lo
        self.band_high = wavelet.dec_hi
        self.band_low.reverse()
        self.band_high.reverse()
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        生成变换矩阵
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_h_2 = matrix_h[0:(math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_h_2 = matrix_h_2[:, (self.band_length_half - 1):end]

        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)
        matrix_g_2 = matrix_g_2[:, (self.band_length_half - 1):end]
        if torch.cuda.is_available():
            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()
            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()
            self.matrix_low_2 = torch.Tensor(matrix_h_2).cuda()
            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()
            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()
            self.matrix_high_2 = torch.Tensor(matrix_g_2).cuda()
        else:
            self.matrix_low_0 = torch.Tensor(matrix_h_0)
            self.matrix_low_1 = torch.Tensor(matrix_h_1)
            self.matrix_low_2 = torch.Tensor(matrix_h_2)
            self.matrix_high_0 = torch.Tensor(matrix_g_0)
            self.matrix_high_1 = torch.Tensor(matrix_g_1)
            self.matrix_high_2 = torch.Tensor(matrix_g_2)

    def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH):
        """
        :param LLL: the low-frequency component, lfc
        :param LLH: the high-frequency componetn, hfc_llh
        :param LHL: the high-frequency componetn, hfc_lhl
        :param LHH: the high-frequency componetn, hfc_lhh
        :param HLL: the high-frequency componetn, hfc_hll
        :param HLH: the high-frequency componetn, hfc_hlh
        :param HHL: the high-frequency componetn, hfc_hhl
        :param HHH: the high-frequency componetn, hfc_hhh
        :return: the original 3D input data
        """
        assert len(LLL.size()) == len(LLH.size()) == len(
            LHL.size()) == len(LHH.size()) == 5
        assert len(HLL.size()) == len(HLH.size()) == len(
            HHL.size()) == len(HHH.size()) == 5
        self.input_depth = LLL.size()[-3] + HHH.size()[-3]
        self.input_height = LLL.size()[-2] + HHH.size()[-2]
        self.input_width = LLL.size()[-1] + HHH.size()[-1]
        self.get_matrix()
        return IDWTFunction_3D.apply(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH,
                                     self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,
                                     self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)


# if __name__ == '__main__':
#     dwt = DWT_2D("haar")
#     iwt = IDWT_2D("haar")
#     x = torch.randn(3, 3, 24, 24).cuda()
#     xll = x
#     wavelet_list = []
#     for i in range(3):
#         xll, xlh, xhl, xhh = dwt(xll)
#         wavelet_list.append([xll, xlh, xhl, xhh])

#     # xll = wavelet_list[-1] * torch.randn(xll.shape)
#     for i in range(2)[::-1]:
#         xll, xlh, xhl, xhh = wavelet_list[i]
#         xll = iwt(xll, xlh, xhl, xhh)
#         print(xll.shape)

#     print(torch.sum(x - xll))
#     print(torch.sum(x - iwt(*wavelet_list[0])))

## DWT & IDWT

In [None]:
class DWT(nn.Module):
    def __init__(self):
      super().__init__()
      self.dwt = DWT_2D("haar")
    def forward(self, x):
      xll, xlh, xhl, xhh = self.dwt(x)
      xh = torch.cat([xlh, xhl, xhh], 1)
      return xll, xh

class IDWT(nn.Module):
    def __init__(self):
      super().__init__()
      self.idwt = IDWT_2D("haar")
    def forward(self, xll, high_freq):
      xlh, xhl, xhh = high_freq.chunk(3, dim = 1)
      x = self.idwt(xll, xlh, xhl, xhh)
      return x

class DWT_transform(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.dwt = DWT()
        self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels*3, kernel_size=1, padding=0)
        # self.conv1x1_mix = nn.Conv2d(out_channels*2, out_channels, kernel_size=1, padding=0)

    def forward(self, x):

        b, c, h, w = x.shape
        mod1 = h % 2
        mod2 = w % 2
        if (mod1):
            # print("padding height")
            x = F.pad(x, (0, 0, 0, 1), "replicate")
        if (mod2):
            # print("padding width")
            x = F.pad(x, (0, 1, 0, 0), "replicate")

        dwt_low_frequency, dwt_high_frequency = self.dwt(x)
        dwt_low_frequency = self.conv1x1_low(dwt_low_frequency)
        dwt_high_frequency = self.conv1x1_high(dwt_high_frequency)
        # x = torch.cat([dwt_low_frequency, dwt_high_frequency], 1)
        # x = self.conv1x1_mix(x)

        return dwt_low_frequency, dwt_high_frequency, [mod1, mod2]

class IDWT_transform(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels*3, kernel_size=1, padding=0)
        self.idwt = IDWT()

    def forward(self, x, high_freq):
        x = self.conv1x1_low(x)
        high_freq = self.conv1x1_high(high_freq)
        x = self.idwt(x, high_freq)

        return x

class DWT_IDWT_transform(nn.Module):
  def __init__(self):
      super().__init__()

      # self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
      # self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels*3, kernel_size=1, padding=0)
      self.dwt = DWT()
      self.idwt = IDWT()
  def forward(self, x):
      b, c, h, w = x.shape
      mod1 = h % 2
      mod2 = w % 2
      if (mod1):
          # print("padding height")
          x = F.pad(x, (0, 0, 0, 1), "replicate")
      if (mod2):
          # print("padding width")
          x = F.pad(x, (0, 1, 0, 0), "replicate")

      xll, xh = self.dwt(x)
      x = self.idwt(xll, xh)

      if (mod1):
        x = x[:, :, :-1, :]
      if (mod2):
        x = x[:, :, :, :-1]

      return x

## New DWT

In [None]:
try:
  from pytorch_wavelets import DWTForward, DWTInverse
except:
  !pip install pytorch_wavelets
  from pytorch_wavelets import DWTForward, DWTInverse

class DWT_transform_2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.dwt = DWTForward(J=1, wave='haar', mode='reflect')
        self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels*3, kernel_size=1, padding=0)

    def forward(self, x):
        dwt_low_frequency, dwt_high_frequency = self.dwt(x)
        dwt_low_frequency = self.conv1x1_low(dwt_low_frequency)
        dwt_high_frequency[0] = rearrange(dwt_high_frequency[0], 'b c freq h w -> b (c freq) h w', freq=3)
        dwt_high_frequency[0] = self.conv1x1_high(dwt_high_frequency[0])

        return dwt_low_frequency, dwt_high_frequency

class IDWT_transform_2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1x1_low = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        self.conv1x1_high = nn.Conv2d(in_channels*3, out_channels*3, kernel_size=1, padding=0)
        self.idwt = DWTInverse(wave='haar', mode='reflect')

    def forward(self, x, high_freq):
        x = self.conv1x1_low(x)
        high_freq[0] = self.conv1x1_high(high_freq[0])
        high_freq[0] = rearrange(high_freq[0], 'b (c freq) h w -> b c freq h w', freq=3)
        x = self.idwt((x, high_freq))

        return x

class DWT_IDWT_transform_2(nn.Module):
  def __init__(self):
      super().__init__()

      self.dwt = DWTForward(J=1, wave='haar', mode='reflect')
      self.idwt = DWTInverse(wave='haar', mode='reflect')
  def forward(self, x):
      xll, xh = self.dwt(x)
      xh[0] = rearrange(xh[0], 'b c freq h w -> b (c freq) h w', freq=3)
      xh[0] = rearrange(xh[0], 'b (c freq) h w -> b c freq h w', freq=3)
      x = self.idwt((xll, xh))

      return x

Collecting pytorch_wavelets
  Downloading pytorch_wavelets-1.3.0-py3-none-any.whl (54 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.9/54.9 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pytorch_wavelets
Successfully installed pytorch_wavelets-1.3.0


# Basic Residual Block

In [None]:
class BasicResBlock(nn.Module):
	def __init__(self, channel_num):
		super(BasicResBlock, self).__init__()

		#the input and output channel number is channel_num
		self.conv_block1 = nn.Sequential(
			nn.Conv2d(channel_num, channel_num, 3, padding=1),
			nn.BatchNorm2d(channel_num),
			nn.ReLU(inplace=True),
		)
		self.conv_block2 = nn.Sequential(
			nn.Conv2d(channel_num, channel_num, 3, padding=1),
			nn.BatchNorm2d(channel_num),
		)
		# self.relu = nn.ReLU(inplace=True)

	def forward(self, x):

		residual = x
		x = self.conv_block1(x)
		x = self.conv_block2(x)
		x = x + residual
		# out = self.relu(x)
		return x

# Selective Res Block

In [None]:
def _make_conv_layer(in_channels, out_channels, stride=1, dilation=1, norm_type='bn'):
    conv_layer = [
        nn.ReflectionPad2d(dilation),
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=0, dilation=dilation)
    ]

    if norm_type == 'bn':
        conv_layer.append(nn.BatchNorm2d(out_channels))
    elif norm_type == 'in':
        conv_layer.append(nn.InstanceNorm2d(out_channels))

    return nn.Sequential(*conv_layer)

class selective_res_block(nn.Module):
    def __init__(self, channels, stride=1, activation=nn.ReLU(inplace=True), norm_type='bn'):
        super(selective_res_block, self).__init__()
        self.conv1 = _make_conv_layer(channels, channels, stride=stride, norm_type=norm_type)
        self.conv2 = _make_conv_layer(channels, channels, stride=stride, norm_type=norm_type)
        self.act = activation
        self.a = nn.Parameter(data=torch.ones(1))
        self.b = nn.Parameter(data=torch.ones(1))

    def forward(self, x):
        identity = x
        out = self.act(self.conv1(x))
        out = self.conv2(out)
        out = self.act(out.mul(self.a) + identity.mul(self.b)) # weighted sum

        return out

# RESTBlock

In [None]:
import numbers

##########################################################################
## Layer Norm

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')
    # flatten = nn.Flatten(2,3)
    # return flatten(x).permute(0,2,1)

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
    # unflatten = nn.Unflatten(1,(h,w))
    # return unflatten(x).permute(0,3,1,2)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma+1e-5) * self.weight

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias


class RESTLayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(RESTLayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)



##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x



##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)



    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out



##########################################################################
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor = 2.66, bias = False, LayerNorm_type = 'WithBias'):
        super(TransformerBlock, self).__init__()

        self.norm1 = RESTLayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = RESTLayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x

# Bottleneck

In [None]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.


    def __init__(self, inplanes, planes, stride=1, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * 4)
        self.bn3 = norm_layer(planes * 4)
        self.relu = nn.LeakyReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = out + identity
        out = self.relu(out)

        return out

# Xception (UNet)

In [None]:
"""
Creates an Xception Model as defined in:

Francois Chollet
Xception: Deep Learning with Depthwise Separable Convolutions
https://arxiv.org/pdf/1610.02357.pdf

This weights ported from the Keras implementation. Achieves the following performance on the validation set:

Loss:0.9173 Prec@1:78.892 Prec@5:94.292

REMEMBER to set your image size to 3x299x299 for both test and validation

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])

The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
import math
import torch.nn as nn
import torch.nn.functional as F
# import torch.utils.model_zoo as model_zoo
from torch.nn import init
import torch

class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x

class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.pa(x)
        return x * y

class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class CP_Attention_block(nn.Module):
    def __init__(self, conv, dim, kernel_size):
        super(CP_Attention_block, self).__init__()
        self.conv1 = conv(dim, dim, kernel_size, bias=True)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = conv(dim, dim, kernel_size, bias=True)
        self.calayer = CALayer(dim)
        self.palayer = PALayer(dim)
    def forward(self, x):
        res = self.act1(self.conv1(x))
        res = res + x
        res = self.conv2(res)
        res = self.calayer(res)
        res = self.palayer(res)
        res += x
        return res

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)

class Xception(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """
    def __init__(self):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(Xception, self).__init__()


        self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        #do relu here

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        #do relu here
        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        # self.fc = nn.Linear(2048, num_classes)



        #------- init weights --------
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        #-----------------------------


        self.up_block= nn.PixelShuffle(2)
        self.attention0 = CP_Attention_block(default_conv, 512, 3)
        self.conv_process0 = nn.Conv2d(512, 1024, kernel_size=3,padding=1)

        #upsample here
        self.attention1 = CP_Attention_block(default_conv, 256, 3)
        self.conv_process1 = nn.Conv2d(256, 512, kernel_size=3,padding=1)

        #upsample here
        self.attention2 = CP_Attention_block(default_conv, 128, 3)
        self.conv_process2 = nn.Conv2d(128, 256, kernel_size=3,padding=1)

        #upsample here
        self.attention3 = CP_Attention_block(default_conv, 64, 3)
        self.conv_process3 = nn.Conv2d(64, 128, kernel_size=3,padding=1)

        #upsample here
        self.attention4 = CP_Attention_block(default_conv, 32, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # print("1: ", x.size())
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        # print("2: ", x.size())
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        # print("3: ", x.size())
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        # print("4: ", x.size())
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu(x) # channels = 2048
        # print("5: ", x.size())

        # x = F.adaptive_avg_pool2d(x, (1, 1))
        # x = x.view(x.size(0), -1)
        # x = self.fc(x)


        x = self.up_block(x) # channels = 512
        x = self.attention0(x)
        x = self.conv_process0(x) # channels = 1024
        # print("6: ", x.size())

        x = self.up_block(x) # channels = 256
        x = self.attention1(x)
        x = self.conv_process1(x) # channels = 512
        # print("7: ", x.size())

        x = self.up_block(x) # channels = 128
        x = self.attention2(x)
        x = self.conv_process2(x) # channels = 256
        # print("8: ", x.size())

        x = self.up_block(x) # channels = 64
        x = self.attention3(x)
        x = self.conv_process3(x) # channels = 128
        # print("9: ", x.size())

        x = self.up_block(x) # channels = 32
        x = self.attention4(x)
        # print("10: ", x.size())

        return x



# def xception(pretrained=False,**kwargs):
#     """
#     Construct Xception.
#     """

#     model = Xception(**kwargs)
#     if pretrained:
#         model.load_state_dict(model_zoo.load_url(model_urls['xception']), strict = False)
#     return model

# Model (Main UNet)

In [None]:
from torch.functional import Tensor
import torch.nn as nn
import torch
from functools import partial
import math
import warnings
import torch.nn.functional as f

## Residual & Transformer Sequence Block

In [None]:
class BRB_Transformer(nn.Module):
  def __init__(self, channels, heads):
    super(BRB_Transformer, self).__init__()
    self.ResBlock = BasicResBlock(channel_num = channels)
    self.TransformerBlock = TransformerBlock(dim = channels, num_heads = heads)

  def forward(self, x):
    x = self.ResBlock(x)
    x = self.TransformerBlock(x)

    return x

## Refinement Block

In [None]:
class WAB(nn.Module):
    def __init__(self,n_feats,expand=4):
        super(WAB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feats, n_feats * expand,3,1,1, bias=True),
            nn.BatchNorm2d(n_feats * expand),
            nn.ReLU(True),
            nn.Conv2d(n_feats* expand, n_feats , 3, 1, 1, bias=True),
            nn.BatchNorm2d(n_feats)
        )

    def forward(self, x):
        res = self.body(x).mul(0.2)+x
        return res


class invPixelShuffle(nn.Module):
    def __init__(self, ratio=2):
        super(invPixelShuffle, self).__init__()
        self.ratio = ratio

    def forward(self, tensor):
        ratio = self.ratio
        b, ch, y, x = tensor.shape
        assert x % ratio == 0 and y % ratio == 0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio)
        return tensor.view(b, ch, y // ratio, ratio, x // ratio, ratio).permute(0, 1, 3, 5, 2, 4).contiguous().view(b,-1,y // ratio,x // ratio)

class RefinementBlock(nn.Module):
    def __init__(self, channels):
        super(RefinementBlock, self).__init__()
        self.refinement=nn.Sequential(
                  nn.Conv2d(channels,16,3,1,1, bias=True),
                  nn.BatchNorm2d(16),
                  invPixelShuffle(2),
                  nn.Conv2d(64,16,3,1,1, bias=True),
                  nn.BatchNorm2d(16),
                  nn.Sequential(*[WAB(16) for _ in range(3)]),
                  nn.Conv2d(16, 64, 3, 1, 1, bias=True),
                  nn.PixelShuffle(2),
                  nn.BatchNorm2d(16),
                  nn.Conv2d(16, channels, 3, 1, 1, bias=True)
              )

    def forward(self, x):
        x = self.refinement(x)
        return x

## Encoder Block

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, in_channels, out_channels, levels = 2, heads = 4):
    super(EncoderBlock, self).__init__()

    self.ResBlock = BasicResBlock(channel_num = in_channels)
    self.DWT = DWT_transform(in_channels = in_channels, out_channels = out_channels)
    self.BRBTransformer = nn.Sequential(*[BRB_Transformer(channels = out_channels, heads = heads) for i in range(levels)])

  def forward(self, x):
    x = self.ResBlock(x)
    x, high_freq, pad = self.DWT(x)
    x = self.BRBTransformer(x)

    return x, high_freq, pad

## Decoder Block

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, in_channels, out_channels, levels = 2, heads = 4):
    super(DecoderBlock, self).__init__()

    self.BRBTransformer_lh = nn.Sequential(*[BRB_Transformer(channels = in_channels, heads = heads) for i in range(levels)])
    self.BRBTransformer_hl = nn.Sequential(*[BRB_Transformer(channels = in_channels, heads = heads) for i in range(levels)])
    self.BRBTransformer_hh = nn.Sequential(*[BRB_Transformer(channels = in_channels, heads = heads) for i in range(levels)])
    # self.IDWT = IDWT_transform(in_channels, out_channels)
    self.Deconv = nn.Sequential(
          nn.ConvTranspose2d(in_channels * 4, in_channels * 2, kernel_size=2, stride=2, padding=0),
          nn.ReLU(inplace=True),
          nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, stride=1, padding=1),
          nn.ReLU(inplace=True))
    self.Refinement = RefinementBlock(channels = out_channels)
    # self.ResBlock = BasicResBlock(channel_num = out_channels)

  def forward(self, x, high_freq, pad):
    xlh, xhl, xhh = high_freq.chunk(3, dim = 1)
    xlh = self.BRBTransformer_lh(xlh)
    xhl = self.BRBTransformer_hl(xhl)
    xhh = self.BRBTransformer_hh(xhh)
    x = torch.cat([x, xlh, xhl, xhh], 1)
    x = self.Deconv(x)
    x = self.Refinement(x)
    if (pad[0]):
      x = x[:, :, :-1, :]
    if (pad[1]):
      x = x[:, :, :, :-1]
    # x = self.ResBlock(x)

    return x

## UNet

In [None]:
class CANT_HAZE(nn.Module):
    def __init__(self, channels = 32, levels = 2, heads = 4):
      super(CANT_HAZE, self).__init__()

      self.EncoderL1 = EncoderBlock(in_channels = 3           , out_channels = channels    , levels = levels, heads = heads)
      self.EncoderL2 = EncoderBlock(in_channels = channels    , out_channels = channels * 2, levels = levels, heads = heads)
      self.EncoderL3 = EncoderBlock(in_channels = channels * 2, out_channels = channels * 4, levels = levels, heads = heads)
      self.EncoderL4 = EncoderBlock(in_channels = channels * 4, out_channels = channels * 8, levels = levels, heads = heads)

      self.Bottleneck = Bottleneck(inplanes = channels * 8, planes= int((channels * 8)/4))

      self.DecoderL1 = DecoderBlock(in_channels = channels * 8, out_channels = channels * 4, levels = levels, heads = heads)
      self.DecoderL2 = DecoderBlock(in_channels = channels * 4, out_channels = channels * 2, levels = levels, heads = heads)
      self.DecoderL3 = DecoderBlock(in_channels = channels * 2, out_channels = channels    , levels = levels, heads = heads)
      self.DecoderL4 = DecoderBlock(in_channels = channels    , out_channels = 3          , levels = levels, heads = heads)

      # self.outResBlock = BasicResBlock(channel_num = 3)
      # self.out = nn.Conv2d(3, 3, 1)

    def forward(self, input):

      #-------Encoder L1-------#
      e1, high_freq1, pad1 = self.EncoderL1(input)

      #-------Encoder L2-------#
      e2, high_freq2, pad2 = self.EncoderL2(e1)

      #-------Encoder L3-------#
      e3, high_freq3, pad3 = self.EncoderL3(e2)

      #-------Encoder L4-------#
      x, high_freq4, pad4 = self.EncoderL4(e3)

      #-------ResNet Bottleneck-------#
      x = self.Bottleneck(x)

      #-------Decoder L1-------#
      x = self.DecoderL1(x, high_freq4, pad4)

      #-------Decoder L2-------#
      x = self.DecoderL2(x, high_freq3, pad3)

      #-------Decoder L3-------#
      x = self.DecoderL3(x, high_freq2, pad2)

      #-------Decoder L4-------#
      # x = self.DecoderL4(x, high_freq1, pad1) + input
      x = self.DecoderL4(x, high_freq1, pad1)

      #-------Refinement Block-------#
      # x = self.outResBlock(x)
      # x = self.out(x)

      return x

# Fusion

In [None]:
class fusion_net(nn.Module):
    def __init__(self):
        super(fusion_net, self).__init__()
        self.main_branch=CANT_HAZE()
        self.knowledge_adaptation_branch=Xception()

        self.knowledge_adaptation_branch.load_state_dict(torch.load("/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/xception-43020ad28.pth"),
                                                         strict = False)
        self.fusion = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(35, 3, kernel_size=7, padding=0), nn.Tanh())
    def forward(self, x):
        main_branch=self.main_branch(x)

        b, c, h, w = x.shape
        mod1 = h % 32
        mod2 = w % 32
        if (mod1):
            # print("padding height")
            x = F.pad(x, (0, 0, 0, mod1), "replicate")
        if (mod2):
            # print("padding width")
            x = F.pad(x, (0, mod2, 0, 0), "replicate")

        x=self.knowledge_adaptation_branch(x)

        if (mod1):
          x = x[:, :, :-mod1, :]
        if (mod2):
          x = x[:, :, :, :-mod2]

        x = torch.cat([main_branch, x], 1)
        x = self.fusion(x)
        return x

# Training

In [None]:
# --- train --- #
train_epoch = 500 # currently at 320 and should reach 820 after this
best_psnr = 20.33

TRAIN_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/O-HAZE/train/hazy/"
TRAIN_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/O-HAZE/train/GT/"
TEST_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/O-HAZE/test/hazy/"
TEST_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/O-HAZE/test/GT/"

# IMAGE_SIZE = (640,640)
# IMAGE_SIZE = (512,512)
IMAGE_SIZE = (384,384)
# IMAGE_SIZE = (256,256)
# RANDOM_CROP_SIZES = [256, 512]
# TRAIN_BATCH_SIZE = 2
TRAIN_BATCH_SIZE = 3
#TRAIN_BATCH_SIZE = 5

crop_size = [512,512]
overlap_size = [256,256]

VAL_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
NUM_WORKERS = 0
SHUFFLE = True
# --- output picture and check point --- #
G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/O-HAZE_BRB_DWT_RESBOTNCK_RESTBLOCK_DECONV_REFINE_XCEPTION.pth"
G_best_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/Best_O-HAZE_BRB_DWT_RESBOTNCK_RESTBLOCK_DECONV_REFINE_XCEPTION.pth"
# --- Gpu device --- #
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
MyEnsembleNet = fusion_net()
# MyEnsembleNet = t_compile(MyEnsembleNet)
# for name, param in MyEnsembleNet.named_parameters():
#     if param.requires_grad and 'haze_density' in name:
#         param.requires_grad = False
# non_frozen_parameters = [p for p in MyEnsembleNet.parameters() if p.requires_grad]

print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))
# print('Nonfrozen parameters:', sum(param.numel() for param in non_frozen_parameters))

# --- Build optimizer --- #
G_optimizer = torch.optim.Adam(MyEnsembleNet.parameters(), lr=0.0001)
# G_optimizer = torch.optim.Adam(MyEnsembleNet.parameters(), lr=0.0001 * 0.5)
# scheduler_G = torch.optim.lr_scheduler.MultiStepLR(G_optimizer, milestones=[3000, 5000, 6000], gamma=0.5)

dataset = custom_dehaze_train_dataset(HAZY_path = TRAIN_HAZY_IMAGES_PATH, GT_path = TRAIN_GT_IMAGES_PATH, is_train = True, Image_Size = IMAGE_SIZE)
                                      # random_crop_sizes = RANDOM_CROP_SIZES, random_crops = True)
train_loader = DataLoader(dataset=dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=SHUFFLE)

# val_dataset = CustomDataLoader(HAZY_path = VAL_HAZY_IMAGES_PATH,
#                               GT_path = VAL_GT_IMAGES_PATH,
#                               image_size = IMAGE_SIZE,
#                               crop = True)

# val_loader = DataLoader(val_dataset,
#                           batch_size = VAL_BATCH_SIZE,
#                           num_workers = NUM_WORKERS,
#                           shuffle = False)

# --- Load testing data --- #
test_data = CustomDataLoader(HAZY_path = TEST_HAZY_IMAGES_PATH,
                            GT_path = TEST_GT_IMAGES_PATH,
                            # image_size = (768,1024),
                            crop = False,
                            resize = False)

test_loader = DataLoader(test_data,
                        batch_size = TEST_BATCH_SIZE,
                        num_workers = NUM_WORKERS)

MyEnsembleNet = MyEnsembleNet.to(device)

# --- Load the network weight --- #
try:
    MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
    # MyEnsembleNet.load_state_dict(torch.load("/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/Best_NH_BRB_DWT_RESBOTNCK_RESTBLOCK_DECONV_REFINE_XCEPTION.pth"), strict = False)
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

psnr = PSNR(data_range=1.0).to(device)
ssim = SSIM(data_range=1.0).to(device)
l1_loss = nn.L1Loss().to(device)
# --- Start training --- #
for epoch in range(train_epoch):
    psnr_list = []
    ssim_list = []
    MyEnsembleNet.train()
    avg_loss = 0
    print("We are in epoch: " + str(epoch+1))

    for batch_idx, (hazy, clean) in enumerate(train_loader):
            hazy = hazy.to(device)
            clean = clean.to(device)
            output = MyEnsembleNet(hazy)
            l1_loss_ = l1_loss(output, clean)
            ssim_ = ssim(output, clean)
            psnr_ = psnr(output, clean)
            ssim_loss_ = 1 - ssim_
            # psnr_loss = 1 - psnr_ / 100
            calc_ssim = ssim_.item()
            calc_psnr = psnr_.item()
            total_loss = (l1_loss_ +  ssim_loss_) / 2
            # total_loss = (l1_loss_ +  ssim_loss_ + psnr_loss) / 3
            avg_loss += total_loss.item()
            MyEnsembleNet.zero_grad()
            total_loss.backward()
            G_optimizer.step()
            # print('PSNR: ', calc_psnr, 'SSIM: ', calc_ssim, 'SSIM_Loss: ', ssim_loss_.item(), 'l1_loss: ', l1_loss_.item(), 'total_loss', total_loss.item())
            psnr_list.append(calc_psnr)
            ssim_list.append(calc_ssim)
            # del hazy, clean, output
            # gc.collect()

    avr_psnr = sum(psnr_list) / len(psnr_list)
    avr_ssim = sum(ssim_list) / len(ssim_list)
    print('AVG PSNR: ', avr_psnr, 'AVG SSIM: ', avr_ssim, 'AVG Loss: ', avg_loss / len(psnr_list))

    # with torch.inference_mode():
    #     print("-----Validating-----")
    #     psnr_list = []
    #     ssim_list = []
    #     MyEnsembleNet.eval()

    #     for batch_idx, (hazy, clean) in enumerate(val_loader):
    #         for i in range(len(hazy)):
    #             hazy[i] = hazy[i].to(device)
    #             clean[i] = clean[i].to(device)
    #             output = MyEnsembleNet(hazy[i])
    #             calc_psnr = to_psnr(output, clean[i])
    #             calc_ssim = to_ssim_skimage(output, clean[i])
    #             psnr_list.extend(calc_psnr)
    #             ssim_list.extend(calc_ssim)

    # avr_psnr = sum(psnr_list) / len(psnr_list)
    # avr_ssim = sum(ssim_list) / len(ssim_list)
    # print('AVG PSNR: ', avr_psnr, 'AVG SSIM: ', avr_ssim)

    if (epoch+1) % 5 == 0:
            print("-----Testing-----")
            with torch.inference_mode():
                psnr_list = []
                ssim_list = []
                MyEnsembleNet.eval()
                for batch_idx, (hazy, clean, data_name) in enumerate(test_loader):
                    clean = clean.to(device)
                    hazy = hazy.to(device)
                    frame_out = MyEnsembleNet(hazy)
                    # if not os.path.exists('test/'):
                    #     os.makedirs('test/')
                    # imwrite(frame_out, 'test/' + ''.join(data_name) + '.png', range=(0, 1))
                    imwrite(frame_out, '/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/results_O-HAZE/' + ''.join(data_name) + '.png', range=(0, 1))
                    psnr_list.append(psnr(frame_out, clean).item())
                    ssim_list.append(ssim(frame_out, clean).item())
            avr_psnr = sum(psnr_list) / len(psnr_list)
            avr_ssim = sum(ssim_list) / len(ssim_list)
            print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim)
            torch.save(MyEnsembleNet.state_dict(), G_model_save_dir)
            print("-----Model Saved-----")
            if(avr_psnr > best_psnr):
                best_psnr = avr_psnr
                torch.save(MyEnsembleNet.state_dict(), G_best_model_save_dir)
                print("-----Best Model Saved-----")

print('Best PSNR: ', best_psnr)

MyEnsembleNet parameters: 59512615
Dataset has: 40 images
Dataset has: 5 images
--- no weight loaded ---
We are in epoch: 1
AVG PSNR:  11.001591341836113 AVG SSIM:  0.17169991242034094 AVG Loss:  0.5297169621501651
We are in epoch: 2
AVG PSNR:  12.446638720376152 AVG SSIM:  0.3784032868487494 AVG Loss:  0.4062989034823009
We are in epoch: 3
AVG PSNR:  13.667108263288226 AVG SSIM:  0.4471408554485866 AVG Loss:  0.3623412123748234
We are in epoch: 4
AVG PSNR:  12.668119634900775 AVG SSIM:  0.43018369376659393 AVG Loss:  0.3868142089673451
We are in epoch: 5
AVG PSNR:  14.391375269208636 AVG SSIM:  0.48070192337036133 AVG Loss:  0.3427429028919765
-----Testing-----
PSNR:  14.28794460296631 SSIM:  0.4824086368083954
-----Model Saved-----
We are in epoch: 6
AVG PSNR:  15.364476203918457 AVG SSIM:  0.5425954737833568 AVG Loss:  0.3006124123930931
We are in epoch: 7
AVG PSNR:  15.807620184762138 AVG SSIM:  0.573413188968386 AVG Loss:  0.27995134464332033
We are in epoch: 8
AVG PSNR:  17.05919

# Testing (Full Dimensions)

In [None]:
import time
# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/HAZY/"
VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/valid_dense/GT/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/Real_compressed/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/Real_compressed/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NTIRE23/TEST/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NTIRE23/TEST/"

VAL_BATCH_SIZE = 1
NUM_WORKERS = 0

# --- output picture and check point --- #
# G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v8/weights/Best_model_weights.pth"
G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/Best_DENSE_BRB_DWT_RESBOTNCK_RESTBLOCK_DECONV_REFINE_XCEPTION.pth"
# --- Gpu device --- #
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
MyEnsembleNet = fusion_net()
print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# --- Load testing data --- #
val_data = CustomDataLoader(HAZY_path = VAL_HAZY_IMAGES_PATH,
                            GT_path = VAL_GT_IMAGES_PATH,
                            # image_size = (720,1280),
                            resize = False,
                            crop = False)

val_loader = DataLoader(val_data,
                        batch_size = VAL_BATCH_SIZE,
                        num_workers = NUM_WORKERS)

MyEnsembleNet = MyEnsembleNet.to(device)

# --- Load the network weight --- #
try:
    MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir), strict = False)
    # MyEnsembleNet = torch.load("/content/drive/MyDrive/Graduation Project/CANT_Haze_v8/weights2/full_model_test.pth")
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

psnr = PSNR(data_range=1.0).to(device)
ssim = SSIM(data_range=1.0).to(device)

# --- Start training --- #
print("-----Testing-----")
with torch.inference_mode():
    psnr_list = []
    ssim_list = []
    avg_time = 0
    MyEnsembleNet.eval()
    for batch_idx, (hazy, clean, data_name) in enumerate(val_loader):
        clean = clean.to(device)
        hazy = hazy.to(device)
        start = time.time()
        frame_out = MyEnsembleNet(hazy)
        end = time.time()
        avg_time += (end - start) * 10 ** 3
        if not os.path.exists('test3/'):
            os.makedirs('test3/')
        imwrite(frame_out, 'test3/' + ''.join(data_name) + '.png', range=(0, 1))
        psnr_list.append(psnr(frame_out, clean).item())
        ssim_list.append(ssim(frame_out, clean).item())
avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
avg_time /= len(psnr_list)
print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim, 'Time: ', avg_time,'(ms)')

# Testing (Custom Shaped Crops)

In [None]:
import time

crop_size = [512,512]
overlap_size = [256,256]

VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/Real_original/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/Real_original/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NTIRE23/TEST/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NTIRE23/TEST/"

VAL_BATCH_SIZE = 1
NUM_WORKERS = 0

# --- output picture and check point --- #
G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/Best_NH_BRB_DWT_RESBOTNCK_RESTBLOCK_DECONV_REFINE_XCEPTION.pth"
# --- Gpu device --- #
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
MyEnsembleNet = fusion_net()
print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# --- Load testing data --- #
val_data = CustomDataLoader(HAZY_path = VAL_HAZY_IMAGES_PATH,
                            GT_path = VAL_GT_IMAGES_PATH,
                            # image_size = (720,1280),
                            # resize = True,
                            crop = False)

val_loader = DataLoader(val_data,
                        batch_size = VAL_BATCH_SIZE,
                        num_workers = NUM_WORKERS)

MyEnsembleNet = MyEnsembleNet.to(device)

# --- Load the network weight --- #
try:
    MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

psnr = PSNR(data_range=1.0).to(device)
ssim = SSIM(data_range=1.0).to(device)

# --- Start training --- #
print("-----Testing-----")
with torch.inference_mode():
    psnr_list = []
    ssim_list = []
    avg_time = 0
    MyEnsembleNet.eval()
    for batch_idx, (hazy, clean, data_name) in enumerate(val_loader):
        clean = clean.to(device)
        hazy = hazy.to(device)
        start = time.time()
        patch_list, crops_per_row, crops_per_col= custom_patchify(hazy, crop_size, overlap_size)

        for i in range(crops_per_row):
          for j in range(crops_per_col):
            patch_list[i][j] = MyEnsembleNet(patch_list[i][j])

        frame_out = custom_unpatchify(patch_list, overlap_size, crops_per_row, crops_per_col)

        end = time.time()
        avg_time += (end - start) * 10 ** 3
        if not os.path.exists('test3/'):
            os.makedirs('test3/')
        imwrite(frame_out, 'test3/' + ''.join(data_name) + '.png', range=(0, 1))
        psnr_list.append(psnr(frame_out, clean).item())
        ssim_list.append(ssim(frame_out, clean).item())
avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
avg_time /= len(psnr_list)
print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim, 'Time: ', avg_time,'(ms)')

# Testing (Grid Shaped Crops)

In [None]:
import time
VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_Hazy/"
VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NH-HAZE/Test_GT/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/Real_original/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/Real_original/"

# VAL_HAZY_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NTIRE23/TEST/"
# VAL_GT_IMAGES_PATH = "/content/drive/MyDrive/Graduation Project/data/NTIRE23/TEST/"

VAL_BATCH_SIZE = 1
NUM_WORKERS = 0

# --- output picture and check point --- #
G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v10/weights/Best_NH_BRB_DWT_RESBOTNCK_RESTBLOCK_DECONV_REFINE_XCEPTION.pth"
# --- Gpu device --- #
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
MyEnsembleNet = fusion_net()
print('MyEnsembleNet parameters:', sum(param.numel() for param in MyEnsembleNet.parameters()))

# --- Load testing data --- #
val_data = CustomDataLoader(HAZY_path = VAL_HAZY_IMAGES_PATH,
                            GT_path = VAL_GT_IMAGES_PATH,
                            # image_size = (720,1280),
                            # resize = True,
                            crop = False)

val_loader = DataLoader(val_data,
                        batch_size = VAL_BATCH_SIZE,
                        num_workers = NUM_WORKERS)

MyEnsembleNet = MyEnsembleNet.to(device)

# --- Load the network weight --- #
try:
    MyEnsembleNet.load_state_dict(torch.load(G_model_save_dir))
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

psnr = PSNR(data_range=1.0).to(device)
ssim = SSIM(data_range=1.0).to(device)

# --- Start training --- #
print("-----Testing-----")
with torch.inference_mode():
    psnr_list = []
    ssim_list = []
    avg_time = 0
    MyEnsembleNet.eval()
    for batch_idx, (hazy, clean, data_name) in enumerate(val_loader):
        clean = clean.to(device)
        hazy = hazy.to(device)
        start = time.time()
        patch_list_T, patch_list_B=grid_patchify(hazy,2)

        for x in range(len(patch_list_T)):
            patch_list_T[x] = MyEnsembleNet(patch_list_T[x])[0]
            patch_list_B[x] = MyEnsembleNet(patch_list_B[x])[0]

        frame_out = unpatchify(patch_list_T,patch_list_B).unsqueeze(0)

        end = time.time()
        avg_time += (end - start) * 10 ** 3
        if not os.path.exists('test3/'):
            os.makedirs('test3/')
        imwrite(frame_out, 'test3/' + ''.join(data_name) + '.png', range=(0, 1))
        psnr_list.append(psnr(frame_out, clean).item())
        ssim_list.append(ssim(frame_out, clean).item())
avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
avg_time /= len(psnr_list)
print('PSNR: ', avr_psnr, 'SSIM: ', avr_ssim, 'Time: ', avg_time,'(ms)')

# Edge Impulse

In [None]:
# #ONNX_ML=1
# import torch
# !pip install onnx
# # !pip install onnxruntime
# # !pip install onnx_tf
# import onnx
# # from onnx_tf.backend import prepare
# # import onnxruntime as ort

# G_model_save_dir = "/content/drive/MyDrive/Graduation Project/CANT_Haze_v9/weights/Best_NH_BRB_DWT_RESBOTNCK_RESTBLOCK_LEAKYRELU_DEEP.pth"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# trained_model = CANT_HAZE().to(device)
# trained_model.load_state_dict(torch.load(G_model_save_dir))
# trained_model.eval()
# with torch.inference_mode():

#   dummy_input = torch.randn(3, 1200, 1600).unsqueeze(0).to(device) # define a random input example image

#   torch.onnx.export(trained_model,
#                     dummy_input,
#                     "fixed_model_opv9.onnx",
#                     export_params=True,
#                     do_constant_folding=True,
#                     opset_version=9,)
#                     # input_names=['input'],
#                     # output_names=['output'],
#                     # dynamic_axes={'input': [2, 3], 'output' : [2, 3]}) # convert pytorch to ONNX

#   model = onnx.load('fixed_model_opv9.onnx') # Load ONNX model

#   # model = ort.InferenceSession("model.onnx") # Load ONNX model using onnx runtime

# # model = prepare(model) # convert ONNX to tensorflow

In [None]:
# !pip install edgeimpulse
# import edgeimpulse as ei
# model = onnx.load('model.onnx') # Load ONNX model

# ei.API_KEY = "ei_385eba63d6f91256bf2df2cbb07cb00a622b6faad386dfdffb796bac8fcf1c71"
# for device_name in ei.model.list_profile_devices():
#   print(f"------------ Profiling for device {device_name} ------------")
#   try:

#     profile = ei.model.profile(model = model, device = device_name)
#     print(profile.summary())
#   except Exception as e:
#     print(f"Could not profile {device_name}: {e}")
#   print("------------------------------------\n")

# Weight Extraction

## Edit weights

In [None]:
# import torch
# from collections import OrderedDict

# # load the model
# model = torch.load('/content/NH_MSRB_DWT2D.pth')

# # separate backbone layers
# new_state_dict = {}
# for name, param in model.items():
#     # print(name)
#     if 'dwt.layer3' not in name and 'encoder.SK' not in name:
#         #extract certain name
#         # new_state_dict[name.replace('dwt_branch.','')] = param
#         #remove certain parameter
#         new_state_dict[name] = param

## Save weights

In [None]:
# new_state_dict = OrderedDict(new_state_dict)

# # save the new model which has only backbone layers, only uncomment when needed
# ###torch.save(new_state_dict, '/content/NH_MSRB_DWT2D.pth')