In [2]:
import random
import numpy as np
import cv2
import os
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from os.path import join
from os import listdir
from torchsummary import summary
import time
import argparse
import models.DnCNN as DnCNN, models.ResNet as ResNet, models.RFDN as RFDN, models.ResNet_ED as ResNetED, models.DRLN as DRLN, models.pix2pix as pix2pix
import models.restormer_arch as restormer
from models.network_swinir import SwinIR as net
from models.swin_transformer_v2 import SwinTransformerV2 as net2
from models.kbnet_s_arch import KBNet_s
from models.kbnet_l_arch import KBNet_l
from models.restormer_arch import Restormer
from utils.param import param_check, seed_everything

model_list = [
            #     'DnCNN', 'ResNet18', 
            # #   'ResNet34', 'ResNet50', 'ResNet101','ResNet152', 
            #   'RFDN', 'ResNetED', 
            # 'DRLN', 
            # 'pix2pix',
            'SwinIR128', 'SwinIR64',
            # 'SwinIR32', 
            # 'SwinIRv232','restormer',
            'KBNet_s', 'KBNet_l','Restormer']
models = {'DnCNN': DnCNN.DnCNN(), 
          'ResNet18': ResNet.ResNet18(), 
          'ResNet34': ResNet.ResNet34(), 
          'ResNet50': ResNet.ResNet50(), 
          'ResNet101': ResNet.ResNet101(), 
          'ResNet152': ResNet.ResNet152(), 
          'RFDN': RFDN.RFDN(),
          'ResNetED' : ResNetED.ResNet18(),
          'DRLN' : DRLN.DRLN(),
          'SwinIR128' : net(upscale=1, in_chans=3, img_size=128, window_size=8, \
                    img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], \
                    mlp_ratio=2, upsampler='', resi_connection='1conv'),
          'SwinIR64' : net(upscale=1, in_chans=3, img_size=64, window_size=8, \
                    img_range=1., depths=[6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], \
                    mlp_ratio=2, upsampler='', resi_connection='1conv'),
          'SwinIR32' : net(upscale=1, in_chans=3, img_size=32, window_size=8, \
                    img_range=1., depths=[6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6], \
                    mlp_ratio=2, upsampler='', resi_connection='1conv'),
          'restormer' : restormer.Restormer(num_refinement_blocks=4),
          'SwinIRv232' : net(upscale=1, in_chans=3, img_size=32, window_size=8, \
                    img_range=1., depths=[6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6], \
                    mlp_ratio=2, upsampler='', resi_connection='1conv'),
          'KBNet_s' : KBNet_s(middle_blk_num=12, enc_blk_nums=[2, 2, 4, 8], dec_blk_nums=[2, 2, 2, 2],lightweight=False), 
          'KBNet_l' : KBNet_s(middle_blk_num=1, enc_blk_nums=[2, 2, 4], dec_blk_nums=[2, 2, 2],lightweight=True),
          'Restormer' : Restormer(dim = 30, num_blocks = [2,3,6,8], num_refinement_blocks = 4, heads = [1,2,4,8])
          }

def param_print(model):
    param_check(model)
    param_check(model, True)
    print(summary(model, (3, 128, 128)))

device = torch.device('cuda' if torch.cuda.is_available() else 'mps:0' if torch.backends.mps.is_available() else 'cpu')
for m in model_list:   
    print('\n\n',m,' 모델은 다음과 같다.')
    if m == 'pix2pix':
        G = pix2pix.Generator()
        D = pix2pix.Discriminator()
        print('총 : ',param_check(G) + param_check(D))
        print('총 : ',param_check(G, True) + param_check(D, True))
        print(summary(G, (3, 256, 256)))
        print(summary(D, (6, 256, 256)))
    else:
        # param_print(models[m])
        param_check(models[m])
        param_check(models[m], True)
        print(summary(models[m], (3, 128, 128)))



 SwinIR128  모델은 다음과 같다.
모든 parameter 개수기준임
Number of parameters: 11504163
!!!!!!10M 1천만보다 1504163개 초과했음!!!!!!
grad인 parameter 개수기준임
Number of parameters: 11504163
!!!!!!10M 1천만보다 1504163개 초과했음!!!!!!
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 180, 128, 128]           5,040
         LayerNorm-2           [-1, 16384, 180]             360
        PatchEmbed-3           [-1, 16384, 180]               0
           Dropout-4           [-1, 16384, 180]               0
         LayerNorm-5           [-1, 16384, 180]             360
            Linear-6              [-1, 64, 540]          97,740
           Softmax-7            [-1, 6, 64, 64]               0
           Dropout-8            [-1, 6, 64, 64]               0
            Linear-9              [-1, 64, 180]          32,580
          Dropout-10              [-1, 64, 180]               0
  WindowAttention-11          