In [14]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import cv2
from pathlib import Path
from tqdm import tqdm
from torchvision import transforms

In [15]:
mirnetv2_repo_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'MIRNetv2'))
if mirnetv2_repo_path not in sys.path:
    sys.path.insert(0, mirnetv2_repo_path)

In [16]:
from basicsr.archs.mirnet_v2_arch import MIRNet_v2

In [17]:
data_root  = Path("../data/patches")
hr_dir = data_root / "HR_patches"
lr_dir = data_root / "mixed_LR_patches"

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patch_size = 192
batch_size = 8
num_epochs = 10
learning_rate = 1e-5

In [19]:
class PatchDataset(Dataset):
    def __init__(self,lr_dir,hr_dir):
        self.lr_dir = Path(lr_dir)
        self.hr_dir = Path(hr_dir)
        self.files = list(self.lr_dir.glob("*.png"))
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        lr_path = self.files[idx]
        hr_path = Path(self.hr_dir) / lr_path.name
        lr_img = cv2.imread(str(lr_path))
        hr_img = cv2.imread(str(hr_path))
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB)
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        lr_tensor = self.to_tensor(lr_img)
        hr_tensor = self.to_tensor(hr_img)
        return lr_tensor, hr_tensor

In [20]:
dataset = PatchDataset(lr_dir,hr_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [21]:
model = MIRNet_v2().to(device)

In [27]:
pretrained_path = "../MIRNetv2/pretrained_models/enhancement_fivek.pth"
state_dict = torch.load(pretrained_path, map_location=device)
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for MIRNet_v2:
	Missing key(s) in state_dict: "conv_in.weight", "body.0.body.0.dau_top.body.0.weight", "body.0.body.0.dau_top.body.2.weight", "body.0.body.0.dau_top.gcnet.conv_mask.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_mid.body.0.weight", "body.0.body.0.dau_mid.body.2.weight", "body.0.body.0.dau_mid.gcnet.conv_mask.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.0.dau_bot.body.0.weight", "body.0.body.0.dau_bot.body.2.weight", "body.0.body.0.dau_bot.gcnet.conv_mask.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.0.down2.body.0.bot.1.weight", "body.0.body.0.down4.0.body.0.bot.1.weight", "body.0.body.0.down4.1.body.0.bot.1.weight", "body.0.body.0.up21_1.body.0.bot.0.weight", "body.0.body.0.up21_2.body.0.bot.0.weight", "body.0.body.0.up32_1.body.0.bot.0.weight", "body.0.body.0.up32_2.body.0.bot.0.weight", "body.0.body.0.conv_out.weight", "body.0.body.0.skff_top.conv_du.0.weight", "body.0.body.0.skff_top.fcs.0.weight", "body.0.body.0.skff_top.fcs.1.weight", "body.0.body.0.skff_mid.conv_du.0.weight", "body.0.body.0.skff_mid.fcs.0.weight", "body.0.body.0.skff_mid.fcs.1.weight", "body.0.body.1.dau_top.body.0.weight", "body.0.body.1.dau_top.body.2.weight", "body.0.body.1.dau_top.gcnet.conv_mask.weight", "body.0.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.0.body.1.dau_mid.body.0.weight", "body.0.body.1.dau_mid.body.2.weight", "body.0.body.1.dau_mid.gcnet.conv_mask.weight", "body.0.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.0.body.1.dau_bot.body.0.weight", "body.0.body.1.dau_bot.body.2.weight", "body.0.body.1.dau_bot.gcnet.conv_mask.weight", "body.0.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.0.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.0.body.1.down2.body.0.bot.1.weight", "body.0.body.1.down4.0.body.0.bot.1.weight", "body.0.body.1.down4.1.body.0.bot.1.weight", "body.0.body.1.up21_1.body.0.bot.0.weight", "body.0.body.1.up21_2.body.0.bot.0.weight", "body.0.body.1.up32_1.body.0.bot.0.weight", "body.0.body.1.up32_2.body.0.bot.0.weight", "body.0.body.1.conv_out.weight", "body.0.body.1.skff_top.conv_du.0.weight", "body.0.body.1.skff_top.fcs.0.weight", "body.0.body.1.skff_top.fcs.1.weight", "body.0.body.1.skff_mid.conv_du.0.weight", "body.0.body.1.skff_mid.fcs.0.weight", "body.0.body.1.skff_mid.fcs.1.weight", "body.0.body.2.weight", "body.1.body.0.dau_top.body.0.weight", "body.1.body.0.dau_top.body.2.weight", "body.1.body.0.dau_top.gcnet.conv_mask.weight", "body.1.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.1.body.0.dau_mid.body.0.weight", "body.1.body.0.dau_mid.body.2.weight", "body.1.body.0.dau_mid.gcnet.conv_mask.weight", "body.1.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.1.body.0.dau_bot.body.0.weight", "body.1.body.0.dau_bot.body.2.weight", "body.1.body.0.dau_bot.gcnet.conv_mask.weight", "body.1.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.1.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.1.body.0.down2.body.0.bot.1.weight", "body.1.body.0.down4.0.body.0.bot.1.weight", "body.1.body.0.down4.1.body.0.bot.1.weight", "body.1.body.0.up21_1.body.0.bot.0.weight", "body.1.body.0.up21_2.body.0.bot.0.weight", "body.1.body.0.up32_1.body.0.bot.0.weight", "body.1.body.0.up32_2.body.0.bot.0.weight", "body.1.body.0.conv_out.weight", "body.1.body.0.skff_top.conv_du.0.weight", "body.1.body.0.skff_top.fcs.0.weight", "body.1.body.0.skff_top.fcs.1.weight", "body.1.body.0.skff_mid.conv_du.0.weight", "body.1.body.0.skff_mid.fcs.0.weight", "body.1.body.0.skff_mid.fcs.1.weight", "body.1.body.1.dau_top.body.0.weight", "body.1.body.1.dau_top.body.2.weight", "body.1.body.1.dau_top.gcnet.conv_mask.weight", "body.1.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.1.body.1.dau_mid.body.0.weight", "body.1.body.1.dau_mid.body.2.weight", "body.1.body.1.dau_mid.gcnet.conv_mask.weight", "body.1.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.1.body.1.dau_bot.body.0.weight", "body.1.body.1.dau_bot.body.2.weight", "body.1.body.1.dau_bot.gcnet.conv_mask.weight", "body.1.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.1.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.1.body.1.down2.body.0.bot.1.weight", "body.1.body.1.down4.0.body.0.bot.1.weight", "body.1.body.1.down4.1.body.0.bot.1.weight", "body.1.body.1.up21_1.body.0.bot.0.weight", "body.1.body.1.up21_2.body.0.bot.0.weight", "body.1.body.1.up32_1.body.0.bot.0.weight", "body.1.body.1.up32_2.body.0.bot.0.weight", "body.1.body.1.conv_out.weight", "body.1.body.1.skff_top.conv_du.0.weight", "body.1.body.1.skff_top.fcs.0.weight", "body.1.body.1.skff_top.fcs.1.weight", "body.1.body.1.skff_mid.conv_du.0.weight", "body.1.body.1.skff_mid.fcs.0.weight", "body.1.body.1.skff_mid.fcs.1.weight", "body.1.body.2.weight", "body.2.body.0.dau_top.body.0.weight", "body.2.body.0.dau_top.body.2.weight", "body.2.body.0.dau_top.gcnet.conv_mask.weight", "body.2.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.2.body.0.dau_mid.body.0.weight", "body.2.body.0.dau_mid.body.2.weight", "body.2.body.0.dau_mid.gcnet.conv_mask.weight", "body.2.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.2.body.0.dau_bot.body.0.weight", "body.2.body.0.dau_bot.body.2.weight", "body.2.body.0.dau_bot.gcnet.conv_mask.weight", "body.2.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.2.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.2.body.0.down2.body.0.bot.1.weight", "body.2.body.0.down4.0.body.0.bot.1.weight", "body.2.body.0.down4.1.body.0.bot.1.weight", "body.2.body.0.up21_1.body.0.bot.0.weight", "body.2.body.0.up21_2.body.0.bot.0.weight", "body.2.body.0.up32_1.body.0.bot.0.weight", "body.2.body.0.up32_2.body.0.bot.0.weight", "body.2.body.0.conv_out.weight", "body.2.body.0.skff_top.conv_du.0.weight", "body.2.body.0.skff_top.fcs.0.weight", "body.2.body.0.skff_top.fcs.1.weight", "body.2.body.0.skff_mid.conv_du.0.weight", "body.2.body.0.skff_mid.fcs.0.weight", "body.2.body.0.skff_mid.fcs.1.weight", "body.2.body.1.dau_top.body.0.weight", "body.2.body.1.dau_top.body.2.weight", "body.2.body.1.dau_top.gcnet.conv_mask.weight", "body.2.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.2.body.1.dau_mid.body.0.weight", "body.2.body.1.dau_mid.body.2.weight", "body.2.body.1.dau_mid.gcnet.conv_mask.weight", "body.2.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.2.body.1.dau_bot.body.0.weight", "body.2.body.1.dau_bot.body.2.weight", "body.2.body.1.dau_bot.gcnet.conv_mask.weight", "body.2.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.2.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.2.body.1.down2.body.0.bot.1.weight", "body.2.body.1.down4.0.body.0.bot.1.weight", "body.2.body.1.down4.1.body.0.bot.1.weight", "body.2.body.1.up21_1.body.0.bot.0.weight", "body.2.body.1.up21_2.body.0.bot.0.weight", "body.2.body.1.up32_1.body.0.bot.0.weight", "body.2.body.1.up32_2.body.0.bot.0.weight", "body.2.body.1.conv_out.weight", "body.2.body.1.skff_top.conv_du.0.weight", "body.2.body.1.skff_top.fcs.0.weight", "body.2.body.1.skff_top.fcs.1.weight", "body.2.body.1.skff_mid.conv_du.0.weight", "body.2.body.1.skff_mid.fcs.0.weight", "body.2.body.1.skff_mid.fcs.1.weight", "body.2.body.2.weight", "body.3.body.0.dau_top.body.0.weight", "body.3.body.0.dau_top.body.2.weight", "body.3.body.0.dau_top.gcnet.conv_mask.weight", "body.3.body.0.dau_top.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_top.gcnet.channel_add_conv.2.weight", "body.3.body.0.dau_mid.body.0.weight", "body.3.body.0.dau_mid.body.2.weight", "body.3.body.0.dau_mid.gcnet.conv_mask.weight", "body.3.body.0.dau_mid.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_mid.gcnet.channel_add_conv.2.weight", "body.3.body.0.dau_bot.body.0.weight", "body.3.body.0.dau_bot.body.2.weight", "body.3.body.0.dau_bot.gcnet.conv_mask.weight", "body.3.body.0.dau_bot.gcnet.channel_add_conv.0.weight", "body.3.body.0.dau_bot.gcnet.channel_add_conv.2.weight", "body.3.body.0.down2.body.0.bot.1.weight", "body.3.body.0.down4.0.body.0.bot.1.weight", "body.3.body.0.down4.1.body.0.bot.1.weight", "body.3.body.0.up21_1.body.0.bot.0.weight", "body.3.body.0.up21_2.body.0.bot.0.weight", "body.3.body.0.up32_1.body.0.bot.0.weight", "body.3.body.0.up32_2.body.0.bot.0.weight", "body.3.body.0.conv_out.weight", "body.3.body.0.skff_top.conv_du.0.weight", "body.3.body.0.skff_top.fcs.0.weight", "body.3.body.0.skff_top.fcs.1.weight", "body.3.body.0.skff_mid.conv_du.0.weight", "body.3.body.0.skff_mid.fcs.0.weight", "body.3.body.0.skff_mid.fcs.1.weight", "body.3.body.1.dau_top.body.0.weight", "body.3.body.1.dau_top.body.2.weight", "body.3.body.1.dau_top.gcnet.conv_mask.weight", "body.3.body.1.dau_top.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_top.gcnet.channel_add_conv.2.weight", "body.3.body.1.dau_mid.body.0.weight", "body.3.body.1.dau_mid.body.2.weight", "body.3.body.1.dau_mid.gcnet.conv_mask.weight", "body.3.body.1.dau_mid.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_mid.gcnet.channel_add_conv.2.weight", "body.3.body.1.dau_bot.body.0.weight", "body.3.body.1.dau_bot.body.2.weight", "body.3.body.1.dau_bot.gcnet.conv_mask.weight", "body.3.body.1.dau_bot.gcnet.channel_add_conv.0.weight", "body.3.body.1.dau_bot.gcnet.channel_add_conv.2.weight", "body.3.body.1.down2.body.0.bot.1.weight", "body.3.body.1.down4.0.body.0.bot.1.weight", "body.3.body.1.down4.1.body.0.bot.1.weight", "body.3.body.1.up21_1.body.0.bot.0.weight", "body.3.body.1.up21_2.body.0.bot.0.weight", "body.3.body.1.up32_1.body.0.bot.0.weight", "body.3.body.1.up32_2.body.0.bot.0.weight", "body.3.body.1.conv_out.weight", "body.3.body.1.skff_top.conv_du.0.weight", "body.3.body.1.skff_top.fcs.0.weight", "body.3.body.1.skff_top.fcs.1.weight", "body.3.body.1.skff_mid.conv_du.0.weight", "body.3.body.1.skff_mid.fcs.0.weight", "body.3.body.1.skff_mid.fcs.1.weight", "body.3.body.2.weight", "conv_out.weight". 
	Unexpected key(s) in state_dict: "params". 

In [32]:
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from runpy import run_path
from skimage import img_as_ubyte
import cv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt

# Define task
task = 'deblurring'

# Set parameters
# Parameters dictionary (I'm assuming this was defined earlier in your code)
parameters = {
    'inp_channels': 6,
    'n_feat': 80,
    'chan_factor': 1.5,
    'n_RRG': 4,
    'n_MRB': 2,
    'height': 3,
    'width': 2,
    'bias': False,
    'scale': 1,
    'task': task  # Assuming 'task' was defined earlier
}

def get_weights_and_parameters(task, parameters):
    if task == 'deblurring':
        weights = os.path.join('..','MIRNetv2', 'pretrained_models', 'dual_pixel_defocus_deblurring.pth')
    else:
        raise ValueError("Only 'deblurring' task is supported in this script.")
    return weights, parameters

# Load model
weights, parameters = get_weights_and_parameters(task, parameters)
load_arch = run_path(os.path.join('..','MIRNetv2','basicsr','models','archs','mirnet_v2_arch.py'))
model = load_arch['MIRNet_v2'](**parameters)  # Removed .cuda()

# Check if CUDA is available and use it only if it's available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

checkpoint = torch.load(weights, map_location=device)  # Added map_location to load on CPU if needed
model.load_state_dict(checkpoint['params'])
model.eval()

# Input LR and HR image paths
lr_img_path = "C:/Users/manas/image-sharpness/data/patches/LR_patches/0004_037.png"
hr_img_path = "C:/Users/manas/image-sharpness/data/patches/HR_patches/0004_037.png"

# Load images
lr_img = cv2.imread(lr_img_path)
lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB)
hr_img = cv2.imread(hr_img_path)
hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

# Inference
img_input = TF.to_tensor(lr_img).unsqueeze(0).to(device)  # Changed .cuda() to .to(device)
with torch.no_grad():
    restored = model(img_input)
restored = torch.clamp(restored, 0, 1)
restored_img = restored.squeeze(0).permute(1, 2, 0).cpu().numpy()
restored_img_ubyte = img_as_ubyte(restored_img)

# Compute PSNR and SSIM
psnr_val = psnr(hr_img, restored_img_ubyte, data_range=255)
ssim_val = ssim(hr_img, restored_img_ubyte, channel_axis=2, data_range=255)

print(f"PSNR: {psnr_val:.2f} dB")
print(f"SSIM: {ssim_val:.4f}")

# Show LR, Restored, HR
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(lr_img)
plt.title('LR Patch')
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(restored_img_ubyte)
plt.title('Restored Patch')
plt.axis('off')

plt.subplot(1,3,3)
plt.imshow(hr_img)
plt.title('HR Patch')
plt.axis('off')

plt.tight_layout()
plt.show()

RuntimeError: Error(s) in loading state_dict for MIRNet_v2:
	size mismatch for conv_in.weight: copying a param with shape torch.Size([80, 6, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 3, 3, 3]).