In [1]:
import os
import shutil
import yaml
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import utils_tool

from basicsr.models.archs.kbnet_s_arch import KBNet_s
from basicsr.utils.util import patch_forward

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader

In [2]:
checkpoint = torch.load("pretrained_models/gaussian_gray_denoising_sigma15.pth")
checkpoint

{'params': OrderedDict([('patch_embed.proj.weight',
               tensor([[[[-0.1712,  0.2540, -0.1209],
                         [ 0.0261,  0.1423,  0.1935],
                         [-0.1297, -0.2948, -0.1576]]],
               
               
                       [[[-0.3576, -0.3163,  0.0523],
                         [ 0.0128,  0.1877, -0.0780],
                         [ 0.2454,  0.3798,  0.1534]]],
               
               
                       [[[-0.2436, -0.2267, -0.3085],
                         [ 0.1894,  0.0574,  0.2889],
                         [-0.0174,  0.3591,  0.1546]]],
               
               
                       [[[ 0.0851, -0.0938,  0.1409],
                         [ 0.0965, -0.1214,  0.2261],
                         [-0.0075, -0.2692, -0.1215]]],
               
               
                       [[[ 0.3418,  0.1887, -0.1772],
                         [ 0.3388,  0.0223, -0.3464],
                         [ 0.0638, -0.1782, -0.1764]]],


In [4]:
checkpoint['params'].keys()

odict_keys(['patch_embed.proj.weight', 'encoder_level1.0.norm1.body.weight', 'encoder_level1.0.attn.temperature', 'encoder_level1.0.attn.qkv.weight', 'encoder_level1.0.attn.qkv_dwconv.weight', 'encoder_level1.0.attn.project_out.weight', 'encoder_level1.0.norm2.body.weight', 'encoder_level1.0.ffn.project_in.weight', 'encoder_level1.0.ffn.dwconv.weight', 'encoder_level1.0.ffn.project_out.weight', 'encoder_level1.1.norm1.body.weight', 'encoder_level1.1.attn.temperature', 'encoder_level1.1.attn.qkv.weight', 'encoder_level1.1.attn.qkv_dwconv.weight', 'encoder_level1.1.attn.project_out.weight', 'encoder_level1.1.norm2.body.weight', 'encoder_level1.1.ffn.project_in.weight', 'encoder_level1.1.ffn.dwconv.weight', 'encoder_level1.1.ffn.project_out.weight', 'encoder_level1.2.norm1.body.weight', 'encoder_level1.2.attn.temperature', 'encoder_level1.2.attn.qkv.weight', 'encoder_level1.2.attn.qkv_dwconv.weight', 'encoder_level1.2.attn.project_out.weight', 'encoder_level1.2.norm2.body.weight', 'encode

In [5]:
checkpoint['params']['type']

KeyError: 'type'

In [35]:
from torchinfo import summary

model = KBNet_s(
    width=16,
    middle_blk_num=7, 
    # enc_blk_nums=[1, 1, 2, 4],
    # dec_blk_nums=[1, 1, 1, 1],
    lightweight=True
    )
summary(model)

Layer (type:depth-idx)                             Param #
KBNet_s                                            --
├─Conv2d: 1-1                                      448
├─ModuleList: 1-2                                  --
│    └─Sequential: 2-1                             --
│    │    └─KBBlock_s: 3-1                         22,304
│    │    └─KBBlock_s: 3-2                         22,304
│    └─Sequential: 2-2                             --
│    │    └─KBBlock_s: 3-3                         48,096
│    │    └─KBBlock_s: 3-4                         48,096
│    └─Sequential: 2-3                             --
│    │    └─KBBlock_s: 3-5                         109,888
│    │    └─KBBlock_s: 3-6                         109,888
│    │    └─KBBlock_s: 3-7                         109,888
│    │    └─KBBlock_s: 3-8                         109,888
│    └─Sequential: 2-4                             --
│    │    └─KBBlock_s: 3-9                         276,480
│    │    └─KBBlock_s: 3-10        

In [36]:
model = KBNet_s()
summary(model)

Layer (type:depth-idx)                             Param #
KBNet_s                                            --
├─Conv2d: 1-1                                      1,792
├─ModuleList: 1-2                                  --
│    └─Sequential: 2-1                             --
│    │    └─KBBlock_s: 3-1                         115,712
│    │    └─KBBlock_s: 3-2                         115,712
│    └─Sequential: 2-2                             --
│    │    └─KBBlock_s: 3-3                         288,128
│    │    └─KBBlock_s: 3-4                         288,128
│    └─Sequential: 2-3                             --
│    │    └─KBBlock_s: 3-5                         804,992
│    │    └─KBBlock_s: 3-6                         804,992
│    │    └─KBBlock_s: 3-7                         804,992
│    │    └─KBBlock_s: 3-8                         804,992
│    └─Sequential: 2-4                             --
│    │    └─KBBlock_s: 3-9                         2,526,848
│    │    └─KBBlock_s: 3-10