<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/NAFNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }
%%capture
!nvidia-smi
%cd /content

!git clone https://github.com/megvii-research/NAFNet

%cd /content/NAFNet

!pip install -r requirements.txt
!pip install --upgrade --no-cache-dir gdown
!python3 setup.py develop --no_cuda_ext

%cd /content/NAFNet

import gdown

import torch

from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse
import numpy as np
import cv2
import matplotlib.pyplot as plt

import os
from google.colab import files
import shutil
import glob

import PIL.Image
import numpy as np
from IPython.display import display
from IPython.display import display, clear_output

%cd /content/NAFNet

# Deblur
if not os.path.exists("./experiments/pretrained_models/NAFNet-REDS-width64.pth"):
  gdown.download('https://drive.google.com/uc?id=14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X', "./experiments/pretrained_models/", quiet=False)

def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img
def img2tensor(img, bgr2rgb=False, float32=True):
    img = img.astype(np.float32) / 255.
    return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)

def display(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1)
  plt.title('Input image', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title('NAFNet output', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)

def single_image_inference(model, img, save_path):
      model.feed_data(data={'lq': img.unsqueeze(dim=0)})

      if model.opt['val'].get('grids', False):
          model.grids()

      model.test()

      if model.opt['val'].get('grids', False):
          model.grids_inverse()

      visuals = model.get_current_visuals()
      sr_img = tensor2img([visuals['result']])
      imwrite(sr_img, save_path)

def sr_display(LR_l, LR_r, SR_l, SR_r):
  h,w = SR_l.shape[:2]
  LR_l = cv2.resize(LR_l, (w,h), interpolation=cv2.INTER_CUBIC)
  LR_r = cv2.resize(LR_r, (w,h), interpolation=cv2.INTER_CUBIC)
  fig = plt.figure(figsize=(w//40, h//40))
  ax1 = fig.add_subplot(2, 2, 1)
  plt.title('Input image (Left)', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(2, 2, 2)
  plt.title('NAFSSR output (Left)', fontsize=16)
  ax2.axis('off')
  ax1.imshow(LR_l)
  ax2.imshow(SR_l)

  ax3 = fig.add_subplot(2, 2, 3)
  plt.title('Input image (Right)', fontsize=16)
  ax3.axis('off')
  ax4 = fig.add_subplot(2, 2, 4)
  plt.title('NAFSSR output (Right)', fontsize=16)
  ax4.axis('off')
  ax3.imshow(LR_r)
  ax4.imshow(SR_r)

  plt.subplots_adjust(wspace=0.04, hspace=0.04)

def stereo_image_inference(model, img_l, img_r, save_path):
      img = torch.cat([img_l, img_r], dim=0)
      model.feed_data(data={'lq': img.unsqueeze(dim=0)})

      if model.opt['val'].get('grids', False):
          model.grids()

      model.test()

      if model.opt['val'].get('grids', False):
          model.grids_inverse()

      visuals = model.get_current_visuals()
      img_L = visuals['result'][:,:3]
      img_R = visuals['result'][:,3:]
      img_L, img_R = tensor2img([img_L, img_R])

      imwrite(img_L, save_path.format('L'))
      imwrite(img_R, save_path.format('R'))

%cd /content/NAFNet

opt_path = 'options/test/REDS/NAFNet-width64.yml'
opt = parse(opt_path, is_train=False)
opt['dist'] = False
NAFNet = create_model(opt)
clear_output(wait=True)

In [None]:
#@title ##**Upload images** {display-mode: "form" }

#input_folder = "test_images/old"
#output_folder = "output"

upload_folder = "/content/upload"
result_folder = "/content/results"

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.makedirs(upload_folder)


basepath = os.getcwd()
uploaded = files.upload()
for filename in uploaded.keys():
    shutil.move(os.path.join(basepath, filename), os.path.join(upload_folder, filename))



In [None]:
#@title ##**Run** { display-mode: "form" }
%cd /content/NAFNet
import glob
import os
import time
from tqdm import tqdm

if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder)

upload_folder = "/content/upload"
result_folder = "/content/results"

upload_files = set()
while not upload_files:
    try:
        upload_files = set(os.listdir(upload_folder))
    except OSError as e:
        print(f"Error: {e}, retrying...")
        time.sleep(2)

result_files = set(os.listdir(result_folder))

input_list = [file_path for file_path in glob.glob(os.path.join(upload_folder, '*'))
              if os.path.basename(file_path) not in result_files and os.path.basename(file_path) in upload_files]


for input_path in tqdm(input_list, desc="Processing images"):
    img_input = imread(input_path)
    inp = img2tensor(img_input)
    output_path = os.path.join(result_folder, os.path.basename(input_path))
    single_image_inference(NAFNet, inp, output_path)


In [None]:
#@title ##**Visualize** { display-mode: "form" }
from IPython.display import display as ipy_display

filenames_upload = os.listdir(upload_folder)
filenames_upload.sort()

filenames_upload_output = os.listdir(result_folder)
filenames_upload_output.sort()

for filename, filename_output in zip(filenames_upload, filenames_upload_output):
    image_original = PIL.Image.open(os.path.join(upload_folder, filename))
    image_restore = PIL.Image.open(os.path.join(result_folder, filename_output))

    # Resize images to have a maximum width of 500 pixels
    max_width = 500
    width_original, height_original = image_original.size
    width_restore, height_restore = image_restore.size
    if width_original > max_width:
        new_height = int(height_original * max_width / width_original)
        image_original = image_original.resize((max_width, new_height))
    if width_restore > max_width:
        new_height = int(height_restore * max_width / width_restore)
        image_restore = image_restore.resize((max_width, new_height))

    # Combine and display images
    combined_image = PIL.Image.fromarray(np.hstack((np.array(image_original), np.array(image_restore))))
    ipy_display(combined_image)
    print("")


In [None]:
#@title ##**Download results** { display-mode: "form" }
#output_folder = os.path.join(upload_output_path, "final_output")
files_in_folder = os.listdir(result_folder)
zip_file = "download.zip"

if len(files_in_folder) == 1:
    file_to_download = os.path.join(result_folder, files_in_folder[0])
    files.download(file_to_download)
else:
    if os.path.exists(os.path.join(result_folder, zip_file)):
      os.remove(os.path.join(result_folder, zip_file))
    os.system(f"cd {result_folder} && zip -r -j {zip_file} * && cd ..")
    files.download(os.path.join(result_folder, zip_file))