In [None]:
!git clone https://github.com/swz30/Restormer

In [1]:
import sys
import os

# Thêm đường dẫn vào sys.path
sys.path.append('restormers/pytorch/default/1')


In [2]:
import sys
sys.argv = sys.argv[:1]  # Loại bỏ đối số '-f' của Jupyter

In [3]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import os
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import numpy as np

def get_args():
    return {
        'input_dir': './ensemble_private_pseudo2',
        'result_dir': './final_output',
        'task': 'Real_Denoising',
        'tile': None,
        'tile_overlap': 32
    }

args = get_args()


def load_img(filepath):
    return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)


def save_img(filepath, img):
    cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))


def load_gray_img(filepath):
    return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)


def save_gray_img(filepath, img):
    cv2.imwrite(filepath, img)

## follow repo https://github.com/swz30/Restormer to download weights and store right folder
def get_weights_and_parameters(task, parameters):
    weights_path = {
        'Single_Image_Defocus_Deblurring': 'restomer-checkpoint/single_image_defocus_deblurring.pth',
        'Motion_Deblurring': 'restomer-checkpoint/motion_deblurring.pth',
        'Deraining': 'restomer-checkpoint/deraining.pth',
        'Real_Denoising': 'restomer-checkpoint/real_denoising.pth',
        'Gaussian_Color_Denoising': 'restomer-checkpoint/gaussian_color_denoising.pth',
        'Gaussian_Gray_Denoising': 'restomer-checkpoint/gaussian_gray_denoising.pth'
    }
    
    weights = weights_path.get(task, None)
    if weights is None:
        raise ValueError(f"Task '{task}' not recognized or weights not found.")
    
    if task in ['Real_Denoising', 'Gaussian_Color_Denoising', 'Gaussian_Gray_Denoising']:
        parameters['LayerNorm_type'] = 'BiasFree'
    if task == 'Gaussian_Gray_Denoising':
        parameters['inp_channels'] = parameters['out_channels'] = 1
    
    return weights, parameters


# Load model
parameters = {'inp_channels': 3, 'out_channels': 3, 'dim': 48, 'num_blocks': [4, 6, 6, 8], 'num_refinement_blocks': 4,
              'heads': [1, 2, 4, 8], 'ffn_expansion_factor': 2.66, 'bias': False, 'LayerNorm_type': 'WithBias',
              'dual_pixel_task': False}

weights, parameters = get_weights_and_parameters(args['task'], parameters)

load_arch = run_path(os.path.join('restormers/pytorch/default/1/basicsr/models/archs/restormer_arch.py'))
model = load_arch['Restormer'](**parameters)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval()


# Processing images
inp_dir = args['input_dir']
out_dir = os.path.join(args['result_dir'], args['task'])
os.makedirs(out_dir, exist_ok=True)

extensions = ['jpg', 'png', 'jpeg', 'bmp']
files = [f for ext in extensions for f in glob(os.path.join(inp_dir, f'*.{ext}'))]

if not files:
    raise Exception(f'No image files found in {inp_dir}')

img_multiple_of = 8

import time
start_time=time.time()
with torch.no_grad():
    for file_ in tqdm(files):
        if torch.cuda.is_available():
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()

        img = load_img(file_)
        input_ = torch.from_numpy(img).float().div(255.).permute(2, 0, 1).unsqueeze(0).to(device)

        height, width = input_.shape[2], input_.shape[3]
        H, W = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of, ((width + img_multiple_of) // img_multiple_of) * img_multiple_of
        padh, padw = H - height, W - width
        input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')

        restored = model(input_)
        restored = torch.clamp(restored, 0, 1)
        restored = restored[:, :, :height, :width]

        restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
        restored = img_as_ubyte(restored[0])

        f = os.path.splitext(os.path.basename(file_))[0]
        save_img(os.path.join(out_dir, f + '.png'), restored)

    print(f"\nRestored images saved at {out_dir}")
    print(time.time()-start_time)


# Now you can just run the cell without passing command-line arguments!


  checkpoint = torch.load(weights)
100%|██████████| 731/731 [16:36<00:00,  1.36s/it]


Restored images saved at /kaggle/working/ensemble_pseudo_deblur/Real_Denoising
996.1291255950928



