<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/swin2sr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }

!nvidia-smi
! pip -q install timm
! git clone https://github.com/mv-lab/swin2sr.git
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import shutil
from tqdm import tqdm
import shutil, sys
import re
import io
import IPython.display
import numpy as np
import PIL.Image
from google.colab import files
import shutil
import os

os.chdir("./swin2sr")

def load_img (filename, debug=False, norm=True, resize=None):
    img = cv2.imread(filename)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if norm:   
        img = img / 255.
        img = img.astype(np.float32)
    if debug:
        print (img.shape, img.dtype, img.min(), img.max())
        
    if resize:
        img = cv2.resize(img, (resize[0], resize[1]))
        
    return img

def plot_all (images, axis='off', figsize=(16, 8)):
    
    fig = plt.figure(figsize=figsize, dpi=80)
    nplots = len(images)
    for i in range(nplots):
        plt.subplot(1,nplots,i+1)
        plt.axis(axis)
        plt.imshow(images[i])
    plt.show()

# Clean and create the inputs/ directory from scratch

!rm -r inputs
!mkdir inputs

# Put some images into inputs/
!cp testsets/real-inputs/* inputs/ 

# check the images in input/
!ls inputs

def plot_all(images):
    fig, axes = plt.subplots(1, len(images), figsize=(20, 20))
    for ax, img in zip(axes, images):
        ax.imshow(img)
        ax.axis("off")
    plt.show()

In [None]:
#@title ##**Upload images** { display-mode: "form" }


upload_folder = "/content/swin2sr/inputs"
result_folder = "/content/swin2sr/results"

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.makedirs(upload_folder)

basepath = os.getcwd()
uploaded = files.upload()
for filename in uploaded.keys():
    shutil.move(os.path.join(basepath, filename), os.path.join(upload_folder, filename))

In [None]:
#@title ##**Run** { display-mode: "form" }

if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder)

model_name = "compressed_sr" #@param ["compressed_sr","real_sr","lightweight_sr","classical_sr_2","classical_sr_4"]

command = "python main_test_swin2sr.py --folder_lq ./inputs/ --save_img_only"
if model_name == "compressed_sr":
    command += " --training_patch_size 48 --scale 4 --task compressed_sr --model_path model_zoo/swin2sr/Swin2SR_CompressedSR_X4_48.pth"
    
if model_name == "real_sr":
    command += " --training_patch_size 64 --scale 4 --task real_sr --model_path model_zoo/swin2sr/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth"
    
if model_name == "lightweight_sr":
    command += " --training_patch_size 64 --scale 2 --task lightweight_sr --model_path model_zoo/swin2sr/Swin2SR_Lightweight_X2_64.pth"
    
if model_name == "classical_sr_2":
    command += " --training_patch_size 64 --scale 2 --task classical_sr --model_path model_zoo/swin2sr/Swin2SR_ClassicalSR_X2_64.pth"
    
if model_name == "classical_sr_4":
    command += " --training_patch_size 64 --scale 4 --task classical_sr --model_path model_zoo/swin2sr/Swin2SR_ClassicalSR_X4_64.pth"
  
print(command)
os.system(command)

In [None]:
#@title ##**Visualize** { display-mode: "form" }

# получение списка всех файлов в папке inputs
inputs = sorted(glob(os.path.join(upload_folder, '**', '*'), recursive=True))

# получение списка всех файлов .png в папке results
outputs = sorted(glob(os.path.join(result_folder, '**', '*.png'), recursive=True))

# обработка и визуализация изображений
for filename, filename_output in zip(inputs, outputs):
    image_original = PIL.Image.open(os.path.join(upload_folder, filename))
    matching_result_file = [f for f in outputs if filename.split("/")[-1].split(".")[0] in f][0]
    image_restore = PIL.Image.open(os.path.join(result_folder, matching_result_file))

    # Resize images to have a maximum width of 500 pixels
    max_width = 500
    width_original, height_original = image_original.size
    width_restore, height_restore = image_restore.size
    if width_original != width_restore:
        new_height_original = int(height_original * max_width / width_original)
        new_height_restore = int(height_restore * max_width / width_restore)
        new_height = max(new_height_original, new_height_restore)
        image_original = image_original.resize((max_width, new_height))
        image_restore = image_restore.resize((max_width, new_height))
        
    # Combine images horizontally
    array_original = np.array(image_original)
    array_restore = np.array(image_restore)
    combined_array = np.concatenate([array_original, array_restore], axis=1)
    combined_image = PIL.Image.fromarray(combined_array)

    # Display the combined image
    display(combined_image)
    print("")


In [None]:
#@title ##**Download results** { display-mode: "form" }

outputs_png = sorted(glob(os.path.join('results', '**', '*.png'), recursive=True))

if len(outputs_png) == 1:
    files.download(outputs_png[0])
else:
    zip_file = 'results.zip'
    for file_path in outputs_png:
        os.system(f"cp {file_path} {result_folder}")
    os.system(f"cd {result_folder} && zip -r -j {zip_file} *.png && cd ..")
    files.download(os.path.join(result_folder, zip_file))