<a href="https://colab.research.google.com/github/goslim/aku/blob/main/Real_ESRGAN_photo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title ##**Install** { display-mode: "form" }
%%capture
!nvidia-smi
!git clone https://github.com/xinntao/Real-ESRGAN.git
%cd Real-ESRGAN
# Set up the environment
!pip install basicsr
!pip install facexlib
!pip install gfpgan
!pip install -r requirements.txt
!python setup.py develop

import shutil
from tqdm import tqdm
import os
import shutil, sys
import re
import io
import IPython.display
import numpy as np
import PIL.Image
from google.colab import files
import shutil

# URL новой модели
new_model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'

# Открываем файл локально и читаем его содержимое
filename = '/content/Real-ESRGAN/inference_realesrgan.py'
with open(filename, 'r') as f:
    script_content = f.read()

# Изменяем путь к модели в строке скрипта
new_script_content = re.sub(r"(model_path\s*=\s*[\"\']).*?([\"\'])", rf"\g<1>{new_model_path}\g<2>", script_content)

# Записываем измененный скрипт в файл
with open(filename, 'w') as f:
    f.write(new_script_content)

!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/lib/python3.10/dist-packages/basicsr/data/degradations.py

degradations_code = '''import cv2
import math
import numpy as np
import random
import torch
from torch.nn import functional as F

def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, noise_gray_prob=0, clip=True, rounds=False):
    noise_sigma = random.uniform(*sigma_range)

    if random.random() < gray_prob:
        img = rgb_to_grayscale(img)

    if random.random() < noise_gray_prob:
        noise = torch.randn(*img.shape[1:], device=img.device) * noise_sigma
        noise = noise.unsqueeze(0).repeat(img.shape[0], 1, 1)
    else:
        noise = torch.randn_like(img) * noise_sigma

    out = img + noise

    if clip and rounds:
        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = torch.clamp(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.

    return out

def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
    scale = random.uniform(*scale_range)

    if random.random() < gray_prob:
        img = rgb_to_grayscale(img)

    noise = torch.poisson(img * scale) / scale - img
    out = img + noise

    if clip and rounds:
        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = torch.clamp(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.

    return out

def rgb_to_grayscale(img):
    if img.shape[0] != 3:
        raise ValueError('Input image must have 3 channels')
    rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=img.device)
    grayscale = torch.sum(img * rgb_weights.view(-1, 1, 1), dim=0, keepdim=True)
    return grayscale

def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
    if pad_to == 0:
        pad_to = kernel_size
    assert pad_to >= kernel_size, 'Pad size must be larger than kernel size'

    def _scaled_sinc(x):
        if x == 0:
            return torch.tensor(1.)
        x = x * math.pi
        return torch.sin(x) / x

    half_size = (kernel_size - 1) / 2.
    grid = torch.linspace(-half_size, half_size, kernel_size)
    x, y = torch.meshgrid(grid, grid)
    dist = torch.sqrt(x**2 + y**2)

    kernel = _scaled_sinc(dist * cutoff)
    kernel = kernel / kernel.sum()

    if pad_to > kernel_size:
        pad = (pad_to - kernel_size) // 2
        kernel = F.pad(kernel, [pad] * 4)

    return kernel

def random_mixed_kernels(
        kernel_list,
        kernel_prob,
        kernel_size=21,
        blur_sigma=0.1,
        blur_sigma_min=0.1,
        blur_sigma_max=10.0,
        blur_kernel_size=21,
        pad_to=0):
    num_kernels = len(kernel_list)
    kernel_type = np.random.choice(kernel_list, p=kernel_prob)

    if pad_to == 0:
        pad_to = kernel_size

    if kernel_type == 'iso':
        sigma = np.random.uniform(blur_sigma_min, blur_sigma_max)
        kernel = _generate_isotropic_gaussian_kernel(kernel_size, sigma, pad_to)
    elif kernel_type == 'aniso':
        sigma_x = np.random.uniform(blur_sigma_min, blur_sigma_max)
        sigma_y = np.random.uniform(blur_sigma_min, blur_sigma_max)
        rotation = np.random.uniform(-np.pi, np.pi)
        kernel = _generate_anisotropic_gaussian_kernel(kernel_size, sigma_x, sigma_y, rotation, pad_to)
    else:  # general
        kernel = circular_lowpass_kernel(blur_sigma, kernel_size, pad_to)

    return kernel

def _generate_isotropic_gaussian_kernel(kernel_size=21, sigma=0.1, pad_to=0):
    if pad_to == 0:
        pad_to = kernel_size

    half_size = (kernel_size - 1) / 2.
    grid = torch.linspace(-half_size, half_size, kernel_size)
    x, y = torch.meshgrid(grid, grid)
    kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
    kernel = kernel / kernel.sum()

    if pad_to > kernel_size:
        pad = (pad_to - kernel_size) // 2
        kernel = F.pad(kernel, [pad] * 4)

    return kernel

def _generate_anisotropic_gaussian_kernel(kernel_size=21, sigma_x=0.1, sigma_y=0.1, rotation=0, pad_to=0):
    if pad_to == 0:
        pad_to = kernel_size

    half_size = (kernel_size - 1) / 2.
    grid = torch.linspace(-half_size, half_size, kernel_size)
    x, y = torch.meshgrid(grid, grid)

    cos_theta = torch.cos(torch.tensor(rotation))
    sin_theta = torch.sin(torch.tensor(rotation))
    x_rot = cos_theta * x - sin_theta * y
    y_rot = sin_theta * x + cos_theta * y

    kernel = torch.exp(-(x_rot**2 / (2 * sigma_x**2) + y_rot**2 / (2 * sigma_y**2)))
    kernel = kernel / kernel.sum()

    if pad_to > kernel_size:
        pad = (pad_to - kernel_size) // 2
        kernel = F.pad(kernel, [pad] * 4)

    return kernel
'''

# Записываем новое содержимое в файл
degradations_path = '/usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py'

# Создаем резервную копию
!cp -n {degradations_path} {degradations_path}.backup

# Записываем новый код
with open(degradations_path, 'w') as f:
    f.write(degradations_code)


In [None]:
#@title ##**Upload images** { display-mode: "form" }
%cd /content/Real-ESRGAN
#input_folder = "test_images/old"
#output_folder = "output"

upload_folder = "/content/Real-ESRGAN/upload"
result_folder = "/content/Real-ESRGAN/results"

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.makedirs(upload_folder)

basepath = os.getcwd()
uploaded = files.upload()
for filename in uploaded.keys():
    shutil.move(os.path.join(basepath, filename), os.path.join(upload_folder, filename))

/content/Real-ESRGAN


In [None]:
#@title ##**Run** { display-mode: "form" }

if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder)

model_name = "RealESRGAN_x2plus" #@param ["RealESRGAN_x2plus","RealESRGAN_x4plus","RealESRNet_x4plus","realesr-general-x4v3","RealESRGAN_x4plus_anime_6B","realesr-animevideov3"]
scale = "2" #@param ["1","2","3","4"]
face_enhance = "Yes" #@param ["Yes","No"]

command = "python inference_realesrgan.py -n {} -i upload --outscale {}".format(model_name, scale)
if face_enhance == "Yes":
    command += " --face_enhance"

print(command)
os.system(command)
#!python inference_realesrgan.py -n RealESRGAN_x2plus -i upload --outscale 2

In [None]:
#@title ##**Visualize** { display-mode: "form" }
import os
import PIL.Image
import numpy as np
from IPython.display import HTML, display
import base64
from io import BytesIO

def is_image_file(filename):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

def resize_image_maintain_aspect(image, max_width, target_height=None):
    width, height = image.size
    if width > max_width:
        new_height = int(height * max_width / width)
        image = image.resize((max_width, new_height), PIL.Image.Resampling.LANCZOS)
    if target_height is not None and image.size[1] != target_height:
        new_width = int(image.size[0] * target_height / image.size[1])
        image = image.resize((new_width, target_height), PIL.Image.Resampling.LANCZOS)
    return image

def image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

visualization_method = "Side-by-Side" #@param ["Side-by-Side", "Slider"]

filenames_upload = sorted([f for f in os.listdir(upload_folder) if is_image_file(f)])
filenames_upload_output = sorted([f for f in os.listdir(result_folder) if is_image_file(f)])

if not filenames_upload or not filenames_upload_output:
    print(f"Error: No images found in {upload_folder} or {result_folder}.")
else:
    for filename, filename_output in zip(filenames_upload, filenames_upload_output):
        try:
            image_original = PIL.Image.open(os.path.join(upload_folder, filename))
            image_restore = PIL.Image.open(os.path.join(result_folder, filename_output))

            if visualization_method == "Side-by-Side":
                max_width = 500
                image_original = resize_image_maintain_aspect(image_original, max_width)
                image_restore = resize_image_maintain_aspect(image_restore, max_width)
                target_height = min(image_original.size[1], image_restore.size[1])
                image_original = resize_image_maintain_aspect(image_original, max_width, target_height)
                image_restore = resize_image_maintain_aspect(image_restore, max_width, target_height)

                combined_width = image_original.size[0] + image_restore.size[0]
                combined_image = PIL.Image.new('RGB', (combined_width, target_height))
                combined_image.paste(image_original, (0, 0))
                combined_image.paste(image_restore, (image_original.size[0], 0))
                display(combined_image)

            else:
                max_width = min(image_restore.size[0], 1000)
                image_restore = resize_image_maintain_aspect(image_restore, max_width)
                target_height = image_restore.size[1]
                image_original = resize_image_maintain_aspect(image_original, max_width, target_height)

                if image_original.mode != 'RGB':
                    image_original = image_original.convert('RGB')
                if image_restore.mode != 'RGB':
                    image_restore = image_restore.convert('RGB')

                original_base64 = image_to_base64(image_original)
                restore_base64 = image_to_base64(image_restore)

                html_code = f"""
                <div style="position: relative; width: {image_restore.size[0]}px; height: {image_restore.size[1]}px; margin-bottom: 20px;">
                    <div style="position: relative; width: 100%; height: 100%; overflow: hidden;">
                        <img src="data:image/png;base64,{original_base64}" style="position: absolute; width: 100%; height: 100%;">
                        <div style="position: absolute; width: 100%; height: 100%; overflow: hidden; clip-path: inset(0 0 0 50%);">
                            <img src="data:image/png;base64,{restore_base64}" style="position: absolute; width: 100%; height: 100%;">
                        </div>
                    </div>
                    <div class="slider" style="position: absolute; top: 0; bottom: 0; width: 4px; background: white; cursor: ew-resize; left: 50%; box-shadow: 0 0 5px rgba(0,0,0,0.5);">
                        <div style="position: absolute; top: 50%; transform: translateY(-50%); width: 20px; height: 20px; background: white; border-radius: 50%; left: -8px;"></div>
                    </div>
                </div>
                <script>
                    document.querySelectorAll('.slider').forEach(slider => {{
                        let isDragging = false;
                        const container = slider.parentElement.querySelector('div:nth-child(1)');
                        const clipDiv = container.querySelector('div');

                        slider.addEventListener('mousedown', (e) => {{
                            isDragging = true;
                            e.preventDefault();
                        }});

                        document.addEventListener('mouseup', () => {{
                            isDragging = false;
                        }});

                        document.addEventListener('mousemove', (e) => {{
                            if (!isDragging) return;
                            const rect = container.getBoundingClientRect();
                            let x = e.clientX - rect.left;
                            if (x < 0) x = 0;
                            if (x > rect.width) x = rect.width;
                            const percentage = (x / rect.width) * 100;
                            slider.style.left = percentage + '%';
                            clipDiv.style.clipPath = `inset(0 0 0 ${{percentage}}%)`;
                        }});

                        slider.addEventListener('touchstart', (e) => {{
                            isDragging = true;
                            e.preventDefault();
                        }});

                        document.addEventListener('touchend', () => {{
                            isDragging = false;
                        }});

                        document.addEventListener('touchmove', (e) => {{
                            if (!isDragging) return;
                            const rect = container.getBoundingClientRect();
                            let x = e.touches[0].clientX - rect.left;
                            if (x < 0) x = 0;
                            if (x > rect.width) x = rect.width;
                            const percentage = (x / rect.width) * 100;
                            slider.style.left = percentage + '%';
                            clipDiv.style.clipPath = `inset(0 0 0 ${{percentage}}%)`;
                        }});
                    }});
                </script>
                """

                display(HTML(html_code))

        except Exception as e:
            print(f"Error processing {filename} and {filename_output}: {e}")

In [None]:
#@title ##**Download results** { display-mode: "form" }
#output_folder = os.path.join(upload_output_path, "final_output")
files_in_folder = os.listdir(result_folder)
zip_file = "download.zip"

if len(files_in_folder) == 1:
    file_to_download = os.path.join(result_folder, files_in_folder[0])
    files.download(file_to_download)
else:
    if os.path.exists(os.path.join(result_folder, zip_file)):
      os.remove(os.path.join(result_folder, zip_file))
    os.system(f"cd {result_folder} && zip -r -j {zip_file} * && cd ..")
    files.download(os.path.join(result_folder, zip_file))

In [None]:
#@title ##**Download results to google drive (optional)** { display-mode: "form" }
from google.colab import drive
drive.mount('/content/drive')

!cp "/content/Real-ESRGAN/results/download.zip" "/content/drive/MyDrive/"
