<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/BSRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }
%%capture
!nvidia-smi
!git clone https://github.com/cszn/BSRGAN.git
%cd BSRGAN


!pip install scikit-image pyyaml tqdm gdown

import os
if not os.path.exists('model_zoo'):
    os.makedirs('model_zoo')


!gdown --id 1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv -O model_zoo/BSRGAN.pth



In [None]:
#@title ##**Upload images** { display-mode: "form" }
from google.colab import files
import os
import shutil
from tqdm import tqdm
import os
import shutil, sys

upload_folder = '/content/BSRGAN/upload'

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.makedirs(upload_folder)

uploaded = files.upload()


for filename in uploaded.keys():
    with open(os.path.join(upload_folder, filename), 'wb') as f:
        f.write(uploaded[filename])

print(f"Images have been uploaded to {upload_folder}")


In [None]:
#@title ##**Run** { display-mode: "form" }
import os
import torch
import shutil
from utils import utils_image as util
from models.network_rrdbnet import RRDBNet as net

model_ids = {
    "BSRGAN": "1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv",
    "BSRGANx2": "1J-6NX3DB6GA0G8AN95vmgWs_2l9l4dJf",
    "BSRNet": "1JGJLiENPkOqi39bvQYa_jlIPlMk24iKH"
}

scale = "4"  #@param ["1","2","3","4"]
selected_model = "BSRGAN"  #@param ["BSRGAN", "BSRGANx2", "BSRNet"]

upload_folder = '/content/BSRGAN/upload'
result_folder = '/content/BSRGAN/results'

if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder)

if selected_model in model_ids:
    model_id = model_ids[selected_model]
    model_path = f'model_zoo/{selected_model}.pth'

    !gdown --id {model_id} -O {model_path}

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    sf = int(scale)
    model = net(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=sf)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    for filename in os.listdir(upload_folder):
        img_path = os.path.join(upload_folder, filename)
        img_L = util.imread_uint(img_path, n_channels=3)
        img_L = util.uint2tensor4(img_L)
        img_L = img_L.to(device)

        with torch.no_grad():
            img_E = model(img_L)

        img_E = util.tensor2uint(img_E)

        output_image_name = os.path.splitext(filename)[0] + f'_{selected_model}.png'
        util.imsave(img_E, os.path.join(result_folder, output_image_name))

    print(f"Images have been processed and saved to {result_folder}")
else:
    print("Selected model is not available.")


In [None]:
#@title ##**Visualize** { display-mode: "form" }
from PIL import Image
import numpy as np
import os
from IPython.display import display

filenames_upload = os.listdir(upload_folder)
filenames_upload.sort()

filenames_upload_output = os.listdir(result_folder)
filenames_upload_output.sort()

for filename, filename_output in zip(filenames_upload, filenames_upload_output):
    image_original = PIL.Image.open(os.path.join(upload_folder, filename))
    image_restore = PIL.Image.open(os.path.join(result_folder, filename_output))

    # Resize images to have a maximum width of 600 pixels
    max_width = 500
    width_original, height_original = image_original.size
    width_restore, height_restore = image_restore.size
    if width_original > max_width:
        new_height = int(height_original * max_width / width_original)
        image_original = image_original.resize((max_width, new_height))
    if width_restore > max_width:
        new_height = int(height_restore * max_width / width_restore)
        image_restore = image_restore.resize((max_width, new_height))

    # Combine and display images
    display(PIL.Image.fromarray(np.hstack((np.array(image_original), np.array(image_restore)))))
    print("")

In [None]:
#@title ##**Download results** { display-mode: "form" }
#output_folder = os.path.join(upload_output_path, "final_output")
files_in_folder = os.listdir(result_folder)
zip_file = "download.zip"

if len(files_in_folder) == 1:
    file_to_download = os.path.join(result_folder, files_in_folder[0])
    files.download(file_to_download)
else:
    if os.path.exists(os.path.join(result_folder, zip_file)):
      os.remove(os.path.join(result_folder, zip_file))
    os.system(f"cd {result_folder} && zip -r -j {zip_file} * && cd ..")
    files.download(os.path.join(result_folder, zip_file))