# Causal-IR-DIR áp dụng trên McMaster dataset bằng $\text{DIR}_{\text{sf}}$

In [1]:
import os
from zipfile import ZipFile

import cv2
import numpy as np

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
from torch.utils import data

  from .autonotebook import tqdm as notebook_tqdm


## Set-up về việc có dùng Google Colab hay không

Nếu đang chạy trên Google Colab thì path đến file ZIP của dataset sẽ được cập nhật theo shortcut trên Google Drive. (Xem thêm trong file `README.md`)

In [3]:
is_gcolab: bool = False
try:
    if str(get_ipython()) == 'google.colab':
        is_gcolab = True
except NameError:
    pass

if is_gcolab:
    print('This notebook is running on Google Colab')
else:
    print('This notebook is not running on Google Colab')

This notebook is not running on Google Colab


In [4]:
if is_gcolab:
    from google.colab import drive
    drive.mount('/gdrive')
    print('Mounted from Google Drive')
else:
    print('Because this is not running with Google Colab, '+
          'data will not be mounted from Google Drive')

Because this is not running with Google Colab, data will not be mounted from Google Drive


## Load dữ liệu

McMaster dataset

In [13]:
# Set the path to ZIP file

data_zip_filepath = None
if is_gcolab:
    data_zip_filepath = '/gdrive/MyDrive/McM.zip'
else:
    data_zip_filepath = './McM.zip'

assert data_zip_filepath is not None
assert os.path.exists(data_zip_filepath)

print(data_zip_filepath)

./McM.zip


In [15]:
# Extract all from ZIP file

data_dirpath = './McM'

with ZipFile(data_zip_filepath, 'r') as zip_ref:
    zip_ref.extractall(data_dirpath, pwd='McM_CDM'.encode('utf-8'))

assert data_dirpath is not None
assert os.path.exists(data_dirpath)

print(data_dirpath)

./McM


## Tiền xử lý

### Cắt ảnh ra theo patch 256x256

In [19]:
# %run -i generate_cropped_DF2K.py

./McM/McM done!


In [None]:
import cv2
import os

folder = './McM/McM'
dest_folder = './McM-cropped'

if not os.path.exists(dest_folder):
    os.makedirs(dest_folder)

patchsize = 256
stride = 256

count = 1

for img_n in sorted(os.listdir(folder)):
    img = cv2.imread(os.path.join(folder, img_n))
    h, w, _ = img.shape
    h_number = h // patchsize
    w_number = w // patchsize
    for i in range(h_number):
        for j in range(w_number):
            start_ij_l = j * stride
            start_ij_u = i * stride
            end_ij_l = start_ij_l + stride
            end_ij_u = start_ij_u + stride
            img_crop = img[start_ij_u:end_ij_u, start_ij_l:end_ij_l]
            cv2.imwrite(os.path.join(dest_folder, '{:0>6d}.png'.format(count)), img_crop)
            count += 1
print("{} done!".format(folder))


## Không biết đang làm gì

### srdata_noise.py

In [None]:
import os
import glob

import numpy as np
import torch
import torch.utils.data as data
import random
import cv2


class DataCrop(data.Dataset):
    def __init__(self, choose, hr_folder, patch_size=64):
        self.patch_size = patch_size
        self.dir_hr = hr_folder
        self.images_hr = sorted(glob.glob(os.path.join(self.dir_hr, '*.png')))
        self.choose = (choose + 1) * 5  # 5, 10, 15, 20

    def __getitem__(self, idx):
        filename = self.images_hr[idx].split('/')[-1]

        hr = cv2.imread(os.path.join(self.dir_hr, filename))  # BGR, n_channels=3        
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB)  # RGB, n_channels=3

        croph = np.random.randint(0, 256 - self.patch_size)
        cropw = np.random.randint(0, 256 - self.patch_size)
        hr = hr[croph: croph+self.patch_size, cropw: cropw+self.patch_size, :]

        mode = np.random.randint(0, 8)
        hr = augment_img(hr, mode=mode)

        hr = hr.astype(np.float32) / 255.
        lr = hr.copy()

        noise = np.random.randn(*hr.shape) * self.choose / 255.

        lr += noise

        lr = np.clip(lr, 0, 1).astype(np.float32)

        lr = torch.from_numpy(np.ascontiguousarray(lr.transpose(2, 0, 1)))
        hr = torch.from_numpy(np.ascontiguousarray(hr.transpose(2, 0, 1)))

        return lr, hr

    def __len__(self):
        return len(self.images_hr)


class DataTest(data.Dataset):
    def __init__(self, hr_folder='default', level=50):

        self.dir_hr = 'Set5/HR' if hr_folder == 'default' else hr_folder
        self.name_hr = sorted(os.listdir(self.dir_hr))
        self.level = level

    def __getitem__(self, idx):
        name = self.name_hr[idx]
        hr = cv2.cvtColor(cv2.imread(os.path.join(self.dir_hr, name)), cv2.COLOR_BGR2RGB)

        hr = hr.astype(np.float32) / 255.
        lr = hr.copy()

        noise = np.random.randn(*hr.shape) * self.level / 255.

        lr += noise

        lr = np.clip(lr, 0, 1).astype(np.float32)

        lr = torch.from_numpy(np.ascontiguousarray(lr.transpose(2, 0, 1)))
        hr = torch.from_numpy(np.ascontiguousarray(hr.transpose(2, 0, 1)))

        return lr, hr, name

    def __len__(self):
        return len(self.name_hr)


def augment_img(img, mode=0):
    '''Kai Zhang (github: https://github.com/cszn)
    '''
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(np.rot90(img))
    elif mode == 2:
        return np.flipud(img)
    elif mode == 3:
        return np.rot90(img, k=3)
    elif mode == 4:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 5:
        return np.rot90(img)
    elif mode == 6:
        return np.rot90(img, k=2)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))

### utils_logger.py

In [None]:
import sys
import datetime
import logging


'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''


def log(*args, **kwargs):
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


'''
# --------------------------------------------
# logger
# --------------------------------------------
'''


def logger_info(logger_name, log_path='default_logger.log', mode='a'):
    ''' set up logger
    modified by Kai Zhang (github: https://github.com/cszn)
    '''
    log = logging.getLogger(logger_name)
    if log.hasHandlers():
        print('LogHandlers exist!')
    else:
        print('LogHandlers setup!')
        level = logging.INFO
        formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
        fh = logging.FileHandler(log_path, mode=mode)
        fh.setFormatter(formatter)
        log.setLevel(level)
        log.addHandler(fh)
        # print(len(log.handlers))

        sh = logging.StreamHandler()
        sh.setFormatter(formatter)
        log.addHandler(sh)


'''
# --------------------------------------------
# print to file and std_out simultaneously
# --------------------------------------------
'''


class logger_print(object):
    def __init__(self, log_path="default.log"):
        self.terminal = sys.stdout
        self.log = open(log_path, 'a')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  # write the message

    def flush(self):
        pass


### util_calculate_psnr_ssim.py

In [None]:
import cv2
import numpy as np
import torch


def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate PSNR (Peak Signal-to-Noise Ratio).

    Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the PSNR calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: psnr result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20. * np.log10(255. / np.sqrt(mse))


def _ssim(img1, img2):
    """Calculate SSIM (structural similarity) for one channel images.

    It is called by func:`calculate_ssim`.

    Args:
        img1 (ndarray): Images with range [0, 255] with order 'HWC'.
        img2 (ndarray): Images with range [0, 255] with order 'HWC'.

    Returns:
        float: ssim result.
    """

    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate SSIM (structural similarity).

    Ref:
    Image quality assessment: From error visibility to structural similarity

    The results are the same as that of the official released MATLAB code in
    https://ece.uwaterloo.ca/~z70wang/research/ssim/.

    For three-channel images, SSIM is calculated for each channel and then
    averaged.

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the SSIM calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: ssim result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    ssims = []
    for i in range(img1.shape[2]):
        ssims.append(_ssim(img1[..., i], img2[..., i]))
    return np.array(ssims).mean()


def _blocking_effect_factor(im):
    block_size = 8

    block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
    block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)

    horizontal_block_difference = (
                (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
        3).sum(2).sum(1)
    vertical_block_difference = (
                (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
        2).sum(1)

    nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
    nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)

    horizontal_nonblock_difference = (
                (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
        3).sum(2).sum(1)
    vertical_nonblock_difference = (
                (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
        3).sum(2).sum(1)

    n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
    n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
    boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
                n_boundary_horiz + n_boundary_vert)

    n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
    n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
    nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
                n_nonboundary_horiz + n_nonboundary_vert)

    scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
    bef = scaler * (boundary_difference - nonboundary_difference)

    bef[boundary_difference <= nonboundary_difference] = 0
    return bef


def calculate_psnrb(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate PSNR-B (Peak Signal-to-Noise Ratio).

    Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
    # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the PSNR calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: psnr result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
    img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
    img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.

    total = 0
    for c in range(img1.shape[1]):
        mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
        bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])

        mse = mse.view(mse.shape[0], -1).mean(1)
        total += 10 * torch.log10(1 / (mse + bef))

    return float(total) / img1.shape[1]


def reorder_image(img, input_order='HWC'):
    """Reorder images to 'HWC' order.

    If the input_order is (h, w), return (h, w, 1);
    If the input_order is (c, h, w), return (h, w, c);
    If the input_order is (h, w, c), return as it is.

    Args:
        img (ndarray): Input image.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            If the input image shape is (h, w), input_order will not have
            effects. Default: 'HWC'.

    Returns:
        ndarray: reordered image.
    """

    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
    if len(img.shape) == 2:
        img = img[..., None]
    if input_order == 'CHW':
        img = img.transpose(1, 2, 0)
    return img


def to_y_channel(img):
    """Change to Y channel of YCbCr.

    Args:
        img (ndarray): Images with range [0, 255].

    Returns:
        (ndarray): Images with range [0, 255] (float type) without round.
    """
    img = img.astype(np.float32) / 255.
    if img.ndim == 3 and img.shape[2] == 3:
        img = bgr2ycbcr(img, y_only=True)
        img = img[..., None]
    return img * 255.


def _convert_input_type_range(img):
    """Convert the type and range of the input image.

    It converts the input image to np.float32 type and range of [0, 1].
    It is mainly used for pre-processing the input image in colorspace
    convertion functions such as rgb2ycbcr and ycbcr2rgb.

    Args:
        img (ndarray): The input image. It accepts:
            1. np.uint8 type with range [0, 255];
            2. np.float32 type with range [0, 1].

    Returns:
        (ndarray): The converted image with type of np.float32 and range of
            [0, 1].
    """
    img_type = img.dtype
    img = img.astype(np.float32)
    if img_type == np.float32:
        pass
    elif img_type == np.uint8:
        img /= 255.
    else:
        raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
    return img


def _convert_output_type_range(img, dst_type):
    """Convert the type and range of the image according to dst_type.

    It converts the image to desired type and range. If `dst_type` is np.uint8,
    images will be converted to np.uint8 type with range [0, 255]. If
    `dst_type` is np.float32, it converts the image to np.float32 type with
    range [0, 1].
    It is mainly used for post-processing images in colorspace convertion
    functions such as rgb2ycbcr and ycbcr2rgb.

    Args:
        img (ndarray): The image to be converted with np.float32 type and
            range [0, 255].
        dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
            converts the image to np.uint8 type with range [0, 255]. If
            dst_type is np.float32, it converts the image to np.float32 type
            with range [0, 1].

    Returns:
        (ndarray): The converted image with desired type and range.
    """
    if dst_type not in (np.uint8, np.float32):
        raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
    if dst_type == np.uint8:
        img = img.round()
    else:
        img /= 255.
    return img.astype(dst_type)


def bgr2ycbcr(img, y_only=False):
    """Convert a BGR image to YCbCr image.

    The bgr version of rgb2ycbcr.
    It implements the ITU-R BT.601 conversion for standard-definition
    television. See more details in
    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.

    It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
    In OpenCV, it implements a JPEG conversion. See more details in
    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.

    Args:
        img (ndarray): The input image. It accepts:
            1. np.uint8 type with range [0, 255];
            2. np.float32 type with range [0, 1].
        y_only (bool): Whether to only return Y channel. Default: False.

    Returns:
        ndarray: The converted YCbCr image. The output image has the same type
            and range as input image.
    """
    img_type = img.dtype
    img = _convert_input_type_range(img)
    if y_only:
        out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
    else:
        out_img = np.matmul(
            img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
    out_img = _convert_output_type_range(out_img, img_type)
    return out_img


### RRDB.py

In [None]:
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F


def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return nn.Sequential(*layers)


class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.init_weights()

    def init_weights(self):
        """Init weights for ResidualDenseBlock.

        Use smaller std for better stability and performance. We empirically
        use 0.1. See more details in "ESRGAN: Enhanced Super-Resolution
        Generative Adversarial Networks"
        """
        for i in range(5):
            default_init_weights(getattr(self, f'conv{i+1}'), 0.1)

        # initialization
        # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x


def default_init_weights(module, scale=1):
    """Initialize network weights.

    Args:
        modules (nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks.
    """
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
            m.weight.data *= scale
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
            m.weight.data *= scale


class RRDB(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf=64, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
        self.RDB3 = ResidualDenseBlock_5C(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf=64, nb=23, gc=32):
        super(RRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        # self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        # self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        # fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        # fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out


class RRDBNetX4(nn.Module):
    def __init__(self, in_nc, out_nc, nf=64, nb=23, gc=32):
        super(RRDBNetX4, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

### DIL_sf_noise.py

In [None]:
import argparse
import os

import numpy as np
import cv2

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils import data
from torch import distributed as dist
import torch.optim as optim
import srdata_noise
import utils_logger
import logging
import util_calculate_psnr_ssim as util

from RRDB import RRDBNet


def synchronize():
    if not dist.is_available():
        return

    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()

    if world_size == 1:
        return

    dist.barrier()


def parse_args():
    parser = argparse.ArgumentParser(description='Train an editor')

    parser.add_argument('--gpus', type=int, default=1, help='number of gpus to use')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument(
        "--ckpt_save",
        type=str,
        default=None,
        help="path to save checkpoints",
    )
    parser.add_argument(
        "--resume",
        type=str,
        default=None,
        help="path to checkpoints for pretrained model",
    )
    parser.add_argument(
        '--distributed',
        action='store_true'
    )
    parser.add_argument(
        "--local_rank", type=int, default=0, help="local rank for distributed training"
    )
    parser.add_argument('--trainset', type=str, help='path to the train set')
    parser.add_argument('--patch_size', type=int, default=64)
    parser.add_argument('--testset', type=str, default='default', help='path to the test set, default is Set5')

    parser.add_argument('--save_every', type=int, default=1, help='save weights')
    parser.add_argument('--test_every', type=int, default=5, help='save weights')
    parser.add_argument('--print_every', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training')
    parser.add_argument('--num_workder', type=int, default=8)
    parser.add_argument('--total_epoch', type=int, default=30)

    args = parser.parse_args()

    return args

def data_sampler(dataset, shuffle=True, distributed=True):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)

def point_grad_to(meta_net, task_net):
    '''
    Set .grad attribute of each parameter to be proportional
    to the difference between self and target
    '''
    for meta_p, task_p in zip(meta_net.parameters(), task_net.parameters()):
        if meta_p.grad is None:
            meta_p.grad = torch.zeros(meta_p.size()).cuda()
        # meta_p.grad.data.zero_()  # not required as optimizer.zero_grad
        meta_p.grad.data.add_(meta_p.data - task_p.data)


def main():
    
    args = parse_args()

    ## initialize training folder
    checkpoint_save_path = args.ckpt_save
    if not os.path.exists(checkpoint_save_path):
        os.makedirs(checkpoint_save_path, exist_ok=True)

    logger_name = 'train'
    utils_logger.logger_info(logger_name, os.path.join(checkpoint_save_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    ## initialize DDP training
    if args.distributed:
        torch.cuda.set_deviRRce(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
    
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    ## initialize model and optimizer
    model_task = RRDBNet(in_nc=3, out_nc=3).to('cuda')
    model_meta = RRDBNet(in_nc=3, out_nc=3).to('cuda')

    optimizer_task = optim.Adam([p for p in model_task.parameters() if p.requires_grad], lr=1.e-4, betas=(0, 0.999))
    optimizer_meta = optim.Adam([p for p in model_meta.parameters() if p.requires_grad], lr=1.e-4)

    if args.resume is not None:
        print("load model: ", args.resume)
        ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage)
        model_task.load_state_dict(ckpt['model_task'])
        model_meta.load_state_dict(ckpt['model_meta'])

    
    loss_fn = torch.nn.L1Loss()
    loss_fn = loss_fn.to('cuda')

    ## for gaussian denoising, we set task number to 4
    dataset_list = []
    for i in range(4):
        dataset_list.append(srdata_noise.DataCrop(i, hr_folder=args.trainset, patch_size=args.patch_size))

    testset = srdata_noise.DataTest(hr_folder=args.testset, level=50)  # you can try 50, 70 ...

    dataloader_test = data.DataLoader(
        testset, 
        batch_size=1,
        sampler=data_sampler(testset, shuffle=False, distributed=False),
        num_workers=1,
        pin_memory=True
    )

    dataloader_list = [
        data.DataLoader(
        trainset, 
        batch_size=args.batch_size,
        sampler=data_sampler(trainset, shuffle=True, distributed=args.distributed),
        num_workers=args.num_workder,
        pin_memory=True,
        drop_last=True
        )
        for trainset in dataset_list
    ]


    if args.distributed:
        model_task = DistributedDataParallel(
            model_task,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=True,
        )
        model_meta = DistributedDataParallel(
            model_meta,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=True,
        )


    total_epochs = args.total_epoch
    state_task = None

    for epoch in range(total_epochs):

        if epoch and not (epoch % 20):
            for param in optimizer_meta.param_groups:
                param['lr'] = (param['lr'] * 0.5) if param['lr'] > 1.e-6 else 1.e-6
            sche = True

        learning_rate_f = optimizer_task.param_groups[0]['lr']
        learning_rate_s = optimizer_meta.param_groups[0]['lr']

        data_len = len(dataset_list[0])

        data_loader_train = [iter(dataloader) for dataloader in dataloader_list]

        random_list = [0, 1, 2, 3]
        np.random.seed(1)  # to control random_list is same on every gpus.

        for iteration in range(data_len // (args.batch_size * args.gpus)):
            model_task.load_state_dict(model_meta.state_dict())

            if state_task is not None:
                optimizer_task.load_state_dict(state_task)

            np.random.shuffle(random_list)

            for ind in random_list:
                dl = data_loader_train[ind]

                lr, hr = dl.next()
                    
                optimizer_task.zero_grad()
                lr = lr.to('cuda')
                hr = hr.to('cuda')
                sr = model_task(lr)
                loss = loss_fn(sr, hr)
                loss_print = loss.item()
                loss.backward()
                optimizer_task.step()

                if torch.cuda.current_device() == 0 and not iteration % args.print_every:
                    logger.info('Epoch: {}\tIter: {}/{}\tTask loss: {}\tTask LR: {:.6f}\tMeta LR: {:.6f}'.format(epoch, iteration, data_len // (args.batch_size * args.gpus), loss_print, learning_rate_f, learning_rate_s))

            state_task = optimizer_task.state_dict()
            optimizer_meta.zero_grad()
            point_grad_to(model_meta, model_task)
            optimizer_meta.step()
            
            if torch.cuda.current_device() == 0 and not iteration % args.print_every:
                logger.info('Meta net updated!')

        # save model
        if not epoch % args.save_every and torch.cuda.current_device() == 0:
            m_task = model_task.module if args.distributed else model_task
            m_meta = model_meta.module if args.distributed else model_meta
            model_meta_dict = m_meta.state_dict()
            model_task_dict = m_task.state_dict()
            torch.save(
                {
                    'model_meta': model_meta_dict,
                    'model_task': model_task_dict,
                },
                os.path.join(checkpoint_save_path, 'model_{}.pt'.format(epoch+1))
            )
        # test model
        if not epoch % args.test_every and torch.cuda.current_device() == 0:
                model_meta.eval()
                p = 0
                s = 0
                count = 0
                
                for lr, hr, filename in dataloader_test:
                    count += 1
                    lr = lr.to('cuda')
                    filename = filename[0]
                    with torch.no_grad():
                        sr = model_meta(lr)
                    sr = sr.detach().cpu().squeeze(0).numpy().transpose(1, 2, 0)
                    sr = sr * 255.
                    sr = np.clip(sr.round(), 0, 255).astype(np.uint8)
                    hr = hr.squeeze(0).numpy().transpose(1, 2, 0)
                    hr = hr * 255.
                    hr = np.clip(hr.round(), 0, 255).astype(np.uint8)

                    sr = cv2.cvtColor(sr, cv2.COLOR_RGB2BGR)
                    hr = cv2.cvtColor(hr, cv2.COLOR_RGB2BGR)
                    psnr = util.calculate_psnr(sr, hr, crop_border=0)
                    ssim = util.calculate_ssim(sr, hr, crop_border=0)
                    p += psnr
                    s += ssim
                    logger.info('{}: {}, {}'.format(filename, psnr, ssim))

                p /= count
                s /= count
                logger.info("Epoch: {}, psnr: {}. ssim: {}.".format(epoch, p, s))
                
                model_meta.train()
    
    
    logger.info('Done')

if __name__ == '__main__':
    main()

## Xây dựng mô hình sử dụng $\text{DIR}_{\text{sf}}$

## Huấn luyện


## Đánh giá