# 可视化展示模型去雾效果
## 第一步：设置环境变量

In [7]:
import platform
!BASICSR_JIT=True
platform = platform.system()
if platform == 'Linux':
    !BASICSR_JIT=True
else:
    !set BASICSR_JIT=True
!CUDA_VISIBLE_DEVICES=0





## 第二步：导入相关依赖

In [12]:
from basicsr.archs.itb_arch import FusionRefine
import torch
import os
import pyiqa
import glob
from tqdm import tqdm
import torchvision
from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt
from PIL import Image

## 第三步：定义变量

In [8]:
# 指定模型运算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 指定预训练模型存放位置
pretrained_net_path = "/mnt/e/DeepLearningCopies/2023/RIDCP/pretrained_models/ITB-Train-Best/DENSE-HAZE-2042-06374.pth"
# 指定待评估图片路径
haze_img_path = "/mnt/d/DeepLearning/dataset/Dense-Haze/hazy"
# 指定输出图像保存路径
output_img_path = "/mnt/e/DeepLearningCopies/2023/RIDCP/ohaze_results/DenseHaze"
# 指定图像最大分辨率
# 分辨率过大容易爆显存，超过最大分辨率的将会降采样后交由模型处理
max_size = 1000 * 1000

## 第四步：构建模型

In [9]:
!BASICSR_JIT=True
# 构建模型，加载预训练权重
opt = {
    "LQ_stage": True,
    "use_weight": False,
    "weight_alpha": -21.25
}
sr_model = FusionRefine(opt=opt).to(device)
sr_model.load_state_dict(torch.load(pretrained_net_path)['params'], strict=False)
sr_model.eval()

{'LQ_stage': True, 'use_weight': False, 'weight_alpha': -21.25}
{}


FusionRefine(
  (feature_extract): RIDCPNew(
    (vq_encoder): VQEncoder(
      (in_conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
      (blocks): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (1): ResBlock(
            (conv): Sequential(
              (0): NormLayer(
                (norm): GroupNorm(32, 128, eps=1e-06, affine=True)
              )
              (1): ActLayer(
                (func): SiLU(inplace=True)
              )
              (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (3): NormLayer(
                (norm): GroupNorm(32, 128, eps=1e-06, affine=True)
              )
              (4): ActLayer(
                (func): SiLU(inplace=True)
              )
              (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
          )
          (2): ResBlock(
            (conv): Seque

## 第五步：处理有雾图像

In [10]:
psnr = pyiqa.create_metric("psnr", device=device)
ssim = pyiqa.create_metric("ssim", device=device)

In [11]:
if os.path.isfile(haze_img_path):
    paths = [haze_img_path]
else:
    paths = sorted(glob.glob(os.path.join(haze_img_path, '*')))

In [12]:
images = []
images_metric = []
count = 0
pbar = tqdm(total=len(paths), unit='image')
for idx, path in enumerate(paths):
    img_name = os.path.basename(path)
    save_path = os.path.join(output_img_path, f'{img_name}')
    pbar.set_description(f'处理图像 {img_name} 中')

    input_img = ToTensor()(Image.open(path).convert('RGB')).to(device)[None, ::]
    h, w = input_img.shape[2:]
    if h * w < max_size:
        output, _ = sr_model.test(input_img)
    elif h * w > max_size * 2:
        input_img = torch.nn.UpsamplingBilinear2d((h//3, w//3))(input_img)
        output = sr_model.test_tile(input_img, tile_size=960, tile_pad=64)
        output = torch.nn.UpsamplingBilinear2d((h, w))(output)
    else:
        input_img = torch.nn.UpsamplingBilinear2d((h//2, w//2))(input_img)
        output, _ = sr_model.test(input_img)
        output = torch.nn.UpsamplingBilinear2d((h, w))(output)
        
    torchvision.utils.save_image(output, save_path)
    
    psnr_hl = psnr(input_img, output).item()
    ssim_hl = ssim(input_img, output).item()

    clear_out = input_img.squeeze().permute(1, 2, 0)
    hazy_out = output.squeeze().permute(1, 2, 0)
    images_metric.append({
        "Name": img_name,
        "PSNR": psnr_hl,
        "SSIM": ssim_hl,
    })
    images.append(clear_out)
    images.append(hazy_out)
    count += 1
    pbar.update(1)
pbar.close()


处理图像 01_hazy.png 中:   0%|          | 0/55 [00:39<?, ?image/s]

处理图像 01_hazy.png 中:   0%|          | 0/55 [00:00<?, ?image/s][A

NameError: name 'deform_conv_ext' is not defined

## 第六步：可视化模型去雾效果

In [ ]:
for i in range(count):
    fig, axs = plt.subplots(1, 2)
    for j, ax in enumerate(axs.flat):
        print(images[i * 2 + j])
        ax.imshow(images[i * 2 + j])
        ax.axis('off')
    plt.suptitle("Name: {} SSIM: {:.2f} PSNR: {:.2f}".format(
        images_metric[i]["Name"],
        images_metric[i]["SSIM"],
        images_metric[i]["PSNR"],
    ))
    plt.subplots_adjust(wspace=0, hspace=0, top=0.7)
    plt.tight_layout()
    plt.show()