<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/DRCT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }
%%capture
!nvidia-smi
!git clone https://github.com/ming053l/DRCT.git
%cd DRCT

!pip install -r requirements.txt

!pip install basicsr

!pip install gdown pillow matplotlib

import os
import basicsr

basicsr_degradations_path = os.path.join(os.path.dirname(basicsr.__file__), 'data', 'degradations.py')

!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/g' {basicsr_degradations_path}

model_folder = '/content/DRCT/models'

if os.path.isdir(model_folder):
    shutil.rmtree(model_folder)
os.makedirs(model_folder)

import gdown


files = [
    ('1jw2UWAersWZecPq-c_g5RM3mDOoc_cbd', 'DRCT_X4.pth'),
    ('1bVxvA6QFbne2se0CQJ-jyHFy94UOi3h5', 'DRCT-L_X4'),
    ('1uLGwmSko9uF82X4OPOMw3xfM3stlnYZ-', 'net_g_latest.pth'),
    ('1rfV_ExLtfjdHygWGJ3VUYgyn9UkzSwbZ', 'net_g_latest (MSEModel).pth')
]

for file_id, filename in files:
    gdown.download(f'https://drive.google.com/uc?id={file_id}', output=f'/content/DRCT/models/{filename}', quiet=False)

In [None]:
#@title ##**Upload images** { display-mode: "form" }
import os
import shutil
from google.colab import files

upload_folder = "/content/DRCT/upload"
result_folder = "/content/DRCT/results"


if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.makedirs(upload_folder)

basepath = os.getcwd()
uploaded = files.upload()
for filename in uploaded.keys():
    shutil.move(os.path.join(basepath, filename), os.path.join(upload_folder, filename))

In [None]:
#@title ##**Run** { display-mode: "form" }
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import sys
import shutil
import traceback

if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder)

sys.path.append('/content/DRCT')

from drct.models.drct_model import DRCTModel

def get_model_config(model_type='1'):
    opt = {
        'task': 'sr',
        'scale': 4,
        'model_type': 'DRCTModel',
        'num_gpu': 1,
        'dist': False,
        'is_train': False,
        'datasets': {
            'test': {
                'name': 'SingleImageDataset',
                'dataroot_lq': None,
                'io_backend': {'type': 'disk'}
            }
        },
        'network_g': {
            'type': 'DRCT',
            'upscale': 4,
            'in_chans': 3,
            'img_size': 64,
            'window_size': 16,
            'img_range': 1.,
            'depths': [6, 6, 6, 6, 6, 6] if model_type == '1' else [6, 6, 6, 6, 6, 6, 6, 6],
            'embed_dim': 180 if model_type == '1' else 240,
            'num_heads': [6, 6, 6, 6, 6, 6] if model_type == '1' else [8, 8, 8, 8, 8, 8, 8, 8],
            'mlp_ratio': 2,
            'upsampler': 'pixelshuffle',
            'resi_connection': '1conv'
        },
        'path': {
            'pretrain_network_g': None,
            'strict_load_g': True,
            'param_key_g': 'params'
        }
    }
    return opt

def get_model_path(model_name):
    model_paths = {
        "DRCT_X4.pth": "/content/DRCT/models/DRCT_X4.pth",
        "DRCT-L_X4": "/content/DRCT/models/DRCT-L_X4",
        "net_g_latest.pth": "/content/DRCT/models/net_g_latest.pth",
        "net_g_latest (MSEModel).pth": "/content/DRCT/models/net_g_latest (MSEModel).pth"
    }
    return model_paths.get(model_name)

def load_and_preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')
    w_orig, h_orig = img.size
    print(f"Original image dimensions: {w_orig}x{h_orig}")

    new_w = ((w_orig + 15) // 16) * 16
    new_h = ((h_orig + 15) // 16) * 16

    delta_w = new_w - w_orig
    delta_h = new_h - h_orig

    padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)

    if delta_w != 0 or delta_h != 0:
        print(f"Adding transparent padding to dimensions: {new_w}x{new_h}")
        img = img.crop((-padding[0], -padding[1], w_orig + padding[2], h_orig + padding[3]))
    else:
        print("Image dimensions are already multiples of 16, no padding needed.")

    return img, padding, (w_orig, h_orig)

def split_image_into_tiles(img, tile_size):
    w, h = img.size
    tiles = []
    positions = []
    for y in range(0, h, tile_size):
        for x in range(0, w, tile_size):
            tile = img.crop((x, y, min(x + tile_size, w), min(y + tile_size, h)))
            tiles.append(tile)
            positions.append((x, y))
    return tiles, positions

def merge_tiles(tiles, positions, img_size):
    upscaled_img = Image.new('RGB', img_size)
    for tile, (x, y) in zip(tiles, positions):
        upscaled_img.paste(tile, (x, y))
    return upscaled_img

def postprocess_image(tensor):
    output = tensor.squeeze(0).cpu().clamp_(0, 1)
    output = transforms.ToPILImage()(output)
    return output

model_name = "net_g_latest (MSEModel).pth"  #@param ["DRCT_X4.pth", "DRCT-L_X4", "net_g_latest.pth", "net_g_latest (MSEModel).pth"]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Loading model to device: {device}")

opt = get_model_config('1' if model_name in ['DRCT_X4.pth', 'net_g_latest.pth', 'net_g_latest (MSEModel).pth'] else '2')
model = DRCTModel(opt)
model_path = get_model_path(model_name)

checkpoint = torch.load(model_path, map_location=device)

if 'params_ema' in checkpoint:
    model.net_g.load_state_dict(checkpoint['params_ema'], strict=True)
elif 'params' in checkpoint:
    model.net_g.load_state_dict(checkpoint['params'], strict=True)
else:
    model.net_g.load_state_dict(checkpoint, strict=True)

model.net_g = model.net_g.to(device)
model.net_g.eval()

transform = transforms.ToTensor()

for filename in os.listdir(upload_folder):
    input_path = os.path.join(upload_folder, filename)
    output_path = os.path.join(result_folder, f'upscaled_{filename}')

    try:
        print(f"\nProcessing {filename}")

        img, padding, (w_orig, h_orig) = load_and_preprocess_image(input_path)

        tile_size = 512
        tiles, positions = split_image_into_tiles(img, tile_size)

        upscaled_tiles = []
        for idx, tile in enumerate(tiles):
            print(f"Processing tile {idx+1}/{len(tiles)}")
            input_tensor = transform(tile).unsqueeze(0).to(device)

            with torch.no_grad():
                output = model.net_g(input_tensor)

            output_image = postprocess_image(output)
            upscaled_tiles.append(output_image)

            del input_tensor
            del output
            torch.cuda.empty_cache()

        scale = 4
        upscaled_w = img.size[0] * scale
        upscaled_h = img.size[1] * scale
        upscaled_img = merge_tiles(upscaled_tiles, [(x * scale, y * scale) for x, y in positions], (upscaled_w, upscaled_h))

        pad_left, pad_top, pad_right, pad_bottom = [p * scale for p in padding]
        upscaled_img = upscaled_img.crop((pad_left, pad_top, upscaled_w - pad_right, upscaled_h - pad_bottom))

        expected_w, expected_h = w_orig * scale, h_orig * scale
        upscaled_img = upscaled_img.resize((expected_w, expected_h), Image.LANCZOS)

        upscaled_img.save(output_path)
        print(f"Image successfully processed and saved to {output_path}")

        del img
        del upscaled_img
        del upscaled_tiles
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error processing {filename}: {str(e)}")
        traceback.print_exc()

In [None]:
#@title ##**Visualize** { display-mode: "form" }
import PIL.Image
import numpy as np
import os

def is_image_file(filename):
    # Common image file extensions
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

def resize_image_maintain_aspect(image, max_width, target_height=None):
    width, height = image.size
    if width > max_width:
        new_height = int(height * max_width / width)
        image = image.resize((max_width, new_height))

    # If target_height is specified, resize to match that height
    if target_height is not None and image.size[1] != target_height:
        new_width = int(image.size[0] * target_height / image.size[1])
        image = image.resize((new_width, target_height))

    return image

filenames_upload = [f for f in os.listdir(upload_folder) if is_image_file(f)]
filenames_upload.sort()

filenames_output = [f for f in os.listdir(result_folder) if is_image_file(f)]
filenames_output.sort()

for filename_in, filename_out in zip(filenames_upload, filenames_output):
    image_original = PIL.Image.open(os.path.join(upload_folder, filename_in))
    image_restore = PIL.Image.open(os.path.join(result_folder, filename_out))

    max_width = 500

    # First resize both images to max_width if needed
    image_original = resize_image_maintain_aspect(image_original, max_width)
    image_restore = resize_image_maintain_aspect(image_restore, max_width)

    # Get the minimum height between the two images
    target_height = min(image_original.size[1], image_restore.size[1])

    # Resize both images to have the same height
    image_original = resize_image_maintain_aspect(image_original, max_width, target_height)
    image_restore = resize_image_maintain_aspect(image_restore, max_width, target_height)

    # Convert images to RGB mode if they're not already
    if image_original.mode != 'RGB':
        image_original = image_original.convert('RGB')
    if image_restore.mode != 'RGB':
        image_restore = image_restore.convert('RGB')

    # Create the side-by-side comparison
    comparison = PIL.Image.fromarray(np.hstack((np.array(image_original), np.array(image_restore))))
    display(comparison)
    print("")

In [None]:
#@title ##**Download results** { display-mode: "form" }
import os
from google.colab import files
import zipfile

def is_image_file(filename):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

files_in_folder = [f for f in os.listdir(result_folder) if is_image_file(f)]
zip_file = "download.zip"

if len(files_in_folder) == 1:
    file_to_download = os.path.join(result_folder, files_in_folder[0])
    files.download(file_to_download)
else:
    zip_path = os.path.join(result_folder, zip_file)
    if os.path.exists(zip_path):
        os.remove(zip_path)

    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in files_in_folder:
            file_path = os.path.join(result_folder, file)
            zipf.write(file_path, file)

    files.download(zip_path)