In [1]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import cv2
from pathlib import Path
from tqdm import tqdm
from torchvision import transforms

In [15]:
mirnetv2_repo_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'MIRNetv2'))
if mirnetv2_repo_path not in sys.path:
    sys.path.insert(0, mirnetv2_repo_path)

In [16]:
from basicsr.archs.mirnet_v2_arch import MIRNet_v2

In [17]:
data_root  = Path("../data/patches")
hr_dir = data_root / "HR_patches"
lr_dir = data_root / "mixed_LR_patches"

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patch_size = 192
batch_size = 8
num_epochs = 10
learning_rate = 1e-5

In [19]:
class PatchDataset(Dataset):
    def __init__(self,lr_dir,hr_dir):
        self.lr_dir = Path(lr_dir)
        self.hr_dir = Path(hr_dir)
        self.files = list(self.lr_dir.glob("*.png"))
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        lr_path = self.files[idx]
        hr_path = Path(self.hr_dir) / lr_path.name
        lr_img = cv2.imread(str(lr_path))
        hr_img = cv2.imread(str(hr_path))
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB)
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        lr_tensor = self.to_tensor(lr_img)
        hr_tensor = self.to_tensor(hr_img)
        return lr_tensor, hr_tensor

In [20]:
dataset = PatchDataset(lr_dir,hr_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [21]:
model = MIRNet_v2().to(device)

In [27]:
pretrained_path = "../MIRNetv2/pretrained_models/enhancement_fivek.pth"
state_dict = torch.load(pretrained_path, map_location=device)
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for MIRNet_v2:
	Missing key(s) in state_dict: "conv_in.weight", "body.0.body.0.dau_top.body.0.weight", "body.0.body.0.dau_top.body.2.weight", "body.0.body.0.dau_top.gcnet.conv_mask.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_mid.body.0.weight", "body.0.body.0.dau_mid.body.2.weight", "body.0.body.0.dau_mid.gcnet.conv_mask.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_bot.body.0.weight", "body.0.body.0.dau_bot.body.2.weight", "body.0.body.0.dau_bot.gcnet.conv_mask.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.0.down2.body.0.bot.1.weight", "body.0.body.0.down4.0.body.0.bot.1.weight", "body.0.body.0.down4.1.body.0.bot.1.weight", "body.0.body.0.up21_1.body.0.bot.0.weight", "body.0.body.0.up21_2.body.0.bot.0.weight", "body.0.body.0.up32_1.body.0.bot.0.weight", "body.0.body.0.up32_2.body.0.bot.0.weight", "body.0.body.0.conv_out.weight", "body.0.body.0.skff_top.conv_du.0.weight", "body.0.body.0.skff_top.fcs.0.weight", "body.0.body.0.skff_top.fcs.1.weight", "body.0.body.0.skff_mid.conv_du.0.weight", "body.0.body.0.skff_mid.fcs.0.weight", "body.0.body.0.skff_mid.fcs.1.weight", "body.0.body.1.dau_top.body.0.weight", "body.0.body.1.dau_top.body.2.weight", "body.0.body.1.dau_top.gcnet.conv_mask.weight", "body.0.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.1.dau_mid.body.0.weight", "body.0.body.1.dau_mid.body.2.weight", "body.0.body.1.dau_mid.gcnet.conv_mask.weight", "body.0.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.1.dau_bot.body.0.weight", "body.0.body.1.dau_bot.body.2.weight", "body.0.body.1.dau_bot.gcnet.conv_mask.weight", "body.0.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.1.down2.body.0.bot.1.weight", "body.0.body.1.down4.0.body.0.bot.1.weight", "body.0.body.1.down4.1.body.0.bot.1.weight", "body.0.body.1.up21_1.body.0.bot.0.weight", "body.0.body.1.up21_2.body.0.bot.0.weight", "body.0.body.1.up32_1.body.0.bot.0.weight", "body.0.body.1.up32_2.body.0.bot.0.weight", "body.0.body.1.conv_out.weight", "body.0.body.1.skff_top.conv_du.0.weight", "body.0.body.1.skff_top.fcs.0.weight", "body.0.body.1.skff_top.fcs.1.weight", "body.0.body.1.skff_mid.conv_du.0.weight", "body.0.body.1.skff_mid.fcs.0.weight", "body.0.body.1.skff_mid.fcs.1.weight", "body.0.body.2.weight", "body.1.body.0.dau_top.body.0.weight", "body.1.body.0.dau_top.body.2.weight", "body.1.body.0.dau_top.gcnet.conv_mask.weight", "body.1.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.1.body.0.dau_mid.body.0.weight", "body.1.body.0.dau_mid.body.2.weight", "body.1.body.0.dau_mid.gcnet.conv_mask.weight", "body.1.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.1.body.0.dau_bot.body.0.weight", "body.1.body.0.dau_bot.body.2.weight", "body.1.body.0.dau_bot.gcnet.conv_mask.weight", "body.1.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.1.body.0.down2.body.0.bot.1.weight", "body.1.body.0.down4.0.body.0.bot.1.weight", "body.1.body.0.down4.1.body.0.bot.1.weight", "body.1.body.0.up21_1.body.0.bot.0.weight", "body.1.body.0.up21_2.body.0.bot.0.weight", "body.1.body.0.up32_1.body.0.bot.0.weight", "body.1.body.0.up32_2.body.0.bot.0.weight", "body.1.body.0.conv_out.weight", "body.1.body.0.skff_top.conv_du.0.weight", "body.1.body.0.skff_top.fcs.0.weight", "body.1.body.0.skff_top.fcs.1.weight", "body.1.body.0.skff_mid.conv_du.0.weight", "body.1.body.0.skff_mid.fcs.0.weight", "body.1.body.0.skff_mid.fcs.1.weight", "body.1.body.1.dau_top.body.0.weight", "body.1.body.1.dau_top.body.2.weight", "body.1.body.1.dau_top.gcnet.conv_mask.weight", "body.1.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.1.body.1.dau_mid.body.0.weight", "body.1.body.1.dau_mid.body.2.weight", "body.1.body.1.dau_mid.gcnet.conv_mask.weight", "body.1.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.1.body.1.dau_bot.body.0.weight", "body.1.body.1.dau_bot.body.2.weight", "body.1.body.1.dau_bot.gcnet.conv_mask.weight", "body.1.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.1.body.1.down2.body.0.bot.1.weight", "body.1.body.1.down4.0.body.0.bot.1.weight", "body.1.body.1.down4.1.body.0.bot.1.weight", "body.1.body.1.up21_1.body.0.bot.0.weight", "body.1.body.1.up21_2.body.0.bot.0.weight", "body.1.body.1.up32_1.body.0.bot.0.weight", "body.1.body.1.up32_2.body.0.bot.0.weight", "body.1.body.1.conv_out.weight", "body.1.body.1.skff_top.conv_du.0.weight", "body.1.body.1.skff_top.fcs.0.weight", "body.1.body.1.skff_top.fcs.1.weight", "body.1.body.1.skff_mid.conv_du.0.weight", "body.1.body.1.skff_mid.fcs.0.weight", "body.1.body.1.skff_mid.fcs.1.weight", "body.1.body.2.weight", "body.2.body.0.dau_top.body.0.weight", "body.2.body.0.dau_top.body.2.weight", "body.2.body.0.dau_top.gcnet.conv_mask.weight", "body.2.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.2.body.0.dau_mid.body.0.weight", "body.2.body.0.dau_mid.body.2.weight", "body.2.body.0.dau_mid.gcnet.conv_mask.weight", "body.2.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.2.body.0.dau_bot.body.0.weight", "body.2.body.0.dau_bot.body.2.weight", "body.2.body.0.dau_bot.gcnet.conv_mask.weight", "body.2.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.2.body.0.down2.body.0.bot.1.weight", "body.2.body.0.down4.0.body.0.bot.1.weight", "body.2.body.0.down4.1.body.0.bot.1.weight", "body.2.body.0.up21_1.body.0.bot.0.weight", "body.2.body.0.up21_2.body.0.bot.0.weight", "body.2.body.0.up32_1.body.0.bot.0.weight", "body.2.body.0.up32_2.body.0.bot.0.weight", "body.2.body.0.conv_out.weight", "body.2.body.0.skff_top.conv_du.0.weight", "body.2.body.0.skff_top.fcs.0.weight", "body.2.body.0.skff_top.fcs.1.weight", "body.2.body.0.skff_mid.conv_du.0.weight", "body.2.body.0.skff_mid.fcs.0.weight", "body.2.body.0.skff_mid.fcs.1.weight", "body.2.body.1.dau_top.body.0.weight", "body.2.body.1.dau_top.body.2.weight", "body.2.body.1.dau_top.gcnet.conv_mask.weight", "body.2.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.2.body.1.dau_mid.body.0.weight", "body.2.body.1.dau_mid.body.2.weight", "body.2.body.1.dau_mid.gcnet.conv_mask.weight", "body.2.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.2.body.1.dau_bot.body.0.weight", "body.2.body.1.dau_bot.body.2.weight", "body.2.body.1.dau_bot.gcnet.conv_mask.weight", "body.2.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.2.body.1.down2.body.0.bot.1.weight", "body.2.body.1.down4.0.body.0.bot.1.weight", "body.2.body.1.down4.1.body.0.bot.1.weight", "body.2.body.1.up21_1.body.0.bot.0.weight", "body.2.body.1.up21_2.body.0.bot.0.weight", "body.2.body.1.up32_1.body.0.bot.0.weight", "body.2.body.1.up32_2.body.0.bot.0.weight", "body.2.body.1.conv_out.weight", "body.2.body.1.skff_top.conv_du.0.weight", "body.2.body.1.skff_top.fcs.0.weight", "body.2.body.1.skff_top.fcs.1.weight", "body.2.body.1.skff_mid.conv_du.0.weight", "body.2.body.1.skff_mid.fcs.0.weight", "body.2.body.1.skff_mid.fcs.1.weight", "body.2.body.2.weight", "body.3.body.0.dau_top.body.0.weight", "body.3.body.0.dau_top.body.2.weight", "body.3.body.0.dau_top.gcnet.conv_mask.weight", "body.3.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.3.body.0.dau_mid.body.0.weight", "body.3.body.0.dau_mid.body.2.weight", "body.3.body.0.dau_mid.gcnet.conv_mask.weight", "body.3.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.3.body.0.dau_bot.body.0.weight", "body.3.body.0.dau_bot.body.2.weight", "body.3.body.0.dau_bot.gcnet.conv_mask.weight", "body.3.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.3.body.0.down2.body.0.bot.1.weight", "body.3.body.0.down4.0.body.0.bot.1.weight", "body.3.body.0.down4.1.body.0.bot.1.weight", "body.3.body.0.up21_1.body.0.bot.0.weight", "body.3.body.0.up21_2.body.0.bot.0.weight", "body.3.body.0.up32_1.body.0.bot.0.weight", "body.3.body.0.up32_2.body.0.bot.0.weight", "body.3.body.0.conv_out.weight", "body.3.body.0.skff_top.conv_du.0.weight", "body.3.body.0.skff_top.fcs.0.weight", "body.3.body.0.skff_top.fcs.1.weight", "body.3.body.0.skff_mid.conv_du.0.weight", "body.3.body.0.skff_mid.fcs.0.weight", "body.3.body.0.skff_mid.fcs.1.weight", "body.3.body.1.dau_top.body.0.weight", "body.3.body.1.dau_top.body.2.weight", "body.3.body.1.dau_top.gcnet.conv_mask.weight", "body.3.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.3.body.1.dau_mid.body.0.weight", "body.3.body.1.dau_mid.body.2.weight", "body.3.body.1.dau_mid.gcnet.conv_mask.weight", "body.3.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.3.body.1.dau_bot.body.0.weight", "body.3.body.1.dau_bot.body.2.weight", "body.3.body.1.dau_bot.gcnet.conv_mask.weight", "body.3.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.3.body.1.down2.body.0.bot.1.weight", "body.3.body.1.down4.0.body.0.bot.1.weight", "body.3.body.1.down4.1.body.0.bot.1.weight", "body.3.body.1.up21_1.body.0.bot.0.weight", "body.3.body.1.up21_2.body.0.bot.0.weight", "body.3.body.1.up32_1.body.0.bot.0.weight", "body.3.body.1.up32_2.body.0.bot.0.weight", "body.3.body.1.conv_out.weight", "body.3.body.1.skff_top.conv_du.0.weight", "body.3.body.1.skff_top.fcs.0.weight", "body.3.body.1.skff_top.fcs.1.weight", "body.3.body.1.skff_mid.conv_du.0.weight", "body.3.body.1.skff_mid.fcs.0.weight", "body.3.body.1.skff_mid.fcs.1.weight", "body.3.body.2.weight", "conv_out.weight". 
	Unexpected key(s) in state_dict: "params". 

In [35]:
# First, make sure these imports are at the top of your script
import os
import cv2
import torch
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from skimage.util import img_as_ubyte
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from runpy import run_path

# Define task and parameters
task = 'deblurring'  # Make sure this variable is defined
parameters = {  # Define the parameters dictionary with required model parameters
    'inp_channels': 3,
    'out_channels': 3,
    'n_feat': 80,
    'chan_factor': 1.5,
    'n_MRB': 2,
    'height': 3,
    'width': 2,
    'bias': False,
    'scale': 1,
    'task': task
}

def get_weights_and_parameters(task, parameters):
    if task == 'deblurring':
        weights = os.path.join('..', 'MIRNetv2', 'pretrained_models', 'motion_deblurring.pth')
    else:
        raise ValueError("Only 'deblurring' task is supported in this script.")
    return weights, parameters

# Load model
weights, parameters = get_weights_and_parameters(task, parameters)

# Make sure the path to the architecture file is correct
arch_path = os.path.join('..', 'MIRNetv2', 'basicsr', 'models', 'archs', 'mirnet_v2_arch.py')
if not os.path.exists(arch_path):
    raise FileNotFoundError(f"Architecture file not found at: {arch_path}")

load_arch = run_path(arch_path)
model = load_arch['MIRNet_v2'](**parameters)

# Check if CUDA is available and use it only if it's available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Load checkpoint with proper error handling
try:
    checkpoint = torch.load(weights, map_location=device)
    # Check if the keys in checkpoint match the model's state_dict
    if 'params' not in checkpoint:
        print("Warning: 'params' key not found in checkpoint. Available keys:", checkpoint.keys())
        # Try to load directly if it's just the state dict
        model.load_state_dict(checkpoint)
    else:
        model.load_state_dict(checkpoint['params'])
    model.eval()
except Exception as e:
    print(f"Error loading checkpoint: {e}")
    # Print model's state_dict keys for debugging
    print("Model expects these keys:")
    for key in model.state_dict().keys():
        print(f"  {key}")
    raise

# Input LR and HR image paths - verify these paths exist
lr_img_path = "C:/Users/manas/image-sharpness/data/patches/LR_patches/0004_037.png"
hr_img_path = "C:/Users/manas/image-sharpness/data/patches/HR_patches/0004_037.png"

if not os.path.exists(lr_img_path):
    raise FileNotFoundError(f"LR image not found at: {lr_img_path}")
if not os.path.exists(hr_img_path):
    raise FileNotFoundError(f"HR image not found at: {hr_img_path}")

# Load images
lr_img = cv2.imread(lr_img_path)
lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB)
hr_img = cv2.imread(hr_img_path)
hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

# Inference
img_input = TF.to_tensor(lr_img).unsqueeze(0).to(device)
with torch.no_grad():
    restored = model(img_input)
restored = torch.clamp(restored, 0, 1)
restored_img = restored.squeeze(0).permute(1, 2, 0).cpu().numpy()
restored_img_ubyte = img_as_ubyte(restored_img)

# Compute PSNR and SSIM
psnr_val = psnr(hr_img, restored_img_ubyte, data_range=255)
ssim_val = ssim(hr_img, restored_img_ubyte, channel_axis=2, data_range=255)

print(f"PSNR: {psnr_val:.2f} dB")
print(f"SSIM: {ssim_val:.4f}")

# Show LR, Restored, HR
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(lr_img)
plt.title('LR Patch')
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(restored_img_ubyte)
plt.title('Restored Patch')
plt.axis('off')

plt.subplot(1,3,3)
plt.imshow(hr_img)
plt.title('HR Patch')
plt.axis('off')

plt.tight_layout()
plt.show()

Error loading checkpoint: Error(s) in loading state_dict for MIRNet_v2:
	Missing key(s) in state_dict: "conv_in.weight", "body.0.body.0.dau_top.body.0.weight", "body.0.body.0.dau_top.body.2.weight", "body.0.body.0.dau_top.gcnet.conv_mask.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_mid.body.0.weight", "body.0.body.0.dau_mid.body.2.weight", "body.0.body.0.dau_mid.gcnet.conv_mask.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_bot.body.0.weight", "body.0.body.0.dau_bot.body.2.weight", "body.0.body.0.dau_bot.gcnet.conv_mask.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.0.down2.body.0.bot.1.weight", "body.0.body.0.down4.0.body.0.bot.1.weight", "body.0.body.0.down4.1.body.0.bot.1.weight", "body.0.body.0.up21_1.body

RuntimeError: Error(s) in loading state_dict for MIRNet_v2:
	Missing key(s) in state_dict: "conv_in.weight", "body.0.body.0.dau_top.body.0.weight", "body.0.body.0.dau_top.body.2.weight", "body.0.body.0.dau_top.gcnet.conv_mask.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_mid.body.0.weight", "body.0.body.0.dau_mid.body.2.weight", "body.0.body.0.dau_mid.gcnet.conv_mask.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_bot.body.0.weight", "body.0.body.0.dau_bot.body.2.weight", "body.0.body.0.dau_bot.gcnet.conv_mask.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.0.down2.body.0.bot.1.weight", "body.0.body.0.down4.0.body.0.bot.1.weight", "body.0.body.0.down4.1.body.0.bot.1.weight", "body.0.body.0.up21_1.body.0.bot.0.weight", "body.0.body.0.up21_2.body.0.bot.0.weight", "body.0.body.0.up32_1.body.0.bot.0.weight", "body.0.body.0.up32_2.body.0.bot.0.weight", "body.0.body.0.conv_out.weight", "body.0.body.0.skff_top.conv_du.0.weight", "body.0.body.0.skff_top.fcs.0.weight", "body.0.body.0.skff_top.fcs.1.weight", "body.0.body.0.skff_mid.conv_du.0.weight", "body.0.body.0.skff_mid.fcs.0.weight", "body.0.body.0.skff_mid.fcs.1.weight", "body.0.body.1.dau_top.body.0.weight", "body.0.body.1.dau_top.body.2.weight", "body.0.body.1.dau_top.gcnet.conv_mask.weight", "body.0.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.1.dau_mid.body.0.weight", "body.0.body.1.dau_mid.body.2.weight", "body.0.body.1.dau_mid.gcnet.conv_mask.weight", "body.0.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.1.dau_bot.body.0.weight", "body.0.body.1.dau_bot.body.2.weight", "body.0.body.1.dau_bot.gcnet.conv_mask.weight", "body.0.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.1.down2.body.0.bot.1.weight", "body.0.body.1.down4.0.body.0.bot.1.weight", "body.0.body.1.down4.1.body.0.bot.1.weight", "body.0.body.1.up21_1.body.0.bot.0.weight", "body.0.body.1.up21_2.body.0.bot.0.weight", "body.0.body.1.up32_1.body.0.bot.0.weight", "body.0.body.1.up32_2.body.0.bot.0.weight", "body.0.body.1.conv_out.weight", "body.0.body.1.skff_top.conv_du.0.weight", "body.0.body.1.skff_top.fcs.0.weight", "body.0.body.1.skff_top.fcs.1.weight", "body.0.body.1.skff_mid.conv_du.0.weight", "body.0.body.1.skff_mid.fcs.0.weight", "body.0.body.1.skff_mid.fcs.1.weight", "body.0.body.2.weight", "body.1.body.0.dau_top.body.0.weight", "body.1.body.0.dau_top.body.2.weight", "body.1.body.0.dau_top.gcnet.conv_mask.weight", "body.1.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.1.body.0.dau_mid.body.0.weight", "body.1.body.0.dau_mid.body.2.weight", "body.1.body.0.dau_mid.gcnet.conv_mask.weight", "body.1.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.1.body.0.dau_bot.body.0.weight", "body.1.body.0.dau_bot.body.2.weight", "body.1.body.0.dau_bot.gcnet.conv_mask.weight", "body.1.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.1.body.0.down2.body.0.bot.1.weight", "body.1.body.0.down4.0.body.0.bot.1.weight", "body.1.body.0.down4.1.body.0.bot.1.weight", "body.1.body.0.up21_1.body.0.bot.0.weight", "body.1.body.0.up21_2.body.0.bot.0.weight", "body.1.body.0.up32_1.body.0.bot.0.weight", "body.1.body.0.up32_2.body.0.bot.0.weight", "body.1.body.0.conv_out.weight", "body.1.body.0.skff_top.conv_du.0.weight", "body.1.body.0.skff_top.fcs.0.weight", "body.1.body.0.skff_top.fcs.1.weight", "body.1.body.0.skff_mid.conv_du.0.weight", "body.1.body.0.skff_mid.fcs.0.weight", "body.1.body.0.skff_mid.fcs.1.weight", "body.1.body.1.dau_top.body.0.weight", "body.1.body.1.dau_top.body.2.weight", "body.1.body.1.dau_top.gcnet.conv_mask.weight", "body.1.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.1.body.1.dau_mid.body.0.weight", "body.1.body.1.dau_mid.body.2.weight", "body.1.body.1.dau_mid.gcnet.conv_mask.weight", "body.1.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.1.body.1.dau_bot.body.0.weight", "body.1.body.1.dau_bot.body.2.weight", "body.1.body.1.dau_bot.gcnet.conv_mask.weight", "body.1.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.1.body.1.down2.body.0.bot.1.weight", "body.1.body.1.down4.0.body.0.bot.1.weight", "body.1.body.1.down4.1.body.0.bot.1.weight", "body.1.body.1.up21_1.body.0.bot.0.weight", "body.1.body.1.up21_2.body.0.bot.0.weight", "body.1.body.1.up32_1.body.0.bot.0.weight", "body.1.body.1.up32_2.body.0.bot.0.weight", "body.1.body.1.conv_out.weight", "body.1.body.1.skff_top.conv_du.0.weight", "body.1.body.1.skff_top.fcs.0.weight", "body.1.body.1.skff_top.fcs.1.weight", "body.1.body.1.skff_mid.conv_du.0.weight", "body.1.body.1.skff_mid.fcs.0.weight", "body.1.body.1.skff_mid.fcs.1.weight", "body.1.body.2.weight", "body.2.body.0.dau_top.body.0.weight", "body.2.body.0.dau_top.body.2.weight", "body.2.body.0.dau_top.gcnet.conv_mask.weight", "body.2.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.2.body.0.dau_mid.body.0.weight", "body.2.body.0.dau_mid.body.2.weight", "body.2.body.0.dau_mid.gcnet.conv_mask.weight", "body.2.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.2.body.0.dau_bot.body.0.weight", "body.2.body.0.dau_bot.body.2.weight", "body.2.body.0.dau_bot.gcnet.conv_mask.weight", "body.2.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.2.body.0.down2.body.0.bot.1.weight", "body.2.body.0.down4.0.body.0.bot.1.weight", "body.2.body.0.down4.1.body.0.bot.1.weight", "body.2.body.0.up21_1.body.0.bot.0.weight", "body.2.body.0.up21_2.body.0.bot.0.weight", "body.2.body.0.up32_1.body.0.bot.0.weight", "body.2.body.0.up32_2.body.0.bot.0.weight", "body.2.body.0.conv_out.weight", "body.2.body.0.skff_top.conv_du.0.weight", "body.2.body.0.skff_top.fcs.0.weight", "body.2.body.0.skff_top.fcs.1.weight", "body.2.body.0.skff_mid.conv_du.0.weight", "body.2.body.0.skff_mid.fcs.0.weight", "body.2.body.0.skff_mid.fcs.1.weight", "body.2.body.1.dau_top.body.0.weight", "body.2.body.1.dau_top.body.2.weight", "body.2.body.1.dau_top.gcnet.conv_mask.weight", "body.2.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.2.body.1.dau_mid.body.0.weight", "body.2.body.1.dau_mid.body.2.weight", "body.2.body.1.dau_mid.gcnet.conv_mask.weight", "body.2.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.2.body.1.dau_bot.body.0.weight", "body.2.body.1.dau_bot.body.2.weight", "body.2.body.1.dau_bot.gcnet.conv_mask.weight", "body.2.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.2.body.1.down2.body.0.bot.1.weight", "body.2.body.1.down4.0.body.0.bot.1.weight", "body.2.body.1.down4.1.body.0.bot.1.weight", "body.2.body.1.up21_1.body.0.bot.0.weight", "body.2.body.1.up21_2.body.0.bot.0.weight", "body.2.body.1.up32_1.body.0.bot.0.weight", "body.2.body.1.up32_2.body.0.bot.0.weight", "body.2.body.1.conv_out.weight", "body.2.body.1.skff_top.conv_du.0.weight", "body.2.body.1.skff_top.fcs.0.weight", "body.2.body.1.skff_top.fcs.1.weight", "body.2.body.1.skff_mid.conv_du.0.weight", "body.2.body.1.skff_mid.fcs.0.weight", "body.2.body.1.skff_mid.fcs.1.weight", "body.2.body.2.weight", "body.3.body.0.dau_top.body.0.weight", "body.3.body.0.dau_top.body.2.weight", "body.3.body.0.dau_top.gcnet.conv_mask.weight", "body.3.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.3.body.0.dau_mid.body.0.weight", "body.3.body.0.dau_mid.body.2.weight", "body.3.body.0.dau_mid.gcnet.conv_mask.weight", "body.3.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.3.body.0.dau_bot.body.0.weight", "body.3.body.0.dau_bot.body.2.weight", "body.3.body.0.dau_bot.gcnet.conv_mask.weight", "body.3.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.3.body.0.down2.body.0.bot.1.weight", "body.3.body.0.down4.0.body.0.bot.1.weight", "body.3.body.0.down4.1.body.0.bot.1.weight", "body.3.body.0.up21_1.body.0.bot.0.weight", "body.3.body.0.up21_2.body.0.bot.0.weight", "body.3.body.0.up32_1.body.0.bot.0.weight", "body.3.body.0.up32_2.body.0.bot.0.weight", "body.3.body.0.conv_out.weight", "body.3.body.0.skff_top.conv_du.0.weight", "body.3.body.0.skff_top.fcs.0.weight", "body.3.body.0.skff_top.fcs.1.weight", "body.3.body.0.skff_mid.conv_du.0.weight", "body.3.body.0.skff_mid.fcs.0.weight", "body.3.body.0.skff_mid.fcs.1.weight", "body.3.body.1.dau_top.body.0.weight", "body.3.body.1.dau_top.body.2.weight", "body.3.body.1.dau_top.gcnet.conv_mask.weight", "body.3.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.3.body.1.dau_mid.body.0.weight", "body.3.body.1.dau_mid.body.2.weight", "body.3.body.1.dau_mid.gcnet.conv_mask.weight", "body.3.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.3.body.1.dau_bot.body.0.weight", "body.3.body.1.dau_bot.body.2.weight", "body.3.body.1.dau_bot.gcnet.conv_mask.weight", "body.3.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.3.body.1.down2.body.0.bot.1.weight", "body.3.body.1.down4.0.body.0.bot.1.weight", "body.3.body.1.down4.1.body.0.bot.1.weight", "body.3.body.1.up21_1.body.0.bot.0.weight", "body.3.body.1.up21_2.body.0.bot.0.weight", "body.3.body.1.up32_1.body.0.bot.0.weight", "body.3.body.1.up32_2.body.0.bot.0.weight", "body.3.body.1.conv_out.weight", "body.3.body.1.skff_top.conv_du.0.weight", "body.3.body.1.skff_top.fcs.0.weight", "body.3.body.1.skff_top.fcs.1.weight", "body.3.body.1.skff_mid.conv_du.0.weight", "body.3.body.1.skff_mid.fcs.0.weight", "body.3.body.1.skff_mid.fcs.1.weight", "body.3.body.2.weight", "conv_out.weight". 
	Unexpected key(s) in state_dict: "patch_embed.proj.weight", "encoder_level1.0.norm1.body.weight", "encoder_level1.0.norm1.body.bias", "encoder_level1.0.attn.temperature", "encoder_level1.0.attn.qkv.weight", "encoder_level1.0.attn.qkv_dwconv.weight", "encoder_level1.0.attn.project_out.weight", "encoder_level1.0.norm2.body.weight", "encoder_level1.0.norm2.body.bias", "encoder_level1.0.ffn.project_in.weight", "encoder_level1.0.ffn.dwconv.weight", "encoder_level1.0.ffn.project_out.weight", "encoder_level1.1.norm1.body.weight", "encoder_level1.1.norm1.body.bias", "encoder_level1.1.attn.temperature", "encoder_level1.1.attn.qkv.weight", "encoder_level1.1.attn.qkv_dwconv.weight", "encoder_level1.1.attn.project_out.weight", "encoder_level1.1.norm2.body.weight", "encoder_level1.1.norm2.body.bias", "encoder_level1.1.ffn.project_in.weight", "encoder_level1.1.ffn.dwconv.weight", "encoder_level1.1.ffn.project_out.weight", "encoder_level1.2.norm1.body.weight", "encoder_level1.2.norm1.body.bias", "encoder_level1.2.attn.temperature", "encoder_level1.2.attn.qkv.weight", "encoder_level1.2.attn.qkv_dwconv.weight", "encoder_level1.2.attn.project_out.weight", "encoder_level1.2.norm2.body.weight", "encoder_level1.2.norm2.body.bias", "encoder_level1.2.ffn.project_in.weight", "encoder_level1.2.ffn.dwconv.weight", "encoder_level1.2.ffn.project_out.weight", "encoder_level1.3.norm1.body.weight", "encoder_level1.3.norm1.body.bias", "encoder_level1.3.attn.temperature", "encoder_level1.3.attn.qkv.weight", "encoder_level1.3.attn.qkv_dwconv.weight", "encoder_level1.3.attn.project_out.weight", "encoder_level1.3.norm2.body.weight", "encoder_level1.3.norm2.body.bias", "encoder_level1.3.ffn.project_in.weight", "encoder_level1.3.ffn.dwconv.weight", "encoder_level1.3.ffn.project_out.weight", "down1_2.body.0.weight", "encoder_level2.0.norm1.body.weight", "encoder_level2.0.norm1.body.bias", "encoder_level2.0.attn.temperature", "encoder_level2.0.attn.qkv.weight", "encoder_level2.0.attn.qkv_dwconv.weight", "encoder_level2.0.attn.project_out.weight", "encoder_level2.0.norm2.body.weight", "encoder_level2.0.norm2.body.bias", "encoder_level2.0.ffn.project_in.weight", "encoder_level2.0.ffn.dwconv.weight", "encoder_level2.0.ffn.project_out.weight", "encoder_level2.1.norm1.body.weight", "encoder_level2.1.norm1.body.bias", "encoder_level2.1.attn.temperature", "encoder_level2.1.attn.qkv.weight", "encoder_level2.1.attn.qkv_dwconv.weight", "encoder_level2.1.attn.project_out.weight", "encoder_level2.1.norm2.body.weight", "encoder_level2.1.norm2.body.bias", "encoder_level2.1.ffn.project_in.weight", "encoder_level2.1.ffn.dwconv.weight", "encoder_level2.1.ffn.project_out.weight", "encoder_level2.2.norm1.body.weight", "encoder_level2.2.norm1.body.bias", "encoder_level2.2.attn.temperature", "encoder_level2.2.attn.qkv.weight", "encoder_level2.2.attn.qkv_dwconv.weight", "encoder_level2.2.attn.project_out.weight", "encoder_level2.2.norm2.body.weight", "encoder_level2.2.norm2.body.bias", "encoder_level2.2.ffn.project_in.weight", "encoder_level2.2.ffn.dwconv.weight", "encoder_level2.2.ffn.project_out.weight", "encoder_level2.3.norm1.body.weight", "encoder_level2.3.norm1.body.bias", "encoder_level2.3.attn.temperature", "encoder_level2.3.attn.qkv.weight", "encoder_level2.3.attn.qkv_dwconv.weight", "encoder_level2.3.attn.project_out.weight", "encoder_level2.3.norm2.body.weight", "encoder_level2.3.norm2.body.bias", "encoder_level2.3.ffn.project_in.weight", "encoder_level2.3.ffn.dwconv.weight", "encoder_level2.3.ffn.project_out.weight", "encoder_level2.4.norm1.body.weight", "encoder_level2.4.norm1.body.bias", "encoder_level2.4.attn.temperature", "encoder_level2.4.attn.qkv.weight", "encoder_level2.4.attn.qkv_dwconv.weight", "encoder_level2.4.attn.project_out.weight", "encoder_level2.4.norm2.body.weight", "encoder_level2.4.norm2.body.bias", "encoder_level2.4.ffn.project_in.weight", "encoder_level2.4.ffn.dwconv.weight", "encoder_level2.4.ffn.project_out.weight", "encoder_level2.5.norm1.body.weight", "encoder_level2.5.norm1.body.bias", "encoder_level2.5.attn.temperature", "encoder_level2.5.attn.qkv.weight", "encoder_level2.5.attn.qkv_dwconv.weight", "encoder_level2.5.attn.project_out.weight", "encoder_level2.5.norm2.body.weight", "encoder_level2.5.norm2.body.bias", "encoder_level2.5.ffn.project_in.weight", "encoder_level2.5.ffn.dwconv.weight", "encoder_level2.5.ffn.project_out.weight", "down2_3.body.0.weight", "encoder_level3.0.norm1.body.weight", "encoder_level3.0.norm1.body.bias", "encoder_level3.0.attn.temperature", "encoder_level3.0.attn.qkv.weight", "encoder_level3.0.attn.qkv_dwconv.weight", "encoder_level3.0.attn.project_out.weight", "encoder_level3.0.norm2.body.weight", "encoder_level3.0.norm2.body.bias", "encoder_level3.0.ffn.project_in.weight", "encoder_level3.0.ffn.dwconv.weight", "encoder_level3.0.ffn.project_out.weight", "encoder_level3.1.norm1.body.weight", "encoder_level3.1.norm1.body.bias", "encoder_level3.1.attn.temperature", "encoder_level3.1.attn.qkv.weight", "encoder_level3.1.attn.qkv_dwconv.weight", "encoder_level3.1.attn.project_out.weight", "encoder_level3.1.norm2.body.weight", "encoder_level3.1.norm2.body.bias", "encoder_level3.1.ffn.project_in.weight", "encoder_level3.1.ffn.dwconv.weight", "encoder_level3.1.ffn.project_out.weight", "encoder_level3.2.norm1.body.weight", "encoder_level3.2.norm1.body.bias", "encoder_level3.2.attn.temperature", "encoder_level3.2.attn.qkv.weight", "encoder_level3.2.attn.qkv_dwconv.weight", "encoder_level3.2.attn.project_out.weight", "encoder_level3.2.norm2.body.weight", "encoder_level3.2.norm2.body.bias", "encoder_level3.2.ffn.project_in.weight", "encoder_level3.2.ffn.dwconv.weight", "encoder_level3.2.ffn.project_out.weight", "encoder_level3.3.norm1.body.weight", "encoder_level3.3.norm1.body.bias", "encoder_level3.3.attn.temperature", "encoder_level3.3.attn.qkv.weight", "encoder_level3.3.attn.qkv_dwconv.weight", "encoder_level3.3.attn.project_out.weight", "encoder_level3.3.norm2.body.weight", "encoder_level3.3.norm2.body.bias", "encoder_level3.3.ffn.project_in.weight", "encoder_level3.3.ffn.dwconv.weight", "encoder_level3.3.ffn.project_out.weight", "encoder_level3.4.norm1.body.weight", "encoder_level3.4.norm1.body.bias", "encoder_level3.4.attn.temperature", "encoder_level3.4.attn.qkv.weight", "encoder_level3.4.attn.qkv_dwconv.weight", "encoder_level3.4.attn.project_out.weight", "encoder_level3.4.norm2.body.weight", "encoder_level3.4.norm2.body.bias", "encoder_level3.4.ffn.project_in.weight", "encoder_level3.4.ffn.dwconv.weight", "encoder_level3.4.ffn.project_out.weight", "encoder_level3.5.norm1.body.weight", "encoder_level3.5.norm1.body.bias", "encoder_level3.5.attn.temperature", "encoder_level3.5.attn.qkv.weight", "encoder_level3.5.attn.qkv_dwconv.weight", "encoder_level3.5.attn.project_out.weight", "encoder_level3.5.norm2.body.weight", "encoder_level3.5.norm2.body.bias", "encoder_level3.5.ffn.project_in.weight", "encoder_level3.5.ffn.dwconv.weight", "encoder_level3.5.ffn.project_out.weight", "down3_4.body.0.weight", "latent.0.norm1.body.weight", "latent.0.norm1.body.bias", "latent.0.attn.temperature", "latent.0.attn.qkv.weight", "latent.0.attn.qkv_dwconv.weight", "latent.0.attn.project_out.weight", "latent.0.norm2.body.weight", "latent.0.norm2.body.bias", "latent.0.ffn.project_in.weight", "latent.0.ffn.dwconv.weight", "latent.0.ffn.project_out.weight", "latent.1.norm1.body.weight", "latent.1.norm1.body.bias", "latent.1.attn.temperature", "latent.1.attn.qkv.weight", "latent.1.attn.qkv_dwconv.weight", "latent.1.attn.project_out.weight", "latent.1.norm2.body.weight", "latent.1.norm2.body.bias", "latent.1.ffn.project_in.weight", "latent.1.ffn.dwconv.weight", "latent.1.ffn.project_out.weight", "latent.2.norm1.body.weight", "latent.2.norm1.body.bias", "latent.2.attn.temperature", "latent.2.attn.qkv.weight", "latent.2.attn.qkv_dwconv.weight", "latent.2.attn.project_out.weight", "latent.2.norm2.body.weight", "latent.2.norm2.body.bias", "latent.2.ffn.project_in.weight", "latent.2.ffn.dwconv.weight", "latent.2.ffn.project_out.weight", "latent.3.norm1.body.weight", "latent.3.norm1.body.bias", "latent.3.attn.temperature", "latent.3.attn.qkv.weight", "latent.3.attn.qkv_dwconv.weight", "latent.3.attn.project_out.weight", "latent.3.norm2.body.weight", "latent.3.norm2.body.bias", "latent.3.ffn.project_in.weight", "latent.3.ffn.dwconv.weight", "latent.3.ffn.project_out.weight", "latent.4.norm1.body.weight", "latent.4.norm1.body.bias", "latent.4.attn.temperature", "latent.4.attn.qkv.weight", "latent.4.attn.qkv_dwconv.weight", "latent.4.attn.project_out.weight", "latent.4.norm2.body.weight", "latent.4.norm2.body.bias", "latent.4.ffn.project_in.weight", "latent.4.ffn.dwconv.weight", "latent.4.ffn.project_out.weight", "latent.5.norm1.body.weight", "latent.5.norm1.body.bias", "latent.5.attn.temperature", "latent.5.attn.qkv.weight", "latent.5.attn.qkv_dwconv.weight", "latent.5.attn.project_out.weight", "latent.5.norm2.body.weight", "latent.5.norm2.body.bias", "latent.5.ffn.project_in.weight", "latent.5.ffn.dwconv.weight", "latent.5.ffn.project_out.weight", "latent.6.norm1.body.weight", "latent.6.norm1.body.bias", "latent.6.attn.temperature", "latent.6.attn.qkv.weight", "latent.6.attn.qkv_dwconv.weight", "latent.6.attn.project_out.weight", "latent.6.norm2.body.weight", "latent.6.norm2.body.bias", "latent.6.ffn.project_in.weight", "latent.6.ffn.dwconv.weight", "latent.6.ffn.project_out.weight", "latent.7.norm1.body.weight", "latent.7.norm1.body.bias", "latent.7.attn.temperature", "latent.7.attn.qkv.weight", "latent.7.attn.qkv_dwconv.weight", "latent.7.attn.project_out.weight", "latent.7.norm2.body.weight", "latent.7.norm2.body.bias", "latent.7.ffn.project_in.weight", "latent.7.ffn.dwconv.weight", "latent.7.ffn.project_out.weight", "up4_3.body.0.weight", "reduce_chan_level3.weight", "decoder_level3.0.norm1.body.weight", "decoder_level3.0.norm1.body.bias", "decoder_level3.0.attn.temperature", "decoder_level3.0.attn.qkv.weight", "decoder_level3.0.attn.qkv_dwconv.weight", "decoder_level3.0.attn.project_out.weight", "decoder_level3.0.norm2.body.weight", "decoder_level3.0.norm2.body.bias", "decoder_level3.0.ffn.project_in.weight", "decoder_level3.0.ffn.dwconv.weight", "decoder_level3.0.ffn.project_out.weight", "decoder_level3.1.norm1.body.weight", "decoder_level3.1.norm1.body.bias", "decoder_level3.1.attn.temperature", "decoder_level3.1.attn.qkv.weight", "decoder_level3.1.attn.qkv_dwconv.weight", "decoder_level3.1.attn.project_out.weight", "decoder_level3.1.norm2.body.weight", "decoder_level3.1.norm2.body.bias", "decoder_level3.1.ffn.project_in.weight", "decoder_level3.1.ffn.dwconv.weight", "decoder_level3.1.ffn.project_out.weight", "decoder_level3.2.norm1.body.weight", "decoder_level3.2.norm1.body.bias", "decoder_level3.2.attn.temperature", "decoder_level3.2.attn.qkv.weight", "decoder_level3.2.attn.qkv_dwconv.weight", "decoder_level3.2.attn.project_out.weight", "decoder_level3.2.norm2.body.weight", "decoder_level3.2.norm2.body.bias", "decoder_level3.2.ffn.project_in.weight", "decoder_level3.2.ffn.dwconv.weight", "decoder_level3.2.ffn.project_out.weight", "decoder_level3.3.norm1.body.weight", "decoder_level3.3.norm1.body.bias", "decoder_level3.3.attn.temperature", "decoder_level3.3.attn.qkv.weight", "decoder_level3.3.attn.qkv_dwconv.weight", "decoder_level3.3.attn.project_out.weight", "decoder_level3.3.norm2.body.weight", "decoder_level3.3.norm2.body.bias", "decoder_level3.3.ffn.project_in.weight", "decoder_level3.3.ffn.dwconv.weight", "decoder_level3.3.ffn.project_out.weight", "decoder_level3.4.norm1.body.weight", "decoder_level3.4.norm1.body.bias", "decoder_level3.4.attn.temperature", "decoder_level3.4.attn.qkv.weight", "decoder_level3.4.attn.qkv_dwconv.weight", "decoder_level3.4.attn.project_out.weight", "decoder_level3.4.norm2.body.weight", "decoder_level3.4.norm2.body.bias", "decoder_level3.4.ffn.project_in.weight", "decoder_level3.4.ffn.dwconv.weight", "decoder_level3.4.ffn.project_out.weight", "decoder_level3.5.norm1.body.weight", "decoder_level3.5.norm1.body.bias", "decoder_level3.5.attn.temperature", "decoder_level3.5.attn.qkv.weight", "decoder_level3.5.attn.qkv_dwconv.weight", "decoder_level3.5.attn.project_out.weight", "decoder_level3.5.norm2.body.weight", "decoder_level3.5.norm2.body.bias", "decoder_level3.5.ffn.project_in.weight", "decoder_level3.5.ffn.dwconv.weight", "decoder_level3.5.ffn.project_out.weight", "up3_2.body.0.weight", "reduce_chan_level2.weight", "decoder_level2.0.norm1.body.weight", "decoder_level2.0.norm1.body.bias", "decoder_level2.0.attn.temperature", "decoder_level2.0.attn.qkv.weight", "decoder_level2.0.attn.qkv_dwconv.weight", "decoder_level2.0.attn.project_out.weight", "decoder_level2.0.norm2.body.weight", "decoder_level2.0.norm2.body.bias", "decoder_level2.0.ffn.project_in.weight", "decoder_level2.0.ffn.dwconv.weight", "decoder_level2.0.ffn.project_out.weight", "decoder_level2.1.norm1.body.weight", "decoder_level2.1.norm1.body.bias", "decoder_level2.1.attn.temperature", "decoder_level2.1.attn.qkv.weight", "decoder_level2.1.attn.qkv_dwconv.weight", "decoder_level2.1.attn.project_out.weight", "decoder_level2.1.norm2.body.weight", "decoder_level2.1.norm2.body.bias", "decoder_level2.1.ffn.project_in.weight", "decoder_level2.1.ffn.dwconv.weight", "decoder_level2.1.ffn.project_out.weight", "decoder_level2.2.norm1.body.weight", "decoder_level2.2.norm1.body.bias", "decoder_level2.2.attn.temperature", "decoder_level2.2.attn.qkv.weight", "decoder_level2.2.attn.qkv_dwconv.weight", "decoder_level2.2.attn.project_out.weight", "decoder_level2.2.norm2.body.weight", "decoder_level2.2.norm2.body.bias", "decoder_level2.2.ffn.project_in.weight", "decoder_level2.2.ffn.dwconv.weight", "decoder_level2.2.ffn.project_out.weight", "decoder_level2.3.norm1.body.weight", "decoder_level2.3.norm1.body.bias", "decoder_level2.3.attn.temperature", "decoder_level2.3.attn.qkv.weight", "decoder_level2.3.attn.qkv_dwconv.weight", "decoder_level2.3.attn.project_out.weight", "decoder_level2.3.norm2.body.weight", "decoder_level2.3.norm2.body.bias", "decoder_level2.3.ffn.project_in.weight", "decoder_level2.3.ffn.dwconv.weight", "decoder_level2.3.ffn.project_out.weight", "decoder_level2.4.norm1.body.weight", "decoder_level2.4.norm1.body.bias", "decoder_level2.4.attn.temperature", "decoder_level2.4.attn.qkv.weight", "decoder_level2.4.attn.qkv_dwconv.weight", "decoder_level2.4.attn.project_out.weight", "decoder_level2.4.norm2.body.weight", "decoder_level2.4.norm2.body.bias", "decoder_level2.4.ffn.project_in.weight", "decoder_level2.4.ffn.dwconv.weight", "decoder_level2.4.ffn.project_out.weight", "decoder_level2.5.norm1.body.weight", "decoder_level2.5.norm1.body.bias", "decoder_level2.5.attn.temperature", "decoder_level2.5.attn.qkv.weight", "decoder_level2.5.attn.qkv_dwconv.weight", "decoder_level2.5.attn.project_out.weight", "decoder_level2.5.norm2.body.weight", "decoder_level2.5.norm2.body.bias", "decoder_level2.5.ffn.project_in.weight", "decoder_level2.5.ffn.dwconv.weight", "decoder_level2.5.ffn.project_out.weight", "up2_1.body.0.weight", "decoder_level1.0.norm1.body.weight", "decoder_level1.0.norm1.body.bias", "decoder_level1.0.attn.temperature", "decoder_level1.0.attn.qkv.weight", "decoder_level1.0.attn.qkv_dwconv.weight", "decoder_level1.0.attn.project_out.weight", "decoder_level1.0.norm2.body.weight", "decoder_level1.0.norm2.body.bias", "decoder_level1.0.ffn.project_in.weight", "decoder_level1.0.ffn.dwconv.weight", "decoder_level1.0.ffn.project_out.weight", "decoder_level1.1.norm1.body.weight", "decoder_level1.1.norm1.body.bias", "decoder_level1.1.attn.temperature", "decoder_level1.1.attn.qkv.weight", "decoder_level1.1.attn.qkv_dwconv.weight", "decoder_level1.1.attn.project_out.weight", "decoder_level1.1.norm2.body.weight", "decoder_level1.1.norm2.body.bias", "decoder_level1.1.ffn.project_in.weight", "decoder_level1.1.ffn.dwconv.weight", "decoder_level1.1.ffn.project_out.weight", "decoder_level1.2.norm1.body.weight", "decoder_level1.2.norm1.body.bias", "decoder_level1.2.attn.temperature", "decoder_level1.2.attn.qkv.weight", "decoder_level1.2.attn.qkv_dwconv.weight", "decoder_level1.2.attn.project_out.weight", "decoder_level1.2.norm2.body.weight", "decoder_level1.2.norm2.body.bias", "decoder_level1.2.ffn.project_in.weight", "decoder_level1.2.ffn.dwconv.weight", "decoder_level1.2.ffn.project_out.weight", "decoder_level1.3.norm1.body.weight", "decoder_level1.3.norm1.body.bias", "decoder_level1.3.attn.temperature", "decoder_level1.3.attn.qkv.weight", "decoder_level1.3.attn.qkv_dwconv.weight", "decoder_level1.3.attn.project_out.weight", "decoder_level1.3.norm2.body.weight", "decoder_level1.3.norm2.body.bias", "decoder_level1.3.ffn.project_in.weight", "decoder_level1.3.ffn.dwconv.weight", "decoder_level1.3.ffn.project_out.weight", "refinement.0.norm1.body.weight", "refinement.0.norm1.body.bias", "refinement.0.attn.temperature", "refinement.0.attn.qkv.weight", "refinement.0.attn.qkv_dwconv.weight", "refinement.0.attn.project_out.weight", "refinement.0.norm2.body.weight", "refinement.0.norm2.body.bias", "refinement.0.ffn.project_in.weight", "refinement.0.ffn.dwconv.weight", "refinement.0.ffn.project_out.weight", "refinement.1.norm1.body.weight", "refinement.1.norm1.body.bias", "refinement.1.attn.temperature", "refinement.1.attn.qkv.weight", "refinement.1.attn.qkv_dwconv.weight", "refinement.1.attn.project_out.weight", "refinement.1.norm2.body.weight", "refinement.1.norm2.body.bias", "refinement.1.ffn.project_in.weight", "refinement.1.ffn.dwconv.weight", "refinement.1.ffn.project_out.weight", "refinement.2.norm1.body.weight", "refinement.2.norm1.body.bias", "refinement.2.attn.temperature", "refinement.2.attn.qkv.weight", "refinement.2.attn.qkv_dwconv.weight", "refinement.2.attn.project_out.weight", "refinement.2.norm2.body.weight", "refinement.2.norm2.body.bias", "refinement.2.ffn.project_in.weight", "refinement.2.ffn.dwconv.weight", "refinement.2.ffn.project_out.weight", "refinement.3.norm1.body.weight", "refinement.3.norm1.body.bias", "refinement.3.attn.temperature", "refinement.3.attn.qkv.weight", "refinement.3.attn.qkv_dwconv.weight", "refinement.3.attn.project_out.weight", "refinement.3.norm2.body.weight", "refinement.3.norm2.body.bias", "refinement.3.ffn.project_in.weight", "refinement.3.ffn.dwconv.weight", "refinement.3.ffn.project_out.weight", "output.weight". 

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')