In [2]:
import torch

from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse
import numpy as np
import cv2
import matplotlib.pyplot as plt

def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img

def img2tensor(img, bgr2rgb=False, float32=True):
    img = img.astype(np.float32) / 255.
    return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)

def display(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1)
  plt.title('Input image', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title('NAFNet output', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)

def single_image_inference(model, img, save_path):
      model.feed_data(data={'lq': img.unsqueeze(dim=0)})

      if model.opt['val'].get('grids', False):
          model.grids()

      model.test()

      if model.opt['val'].get('grids', False):
          model.grids_inverse()

      visuals = model.get_current_visuals()
      sr_img = tensor2img([visuals['result']])
      imwrite(sr_img, save_path)


## Create Model

In [3]:
opt_path = './options/test/REDS/NAFNet-width64.yml'
opt = parse(opt_path, is_train=False)
opt['dist'] = False
NAFNet = create_model(opt)

2024-05-07 17:15:49,075 INFO: Model [ImageRestorationModel] is created.


 load net keys <built-in method keys of dict object at 0x000001FFE46BF440>


# Inference and Show results

In [None]:
input_path = './input/00000000.png'
output_path = './output/00000000.png'

img_input = imread(input_path)
inp = img2tensor(img_input)

In [None]:
inp2 = torch.tensor([inp,inp])
inp = torch.unsqueeze(inp, 0)
NAFNet.feed_data(data={'lq': inp})
NAFNet.test()
visuals = NAFNet.get_current_visuals()
sr_img = tensor2img([visuals['result']])

In [None]:
display(img_input, sr_img)