In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#NoiseDegradation.py


In [2]:
import torch
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor

import random
import numpy as np

class NoiseDegradation(object):
    def __init__(self, args):
        super(NoiseDegradation, self).__init__()
        self.args = args
        self.toTensor = ToTensor()
        self.crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size),
        ])
    def _add_gaussian_noise(self, clean_patch, sigma):
        noise = np.random.randn(*clean_patch.shape)
        noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)

        return noisy_patch, clean_patch

    def _add_noise_degradation_by_level(self, clean_patch, degrade_type):
        degraded_patch = None
        if degrade_type == 0:
            # noise level (sigma) =15
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)
        elif degrade_type == 7:
            # noise level (sigma) =25
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)
        elif degrade_type == 8:
            # noise level (sigma) =50
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)

         # If degraded_patch is still None (meaning degrade_type wasn't 0, 7, or 8)
        # handle the case (raise an error, return the original, etc.)
        if degraded_patch is None:
            degraded_patch = clean_patch # or raise ValueError(f"Invalid degrade_type: {degrade_type}")

        return degraded_patch, clean_patch

    def add_noise_degradation(self,clean_patch,degrade_type = None):
        if degrade_type == 0 or degrade_type == 7 or degrade_type == 8:
            degrade_type= degrade_type
        else:
            degrade_type = random.choices([0,7,8])


        degraded_patch, _ = self._add_noise_degradation_by_level(clean_patch, degrade_type)
        return degraded_patch


#utils.py


In [3]:
import imageio
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import os
import argparse
import random
import torch


def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True

def saveImage(filename, image):
    imageTMP = np.clip(image * 255.0, 0, 255).astype('uint8')
    imageio.imwrite(filename, imageTMP)

def save_rgb (img, filename):

    img = np.clip(img, 0., 1.)
    if np.max(img) <= 1:
        img = img * 255

    img = img.astype(np.float32)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, img)


def load_img (filename, norm=True,):

    # original code: img = np.array(Image.open(filename).convert("RGB"))
    # changed to:
    img = np.array(Image.open(filename).convert("RGB").resize((320,480)))

    if norm:
        img = img / 255.
        img = img.astype(np.float32)
    return img

def plot_all (images, figsize=(20,10), axis='off', names=None):
    nplots = len(images)
    fig, axs = plt.subplots(1,nplots, figsize=figsize, dpi=80,constrained_layout=True)
    for i in range(nplots):
        axs[i].imshow(images[i])
        if names: axs[i].set_title(names[i])
        axs[i].axis(axis)
    plt.show()

def modcrop(img_in, scale=2):
    # img_in: Numpy, HWC or HW
    img = np.copy(img_in)
    if img.ndim == 2:
        H, W = img.shape
        H_r, W_r = H % scale, W % scale
        img = img[:H - H_r, :W - W_r]
    elif img.ndim == 3:
        H, W, C = img.shape
        H_r, W_r = H % scale, W % scale
        img = img[:H - H_r, :W - W_r, :]
    else:
        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
    return img

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


########## MODEL

def count_params(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable_params

#TrainDataset.py

In [4]:

from torch.utils.data import Dataset

import torchvision.transforms.functional as TF
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
from PIL import Image
import numpy as np
import os
import random

# from utils import load_img, modcrop
# from torchvision import transforms
# from training_utils.NoiseDegradation import NoiseDegradation



DEG_MAP = {
    "noise_15" : 0,
    "blur"     : 1,
    "rain"     : 2,
    "haze"     : 3,
    "lol"      : 4,
    "sr"       : 5,
    "en"       : 6,
    "noise_25" : 7,
    "noise_50" : 8
}

DEG2TASK = {
    "noise": "denoising",
    "blur" : "deblurring",
    "rain" : "deraining",
    "haze" : "dehazing",
    "lol"  : "lol",
    "sr"   : "sr",
    "en"   : "enhancement"
}

def crop_img(image, base=16):
    """
    Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others.
    """
    h = image.shape[0]
    w = image.shape[1]
    crop_h = h % base
    crop_w = w % base
    return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]

def data_augmentation(image, mode):
    if mode == 0:
        # original
        out = image.numpy()
    elif mode == 1:
        # flip up and down
        out = np.flipud(image)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(image)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(image)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(image, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(image, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(image, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(image, k=3)
        out = np.flipud(out)
    else:
        raise Exception('Invalid choice of image transformation')
    return out

def random_augmentation(*args):
    out = []
    flag_aug = random.randint(1, 7)
    for data in args:
        out.append(data_augmentation(data, flag_aug).copy())
    return out




################# DATASETS


class InstructIRTrainDataset(Dataset):
    """
    Dataset for Image Restoration having low-quality image and the reference image.
    Tasks: synthetic denoising, deblurring, super-res, etc.
    """

    def __init__(self, args):

        # assert len(hq_img_paths) == len(lq_img_paths)

        super(InstructIRTrainDataset, self).__init__()

        # self.hq_paths  = hq_img_paths
        # self.lq_paths  = lq_img_paths
        self.toTensor  = ToTensor()
        # self.val       = val
        # self.augs      = augmentations
        # self.name      = name
        # self.deg_name = deg_name
        # self.deg_class = deg_class
        self.noise_gradation_generator = NoiseDegradation(args)
        self.args = args
        self.de_type = args.de_type
        print(self.de_type)

        self._init_ids()
        self._merge_ids()

        self.crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size)
        ])

    def _crop_patch(self, img_1, img_2):
        H = img_1.shape[0]
        W = img_1.shape[1]
        ind_H = random.randint(0, H - self.args.patch_size)
        ind_W = random.randint(0, W - self.args.patch_size)

        patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
        patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]

        return patch_1, patch_2

    def _get_original_rain_name(self, rainy_name):
        og_name = rainy_name.split("rainy")[0] + 'original/norain-' + rainy_name.split('rain-')[-1]
        return og_name

    def _init_clean_image_for_noise_degradation(self):
        ref_file = self.args.data_file_dir + "clean_image_for_denoise.txt"
        temp_ids = []
        temp_ids+= [id_.strip() for id_ in open(ref_file)]
        clean_ids = []
        name_list = os.listdir(self.args.denoise_dir)
        clean_ids += [self.args.denoise_dir + id_ for id_ in name_list if id_.strip() in temp_ids]

        self.s15_ids = []
        self.s25_ids = []
        self.s50_ids = []

        if 'denoise_15' in self.de_type:
            self.s15_ids = [{"clean_id": x,"de_type":0} for x in clean_ids]
            self.s15_ids = self.s15_ids * 1
            random.shuffle(self.s15_ids)
            self.s15_counter = 0
        if 'denoise_25' in self.de_type:
            self.s25_ids = [{"clean_id": x,"de_type":7} for x in clean_ids]
            self.s25_ids = self.s25_ids * 1
            random.shuffle(self.s25_ids)
            self.s25_counter = 0
        if 'denoise_50' in self.de_type:
            self.s50_ids = [{"clean_id": x,"de_type":8} for x in clean_ids]
            self.s50_ids = self.s50_ids * 1
            random.shuffle(self.s50_ids)
            self.s50_counter = 0


        print(f"Noisy Sigma 15 images len: {len(self.s15_ids)}")
        print(f"Noisy Sigma 25 images len: {len(self.s25_ids)}\n")
        print(f"Noisy Sigma 50 images len: {len(self.s50_ids)}\n")


    def _init_rs_ids(self):
        temp_ids = []
        rs = self.args.data_file_dir + "/rainy.txt"
        temp_ids+= [self.args.derain_dir + id_.strip() for id_ in open(rs)]
        self.rs_ids = [{"clean_id":x,"de_type":2} for x in temp_ids]
        self.rs_ids = self.rs_ids * 1

        self.rl_counter = 0
        self.num_rl = len(self.rs_ids)
        print("Total Rainy Ids : {}".format(self.num_rl))


    def __len__(self):
        return len(self.dataset_ids)

    def _init_ids(self):
        if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type:
            self._init_clean_image_for_noise_degradation()
        if 'derain' in self.de_type:
            self._init_rs_ids()

        random.shuffle(self.de_type)

    def _merge_ids(self):
        self.dataset_ids = []
        if "denoise_15" in self.de_type:
            self.dataset_ids += self.s15_ids
            self.dataset_ids += self.s25_ids
            self.dataset_ids += self.s50_ids
        if "derain" in self.de_type:
            self.dataset_ids+= self.rs_ids

        print(f"Dataset_ids length: {len(self.dataset_ids)}")

    def __getitem__(self, idx):
        dataset = self.dataset_ids[idx]
        hq_path = dataset["clean_id"]
        deg_id = dataset["de_type"]

        if deg_id == 0 or deg_id == 7 or deg_id == 8:
            # noisy image removal
            if deg_id == 0:
                hq_path = dataset["clean_id"]
            elif deg_id == 7:
                hq_path = dataset["clean_id"]
            elif deg_id == 8:
                hq_path = dataset["clean_id"]

            clean_img = crop_img(np.array(Image.open(hq_path).convert('RGB')), base=16)
            clean_patch = self.crop_transform(clean_img)
            clean_patch= np.array(clean_patch)

            clean_name = hq_path.split("/")[-1].split('.')[0]

            clean_patch = random_augmentation(clean_patch)[0]

            degrad_patch = self.noise_gradation_generator.add_noise_degradation(clean_patch, deg_id)
        else:
            if deg_id == 2:
                # Rain Streak Removal
                    degrad_img = crop_img(np.array(Image.open(dataset["clean_id"]).convert('RGB')), base=16)
                    clean_name = self._get_original_rain_name(dataset["clean_id"])
                    clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)

            degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img))

        clean_patch = self.toTensor(clean_patch)
        degrad_patch = self.toTensor(degrad_patch)

        return [clean_name, deg_id], degrad_patch, clean_patch


        # noise_degradation_level = dataset["de_type"]
        # hq_image = load_img(hq_path)

        # # randonly crop a patch from the input image
        # clean_patch = self.crop_transform(hq_image)
        # clean_patch= np.array(clean_patch)

        # # randomly flip (vertically/horizontally) the patch
        # clean_patch = random_augmentation(clean_patch)[0]
        # degraded_patch = self.noise_gradation_generator.add_noise_degradation(clean_patch, noise_degradation_level)

        # clean_patch = self.toTensor(clean_patch)
        # degraded_patch = self.toTensor(degraded_patch)
        # return degraded_patch, clean_patch, hq_path

    # cannot have two differnt dataloaders
    # so have to merge rainy, noisy images together.
    # clean_ids are just path to image file
    # so no need for us to pass in the actual lq_path and hq_path




In [5]:
# from google.colab import drive
# drive.mount('/content/drive')

#Scheduler.py


In [6]:
import math
from torch.optim.lr_scheduler import _LRScheduler
import warnings
from typing import List
from torch.optim import Optimizer

class LinearWarmupCosineAnnealingLR(_LRScheduler):
    """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
    and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
    .. warning::
        It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
        after each iteration as calling it after each epoch will keep the starting lr at
        warmup_start_lr for the first epoch which is 0 in most cases.
    .. warning::
        passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
        It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
        :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
        epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
        train and validation methods.
    Example:
        >>> layer = nn.Linear(10, 1)
        >>> optimizer = Adam(layer.parameters(), lr=0.02)
        >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
        >>> #
        >>> # the default case
        >>> for epoch in range(40):
        ...     # train(...)
        ...     # validate(...)
        ...     scheduler.step()
        >>> #
        >>> # passing epoch param case
        >>> for epoch in range(40):
        ...     scheduler.step(epoch)
        ...     # train(...)
        ...     # validate(...)
    """

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_epochs: int,
        max_epochs: int,
        warmup_start_lr: float = 0.0,
        eta_min: float = 0.0,
        last_epoch: int = -1,
    ) -> None:
        """
        Args:
            optimizer (Optimizer): Wrapped optimizer.
            warmup_epochs (int): Maximum number of iterations for linear warmup
            max_epochs (int): Maximum number of iterations
            warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
            eta_min (float): Minimum learning rate. Default: 0.
            last_epoch (int): The index of last epoch. Default: -1.
        """
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min

        super().__init__(optimizer, last_epoch)

    def get_lr(self) -> List[float]:
        """Compute learning rate using chainable form of the scheduler."""
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
                UserWarning,
            )

        if self.last_epoch == 0:
            return [self.warmup_start_lr] * len(self.base_lrs)
        if self.last_epoch < self.warmup_epochs:
            return [
                group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]
        if self.last_epoch == self.warmup_epochs:
            return self.base_lrs
        if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
            return [
                group["lr"]
                + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]

        return [
            (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            / (
                1
                + math.cos(
                    math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
                )
            )
            * (group["lr"] - self.eta_min)
            + self.eta_min
            for group in self.optimizer.param_groups
        ]

    def _get_closed_form_lr(self) -> List[float]:
        """Called when epoch is passed as a param to the `step` function of the scheduler."""
        if self.last_epoch < self.warmup_epochs:
            return [
                self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr in self.base_lrs
            ]

        return [
            self.eta_min
            + 0.5
            * (base_lr - self.eta_min)
            * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            for base_lr in self.base_lrs
        ]


#Offsetgenerator.py


In [7]:
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn  import functional as F
import torch

class OffsetGenerator(nn.Module):
    def __init__(self, in_prompt_dim, out_conv_shapes):
        """
        in_prompt_dim: The number of channels in the incoming prompt (e.g. prompt_dim).
        out_conv_shapes: Some descriptor of how many weights or channels
                         you need to offset (e.g. #channels_out * #channels_in * kernel_dim^2).
        """

        super().__init__()
        out_channel = 32

        # refine local patterns in the prompt
        self.conv = nn.Conv2d(in_prompt_dim, out_channel, kernel_size=3, padding=1, stride=1)

        self.global_average_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.output_conv_size = out_conv_shapes

        self.flatten_layer = nn.Linear(in_features=out_channel, out_features=self.output_conv_size)


    def forward(self, prompt):
        """
        prompt: output from PGM with shape (B, prompt_dim, H, W)
        returns: shape (B, out_conv_size)
        """

        prompt = self.conv(prompt)

        prompt = self.global_average_pool(prompt)

        prompt = prompt.view(prompt.size(0), -1)

        offset_vector = self.flatten_layer(prompt)

        offset_vector = torch.tanh(offset_vector)

        return offset_vector



#nafnet.py


In [8]:
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/megvii-research/NAFNet

'''
Simple Baselines for Image Restoration

@article{chen2022simple,
  title={Simple Baselines for Image Restoration},
  author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
  journal={arXiv preprint arXiv:2204.04676},
  year={2022}
}
'''

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init as init
# from torch.nn.modules.batchnorm import _BatchNorm
# from models.nafnet_utils import Local_Base, LayerNorm2d
# from models.cbin_weight import CBINorm_Conv2d

# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/megvii-research/NAFNet

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
            dim=0), None

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)



class AvgPool2d(nn.Module):
    def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
        super().__init__()
        self.kernel_size = kernel_size
        self.base_size = base_size
        self.auto_pad = auto_pad

        # only used for fast implementation
        self.fast_imp = fast_imp
        self.rs = [5, 4, 3, 2, 1]
        self.max_r1 = self.rs[0]
        self.max_r2 = self.rs[0]
        self.train_size = train_size

    def extra_repr(self) -> str:
        return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
            self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
        )

    def forward(self, x):
        if self.kernel_size is None and self.base_size:
            train_size = self.train_size
            if isinstance(self.base_size, int):
                self.base_size = (self.base_size, self.base_size)
            self.kernel_size = list(self.base_size)
            self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
            self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]

            # only used for fast implementation
            self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
            self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])

        if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
            return F.adaptive_avg_pool2d(x, 1)

        if self.fast_imp:  # Non-equivalent implementation but faster
            h, w = x.shape[2:]
            if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
                out = F.adaptive_avg_pool2d(x, 1)
            else:
                r1 = [r for r in self.rs if h % r == 0][0]
                r2 = [r for r in self.rs if w % r == 0][0]
                # reduction_constraint
                r1 = min(self.max_r1, r1)
                r2 = min(self.max_r2, r2)
                s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
                n, c, h, w = s.shape
                k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
                out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
                out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
        else:
            n, c, h, w = x.shape
            s = x.cumsum(dim=-1).cumsum_(dim=-2)
            s = torch.nn.functional.pad(s, (1, 0, 1, 0))  # pad 0 for convenience
            k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
            s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
            out = s4 + s1 - s2 - s3
            out = out / (k1 * k2)

        if self.auto_pad:
            n, c, h, w = x.shape
            _h, _w = out.shape[2:]
            # print(x.shape, self.kernel_size)
            pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
            out = torch.nn.functional.pad(out, pad2d, mode='replicate')

        return out

def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
    for n, m in model.named_children():
        if len(list(m.children())) > 0:
            ## compound module, go inside it
            replace_layers(m, base_size, train_size, fast_imp, **kwargs)

        if isinstance(m, nn.AdaptiveAvgPool2d):
            pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
            assert m.output_size == 1
            setattr(model, n, pool)


'''
ref.
@article{chu2021tlsc,
  title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
  author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
  journal={arXiv preprint arXiv:2112.04491},
  year={2021}
}
'''
class Local_Base():
    def convert(self, *args, train_size, **kwargs):
        replace_layers(self, *args, train_size=train_size, **kwargs)
        imgs = torch.rand(train_size)
        with torch.no_grad():
            self.forward(imgs)


class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0., modify_conv_weights=False):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
                               bias=True)
        self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True),
        )

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp, offset_vector=None):
        x = inp

        x = self.norm1(x)

        x = self.conv1(x)

        # # modifying the conv weights
        # if self.modify_conv_weights:
        #     self.conv1.weight.data = self.weights_modifier_1(self.conv1.weight, prompt)

        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)

        if offset_vector is not None:
            base_weight = self.conv5.weight
            # Correct the reshaping logic to match the offset vector size
            offset_reshaped = offset_vector.view(offset_vector.shape[0], self.conv5.out_channels, self.conv5.in_channels, self.conv5.kernel_size[0], self.conv5.kernel_size[1])

            # Apply offset to the base weights, considering the batch dimension
            # This assumes you want to apply different offsets per item in the batch
            modulated_weight = base_weight[None, ...]  + offset_reshaped


            # Perform convolution using modulated weights for each item/image in the batch
            x_list = []
            for i in range(offset_vector.shape[0]):
                x_list.append(F.conv2d(
                    x[i:i+1], modulated_weight[i],
                    bias=self.conv5.bias,
                    stride=self.conv5.stride,
                    padding=self.conv5.padding,
                    dilation=self.conv5.dilation,
                    groups=self.conv5.groups
               ))
            x = torch.cat(x_list, dim=0)
        else:
            x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma


class NAFNet(nn.Module):

    def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
        super().__init__()

        self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        chan = width
        for num in enc_blk_nums:
            self.encoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            self.downs.append(
                nn.Conv2d(chan, 2*chan, 2, 2)
            )
            chan = chan * 2

        self.middle_blks = \
            nn.Sequential(
                *[NAFBlock(chan) for _ in range(middle_blk_num)]
            )

        for num in dec_blk_nums:
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2)
                )
            )
            chan = chan // 2
            self.decoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )

        self.padder_size = 2 ** len(self.encoders)

    def forward(self, inp):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)

        x = self.intro(inp)

        encs = []

        for encoder, down in zip(self.encoders, self.downs):
            x = encoder(x)
            encs.append(x)
            x = down(x)

        x = self.middle_blks(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = up(x)
            x = x + enc_skip
            x = decoder(x)

        x = self.ending(x)
        x = x + inp

        return x[:, :, :H, :W]

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x

class NAFNetLocal(Local_Base, NAFNet):
    def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
        Local_Base.__init__(self)
        NAFNet.__init__(self, *args, **kwargs)

        N, C, H, W = train_size
        base_size = (int(H * 1.5), int(W * 1.5))

        self.eval()
        with torch.no_grad():
            self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)


def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
    """
    Create Nafnet model
    https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
    """

    net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)

    # inp_shape = (3, 256, 256)

    # from ptflops import get_model_complexity_info

    # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)

    # params = float(params[:-3])
    # macs = float(macs[:-4])

    # print(macs, params)

    return net

#InstructIR.py

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init as init
# from models.OffsetGenerator import OffsetGenerator
# from models.nafnet import NAFBlock

# ------------------------------------------------------------------------
# Copyright (c) 2023 va1shn9v. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/va1shn9v/PromptIR

'''
@inproceedings{potlapalli2023promptir,
  title={PromptIR: Prompting for All-in-One Image Restoration},
  author={Potlapalli, Vaishnav and Zamir, Syed Waqas and Khan, Salman and Khan, Fahad},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023}
}
'''
##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
    def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlock,self).__init__()
        self.prompt_dim= prompt_dim
        # prompt_dim=128: Defines the number of channels in the prompt.
        # prompt_len=5: The number of different prompts available.
        # prompt_size=96: The spatial resolution of each prompt (assumed to be square: 96×96).
        # lin_dim=192: The input dimension for the linear layer.
        # prompt_param's size = 1 * N * C * H * W where N = number of prompt components
        # prompt_param = learnable parameters
        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        # A linear layer takes an input of size "lin_dim" and produces "prompt_len" outputs.
        # This layer generates weights that determine the importance of each prompt.
        self.linear_layer = nn.Linear(lin_dim,prompt_len)
        # A 3×3 convolution with the same number of input and output channels (prompt_dim).
        # Stride =1 and padding =1 ensure the spatial size remains unchanged.
        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)

    # During training, the PromptGenBlock learns to encode prompt_len different degradations into the prompt_param tensor.
    def forward(self,x):
        # x = image feature representation
        B,C,H,W = x.shape
        # x is averaged over the last two dimensions (H, W). The output tensor is of shape (B, C) => global descriptor of the input
        emb = x.mean(dim=(-2,-1))
        # the embedding is passed through the linear layer to produce a tensor of shape (B, prompt_len)
        # softmax() ensures values in the linear layer output sum to 1 across the "prompt_len" dimension
        # meaning each prompt gets an importance weight for each sample in the batch
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
        # prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) expands the dimensions of prompt_weights to shape (B, prompt_len, 1, 1, 1).
        # self.prompt_param.unsqueeze(0).repeat(B, 1, 1, 1, 1, 1).squeeze(1):
        # Expands prompt_param to (B, prompt_len, prompt_dim, prompt_size, prompt_size), so each batch element has its own copy.
        # repeat(B, 1, 1, 1, 1, 1) replicates the prompts across the batch.
        # The two tensors are multiplied element-wise, meaning each prompt is weighted by the computed prompt_weights.
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        # The weighted sum across prompt_len produces a single prompt tensor for each batch element.
        # The new shape is (B, prompt_dim, prompt_size, prompt_size).
        prompt = torch.sum(prompt,dim=1)
        # resize the spatial dimentions of prompt (prompt_size, prompt_size) to (H, W) while keeping B and prompt_dim unchanged.
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        # The learned prompt undergoes a 3×3 convolution for further processing without changing the shape of "prompt"
        prompt = self.conv3x3(prompt)
        # final shape of prompt = (B, prompt_dim, H, W)
        return prompt
        # If multiple degradations exist in the input image feature, the returned prompt will encode a mixture of the degradations it has learned during training





##########################################################################

class ICB(nn.Module):
    """
    Instruction Condition Block (ICB)
    Paper Section 3.3
    """

    def __init__(self, feature_dim, text_dim=768):
        super(ICB, self).__init__()
        self.fc    = nn.Linear(text_dim, feature_dim)
        self.block = NAFBlock(feature_dim)
        self.beta  = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
    # f' = Block(f * mc) + f
    # mc = sigmoid(Wc * emb)
    def forward(self, x, text_embedding):
        gating_factors = torch.sigmoid(self.fc(text_embedding))
        # mc
        gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)


        f = x * self.gamma + self.beta  # 1) learned feature scaling/modulation
        f = f * gating_factors          # 2) (soft) feature routing based on text
        f = self.block(f)               # 3) block feature enhancement
        return f + x # skip connection


class InstructIR(nn.Module):
    """
    InstructIR model using NAFNet (ECCV 2022) as backbone.
    The model takes as input an RGB image and a text embedding (encoded instruction).
    Described in Paper Section 3.3
    """

    def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768, include_offset=False):
        super().__init__()

        self.intro  = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)

        self.encoders    = nn.ModuleList()
        self.decoders    = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups         = nn.ModuleList()
        self.downs       = nn.ModuleList()
        self.enc_cond    = nn.ModuleList()
        self.dec_cond    = nn.ModuleList()

        self.include_offset = include_offset

        chan = width

        # if include_offset is True:
        self.prompt_block_level1 = PromptGenBlock(prompt_dim=chan*2,prompt_len=3,prompt_size = 128,lin_dim = chan*2)
        self.prompt_block_level2 = PromptGenBlock(prompt_dim=chan*4,prompt_len=3,prompt_size = 64,lin_dim = chan*4)
        self.prompt_block_level3 = PromptGenBlock(prompt_dim=chan*8,prompt_len=3,prompt_size = 32,lin_dim = chan*8)

        # prompt_dim_level3 = chan*2**3

        self.promptBlocks = nn.ModuleList()
        self.promptBlocks.append(self.prompt_block_level3)
        self.promptBlocks.append(self.prompt_block_level2)
        self.promptBlocks.append(self.prompt_block_level1)

        # self.promptBlocks = [self.prompt_block_level3, self.prompt_block_level2, self.prompt_block_level1]

        for num in enc_blk_nums:
            #  Each encoder applies multiple NAFBlocks
            self.encoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            # Each encoding layer is modulated using a corresponding ICB
            # which incorporates the provided text embeddings.
            self.enc_cond.append(ICB(chan, txtdim))
            # Downsampling layers that reduce spatial resolution while increasing channel dimensions.
            self.downs.append(
                nn.Conv2d(chan, 2*chan, 2, 2)
            )
            chan = chan * 2
        # Middle blocks: a series of NAFBlocks applied to the deepest,
        # most abstract representation of the image features.
        # print(f"middle block chan: {chan}")
        self.middle_blks = nn.Sequential(
            *[NAFBlock(chan) for _ in range(middle_blk_num)]
        )

        naf_block = self.middle_blks[0]
        cIn = naf_block.conv5.in_channels
        cOut = naf_block.conv5.out_channels
        kernel_size = naf_block.conv5.kernel_size
        vector_size = cIn * cOut * kernel_size[0] * kernel_size[1]

        self.middleblock_offsetGen=OffsetGenerator(in_prompt_dim=chan, out_conv_shapes=vector_size)


        self.prompt_block_middle_blks = PromptGenBlock(prompt_dim=chan,prompt_len=3,prompt_size = 32,lin_dim = chan)

        # decoding path
        for num in dec_blk_nums:
            # Upsampling layers that increase the spatial resolution and
            # decrease the channel dimensions
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2)
                )
            )
            chan = chan // 2
            # sequantially processes upsampled features using multiple NAFBlocks
            self.decoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            # Add text embedding as modulation
            self.dec_cond.append(ICB(chan, txtdim))

        self.padder_size = 2 ** len(self.encoders)


        # self.offset_generators = []
        self.offset_generators = nn.ModuleList()

        # if include_offset is True:
        for i, decoder in enumerate(self.decoders):
            if(i <3):
                naf_block = decoder[0]

                cIn = naf_block.conv5.in_channels
                cOut = naf_block.conv5.out_channels
                kernel_size = naf_block.conv5.kernel_size

                vector_size = cIn * cOut * kernel_size[0] * kernel_size[1]

                # in_prompt_dim = self.promptBlocks[i].prompt_dim

                self.offset_generators.append(OffsetGenerator(in_prompt_dim=self.promptBlocks[i].prompt_dim, out_conv_shapes=vector_size))


    def forward(self, inp, txtembd):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)
        # intro = a convolutional layer to preprocess the input image
        x = self.intro(inp)
        encs = []

        for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
            x = encoder(x)
            x = enc_mod(x, txtembd)
            encs.append(x)
            x = down(x)


        if self.include_offset:
            # print("middle block include_offset")
            middle_block_prompt = self.prompt_block_middle_blks(x)
            # print("after prompt_block_level3")
            offset_vector = self.middleblock_offsetGen(middle_block_prompt)
            # print("after offset_generators[0](middle_block_prompt)")
            for naf_block in self.middle_blks:
                x = naf_block(x, offset_vector)
        else:
            x = self.middle_blks(x)

        index = 0
        for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):

            if self.include_offset is True:
                # print(f"decoder include_offset: {self.include_offset}")
                x = up(x)
                # offset_vector = None
                # if (index < 3):
                #     degradation_aware_prompt = self.promptBlocks[index](x)
                #     offset_vector = self.offset_generators[index](degradation_aware_prompt)
                #     index += 1

                x = x + enc_skip

                offset_vector = None
                if (index < 3):
                    degradation_aware_prompt = self.promptBlocks[index](x)
                    offset_vector = self.offset_generators[index](degradation_aware_prompt)
                    index += 1

                if offset_vector is not None:
                    # print("inside 'if offset_vector is not None:'")
                    for naf_block in decoder:
                        x = naf_block(x, offset_vector)
                else:
                    x = decoder(x)

                x = dec_mod(x, txtembd)
            else:
                x = up(x)
                x = x + enc_skip
                x = decoder(x)
                x = dec_mod(x, txtembd)

        # ending = conv layer to postprocess the final decoded feature into the desired image format.
        x = self.ending(x)
        x = x + inp

        return x[:, :, :H, :W]

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x


def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768, include_offset=False):

    net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim, include_offset=include_offset)

    return net

#Text model

In [10]:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
import os

# Models that use mean pooling
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


class LanguageModel(nn.Module):
    def __init__(self, model='distilbert-base-uncased'):
        super(LanguageModel, self).__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModel.from_pretrained(model)
        self.model_name = model
        # Remove the CLIP vision tower
        if "clip" in self.model_name:
            self.model.vision_model = None
        # Freeze the pre-trained parameters (very important)
        for param in self.model.parameters():
            param.requires_grad = False

        # Make sure to set evaluation mode (also important)
        self.model.eval()

    def forward(self, text_batch):
        inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
        # Get the device of the model's parameters (assumes all parameters are on the same device)
        device = next(self.model.parameters()).device
        # Move all inputs to that device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad(): # Ensure no gradients are computed for this forward pass

            if "clip" in self.model_name:
                sentence_embedding = self.model.get_text_features(**inputs)
                return sentence_embedding

            outputs = self.model(**inputs)

        if any(model in self.model_name for model in POOL_MODELS):
            sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
            # Normalize embeddings
            sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
        else:
            sentence_embedding = outputs.last_hidden_state[:, 0, :]
        return sentence_embedding


class LMHead(nn.Module):
    def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
        super(LMHead, self).__init__()

        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        #self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        embd = self.fc1(x)
        embd = F.normalize(embd, p=2, dim=1)
        deg_pred = self.fc2(embd)
        return embd, deg_pred

#Install transformers & setup "config"

In [11]:
!pip install transformers
!pip install pytorch-msssim



#Test Performance

#metrics.py


In [12]:
import numpy as np
import math
import cv2
import torch


def np_psnr(y_true, y_pred, maxval=1.):
    mse = np.mean((y_true - y_pred) ** 2)
    if(mse == 0):
        return np.inf

    psnr = 20 * np.log10(maxval / np.sqrt(mse))
    return psnr

def pt_psnr(y_true, y_pred, maxval=1.):
    mse = torch.mean((y_true - y_pred) ** 2, dim=(1, 2, 3))
    psnr = 20 * torch.log10(maxval / torch.sqrt(mse))
    return psnr.unsqueeze(1)


############# SWINIR METRICS
# https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/utils/util_calculate_psnr_ssim.py#L243


def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate PSNR (Peak Signal-to-Noise Ratio).

    Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the PSNR calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: psnr result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20. * np.log10(255. / np.sqrt(mse))


def _ssim(img1, img2):
    """Calculate SSIM (structural similarity) for one channel images.

    It is called by func:`calculate_ssim`.

    Args:
        img1 (ndarray): Images with range [0, 255] with order 'HWC'.
        img2 (ndarray): Images with range [0, 255] with order 'HWC'.

    Returns:
        float: ssim result.
    """

    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate SSIM (structural similarity).

    Ref:
    Image quality assessment: From error visibility to structural similarity

    The results are the same as that of the official released MATLAB code in
    https://ece.uwaterloo.ca/~z70wang/research/ssim/.

    For three-channel images, SSIM is calculated for each channel and then
    averaged.

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the SSIM calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: ssim result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    ssims = []
    for i in range(img1.shape[2]):
        ssims.append(_ssim(img1[..., i], img2[..., i]))
    return np.array(ssims).mean()


def reorder_image(img, input_order='HWC'):
    """Reorder images to 'HWC' order.

    If the input_order is (h, w), return (h, w, 1);
    If the input_order is (c, h, w), return (h, w, c);
    If the input_order is (h, w, c), return as it is.

    Args:
        img (ndarray): Input image.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            If the input image shape is (h, w), input_order will not have
            effects. Default: 'HWC'.

    Returns:
        ndarray: reordered image.
    """

    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
    if len(img.shape) == 2:
        img = img[..., None]
    if input_order == 'CHW':
        img = img.transpose(1, 2, 0)
    return img


def to_y_channel(img):
    """Change to Y channel of YCbCr.

    Args:
        img (ndarray): Images with range [0, 255].

    Returns:
        (ndarray): Images with range [0, 255] (float type) without round.
    """
    if np.max(img) > 1.:
        img = img.astype(np.float32) / 255.

    if img.ndim == 3 and img.shape[2] == 3:
        img = bgr2ycbcr(img, y_only=True)
        img = img[..., None]
    return img * 255.


def _convert_input_type_range(img):
    """Convert the type and range of the input image.

    It converts the input image to np.float32 type and range of [0, 1].
    It is mainly used for pre-processing the input image in colorspace
    convertion functions such as rgb2ycbcr and ycbcr2rgb.

    Args:
        img (ndarray): The input image. It accepts:
            1. np.uint8 type with range [0, 255];
            2. np.float32 type with range [0, 1].

    Returns:
        (ndarray): The converted image with type of np.float32 and range of
            [0, 1].
    """
    img_type = img.dtype
    img = img.astype(np.float32)
    if img_type == np.float32:
        pass
    elif img_type == np.uint8:
        img /= 255.
    else:
        raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
    return img


def _convert_output_type_range(img, dst_type):
    """Convert the type and range of the image according to dst_type.

    It converts the image to desired type and range. If `dst_type` is np.uint8,
    images will be converted to np.uint8 type with range [0, 255]. If
    `dst_type` is np.float32, it converts the image to np.float32 type with
    range [0, 1].
    It is mainly used for post-processing images in colorspace convertion
    functions such as rgb2ycbcr and ycbcr2rgb.

    Args:
        img (ndarray): The image to be converted with np.float32 type and
            range [0, 255].
        dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
            converts the image to np.uint8 type with range [0, 255]. If
            dst_type is np.float32, it converts the image to np.float32 type
            with range [0, 1].

    Returns:
        (ndarray): The converted image with desired type and range.
    """
    if dst_type not in (np.uint8, np.float32):
        raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
    if dst_type == np.uint8:
        img = img.round()
    else:
        img /= 255.
    return img.astype(dst_type)


def bgr2ycbcr(img, y_only=False):
    """Convert a BGR image to YCbCr image.

    The bgr version of rgb2ycbcr.
    It implements the ITU-R BT.601 conversion for standard-definition
    television. See more details in
    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.

    It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
    In OpenCV, it implements a JPEG conversion. See more details in
    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.

    Args:
        img (ndarray): The input image. It accepts:
            1. np.uint8 type with range [0, 255];
            2. np.float32 type with range [0, 1].
        y_only (bool): Whether to only return Y channel. Default: False.

    Returns:
        ndarray: The converted YCbCr image. The output image has the same type
            and range as input image.
    """
    img_type = img.dtype
    img = _convert_input_type_range(img)
    if y_only:
        out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
    else:
        out_img = np.matmul(
            img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
    out_img = _convert_output_type_range(out_img, img_type)
    return out_img



#test.py

In [13]:
import os
import gc
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import os
from datetime import datetime
# from metrics import pt_psnr, calculate_ssim, calculate_psnr
from pytorch_msssim import ssim
# from utils import save_rgb
today = datetime.today().strftime("%d_%m_%Y")

DEST_PATH = "/content/drive/MyDrive/FYPData/eval_results/instructir_wm_enabled_performance.txt"

print(f"DEST_PATH: {DEST_PATH}")

def save_performance_to_file(dest_path, text):
    """
    Append the given text to the file at dest_path.
    If the file does not exist, it is created.
    If the folder does not exist, it is created.

    Args:
        dest_path (str): The path to the destination file.
        text (str): The text to be appended to the file.
    """
    # Extract the directory from the destination path.
    directory = os.path.dirname(dest_path)
    # if directory and not os.path.exists(directory):
    os.makedirs(directory, exist_ok=True)

    print(f"Saving performance statistics to {dest_path}")
    with open(dest_path, "a") as f:
        f.write(text)
        f.write("\n")
    print("Saving performance statistics done!")

def return_map(deg):
    if "rain" in deg:
        return "deraining"
    elif "noise" in deg or "nois" in deg:
        return "denoising"
    elif "haze" in deg:
        return "dehazing"
    else:
        return "deraining"

def get_wrong_degradation(prompt):
    degradations = ["noise", "rain", "haze"]
    degradations.remove(prompt)
    return random.choice(degradations)

def augment_prompt(prompt):
    ### special prompts
    lol_prompts = ["fix the illumination", "increase the exposure of the photo", "the image is too dark to see anything, correct the photo", "poor illumination, improve the shot", "brighten dark regions", "make it HDR", "improve the light of the image", "Can you make the image brighter?"]
    sr_prompts  = ["I need to enhance the size and quality of this image.", "My photo is lacking size and clarity; can you improve it?", "I'd appreciate it if you could upscale this photo.", "My picture is too little, enlarge it.", "upsample this image", "increase the resolution of this photo", "increase the number of pixels", "upsample this photo", "Add details to this image", "improve the quality of this photo"]
    en_prompts  = ["make my image look like DSLR", "improve the colors of my image", "improve the contrast of this photo", "apply tonemapping", "enhance the colors of the image", "retouch the photo like a photograper"]

    init = np.random.choice(["Remove the", "Reduce the", "Clean the", "Fix the", "Remove", "Improve the", "Correct the",])
    end  = np.random.choice(["please", "fast", "now", "in the photo", "in the picture", "in the image", ""])
    newp = f"{init} {prompt} {end}"

    if "lol" in prompt:
        newp = np.random.choice(lol_prompts)
    elif "sr" in prompt:
        newp = np.random.choice(sr_prompts)
    elif "en" in prompt:
        newp = np.random.choice(en_prompts)

    newp = newp.strip().replace("  ", " ").replace("\n", "")
    return newp

def test_model(model, language_model, lm_head, testsets, device, promptify, savepath="results/", initial_message=""):

    model.eval()
    if language_model:
        language_model.eval()
        lm_head.eval()

    DEG_ACC = []
    derain_datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800']

    statistics_message = initial_message

    with torch.no_grad():

        for testset in testsets:

            if savepath:
                dt_results_path = os.path.join(savepath, testset.name)
                if not os.path.exists(dt_results_path):
                    os.makedirs(dt_results_path, exist_ok=True)

            eval_message = f"\n>>> Eval on {testset.name} for {testset.degradation}(class={testset.deg_class})\n"
            statistics_message += eval_message

            print(eval_message)

            testset_name = testset.name
            test_dataloader = DataLoader(testset, batch_size=1, num_workers=4, drop_last=True, shuffle=False)
            psnr_dataset = []
            ssim_dataset = []
            psnr_noisy   = []
            use_y_channel= False

            if testset.name in derain_datasets:
                use_y_channel = True
                psnr_y_dataset = []
                ssim_y_dataset = []

            statistics_message += "The input human instructions (first 5):\n"

            for idx, batch in enumerate(test_dataloader):

                x = batch[0].to(device) # HQ image
                y = batch[1].to(device) # LQ image
                f = batch[2][0]         # filename
                # print(f""
                t = [promptify(testset.degradation) for _ in range(x.shape[0])]
                # t = t.to(device)
                # statistics_message += "The input human instructions (first 5):\n"
                if language_model:
                    # statistics_message += "The input human instructions (first 5):\n"
                    if idx < 5:
                        # print the input prompt for debugging
                        statistics_message += f"{t}\n"
                        print("\nInput prompt:", t)


                    lm_embd = language_model(t)
                    lm_embd = lm_embd.to(device)
                    text_embd, deg_pred = lm_head(lm_embd)
                    # text_embd = text_embd.to(device)
                    x_hat = model(y, text_embd)

                psnr_restore = torch.mean(pt_psnr(x, x_hat))
                psnr_dataset.append(psnr_restore.item())
                ssim_restore = ssim(x, x_hat, data_range=1., size_average=True)
                ssim_dataset.append(ssim_restore.item())
                psnr_base    = torch.mean(pt_psnr(x, y))
                psnr_noisy.append(psnr_base.item())

                if use_y_channel:
                    _x_hat = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32)
                    _x     = np.clip(x[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32)
                    _x_hat = (_x_hat*255).astype(np.uint8)
                    _x     = (_x*255).astype(np.uint8)

                    psnr_y = calculate_psnr(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True)
                    ssim_y = calculate_ssim(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True)
                    psnr_y_dataset.append(psnr_y)
                    ssim_y_dataset.append(ssim_y)

                ## SAVE RESULTS
                if savepath:
                    restored_img = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32)
                    img_name = f.split("/")[-1]
                    save_rgb(restored_img, os.path.join(dt_results_path, img_name))

            if len(psnr_dataset) > 0:
                result_str = f"{testset_name}_base {np.mean(psnr_noisy)} Total images: {len(psnr_dataset)}\n{testset_name}_psnr {np.mean(psnr_dataset)}\n{testset_name}_ssim {np.mean(ssim_dataset)}\n"
                print(result_str)
                y_channel_result_str = ""
                if use_y_channel:
                    y_channel_result_str =f"{testset_name}_psnr-Y {np.mean(psnr_y_dataset)} {len(psnr_y_dataset)}\n{testset_name}_ssim-Y {np.mean(ssim_y_dataset)}\n"
                    print(y_channel_result_str)

                statistics_message += result_str
                statistics_message += y_channel_result_str
                divide_line = 25 * "***"
                statistics_message += divide_line
                statistics_message += "\n\n"
                print(); print(divide_line)

                del test_dataloader,psnr_dataset, psnr_noisy; gc.collect()
    save_performance_to_file(DEST_PATH, statistics_message)

        # END OF FUNCTION

DEST_PATH: /content/drive/MyDrive/FYPData/eval_results/instructir_wm_enabled_performance.txt


#datasets.py


In [14]:
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import json
import os
from glob import glob
from PIL import Image

# from utils import load_img, modcrop
from torchvision import transforms

DEG_MAP = {
    "noise": 0,
    "blur" : 1,
    "rain" : 2,
    "haze" : 3,
    "lol"  : 4,
    "sr"   : 5,
    "en"   : 6,
}

DEG2TASK = {
    "noise": "denoising",
    "blur" : "deblurring",
    "rain" : "deraining",
    "haze" : "dehazing",
    "lol"  : "lol",
    "sr"   : "sr",
    "en"   : "enhancement"
}

def augment_prompt(prompt):
    ### special prompts
    lol_prompts = ["fix the illumination", "increase the exposure of the photo", "the image is too dark to see anything, correct the photo", "poor illumination, improve the shot", "brighten dark regions", "make it HDR", "improve the light of the image", "Can you make the image brighter?"]
    sr_prompts  = ["I need to enhance the size and quality of this image.", "My photo is lacking size and clarity; can you improve it?", "I'd appreciate it if you could upscale this photo.", "My picture is too little, enlarge it.", "upsample this image", "increase the resolution of this photo", "increase the number of pixels", "upsample this photo", "Add details to this image", "improve the quality of this photo"]
    en_prompts  = ["make my image look like DSLR", "improve the colors of my image", "improve the contrast of this photo", "apply tonemapping", "enhance the colors of the image", "retouch the photo like a photograper"]

    init = np.random.choice(["Remove the", "Reduce the", "Clean the", "Fix the", "Remove", "Improve the", "Correct the",])
    end  = np.random.choice(["please", "fast", "now", "in the photo", "in the picture", "in the image", ""])
    newp = f"{init} {prompt} {end}"

    if "lol" in prompt:
        newp = np.random.choice(lol_prompts)
    elif "sr" in prompt:
        newp = np.random.choice(sr_prompts)
    elif "en" in prompt:
        newp = np.random.choice(en_prompts)

    newp = newp.strip().replace("  ", " ").replace("\n", "")
    return newp

def get_deg_name(path):
    """
    Get the degradation name from the path
    """

    if ("gopro" in path) or ("GoPro" in path) or ("blur" in path) or ("Blur" in path) or ("RealBlur" in path):
        return "blur"
    elif ("SOTS" in path) or ("haze" in path) or ("sots" in path) or ("RESIDE" in path):
        return "haze"
    elif ("LOL" in path):
        return "lol"
    elif ("fiveK" in path):
        return "en"
    elif ("super" in path) or ("classicalSR" in path):
        return "sr"
    elif ("Rain100" in path) or ("rain13k" in path) or ("Rain13k" in path):
        return "rain"
    else:
        return "noise"

def crop_img(image, base=16):
    """
    Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others.
    """
    h = image.shape[0]
    w = image.shape[1]
    crop_h = h % base
    crop_w = w % base
    return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]


################# DATASETS


class RefDegImage(Dataset):
    """
    Dataset for Image Restoration having low-quality image and the reference image.
    Tasks: synthetic denoising, deblurring, super-res, etc.
    """

    def __init__(self, hq_img_paths, lq_img_paths, augmentations=None, val=False, name="test", deg_name="noise", deg_class=0):

        assert len(hq_img_paths) == len(lq_img_paths)

        self.hq_paths  = hq_img_paths
        self.lq_paths  = lq_img_paths
        self.totensor  = torchvision.transforms.ToTensor()
        self.val       = val
        self.augs      = augmentations
        self.name      = name
        self.degradation = deg_name
        self.deg_class = deg_class


        # New code start
        # # Default resize and ToTensor transformations
        self.resize_and_totensor = transforms.Compose([
            transforms.Resize((320, 480)),  # Resize to fixed dimensions
            transforms.ToTensor(),
        ])
        # new code end
        if self.val:
            self.augs = None # No augmentations during validation/test

    def __len__(self):
        return len(self.hq_paths)

    def __getitem__(self, idx):
        hq_path = self.hq_paths[idx]
        lq_path = self.lq_paths[idx]

        hq_image = load_img(hq_path)
        lq_image = load_img(lq_path)

        if self.val:
            # if an image has an odd number dimension we trim for example from [321, 189] to [320, 188].
            hq_image = crop_img(hq_image)
            lq_image = crop_img(lq_image)

        hq_image = self.totensor(hq_image.astype(np.float32))
        lq_image = self.totensor(lq_image.astype(np.float32))

        return hq_image, lq_image, hq_path



def create_testsets(testsets, debug=False):
    """
    Given a list of testsets create pytorch datasets for each.
    The method requires the paths to references and noisy images.
    """
    assert len(testsets) > 0

    if debug:
        print (20*'****')
        print ("Creating Testsets", len(testsets))

    datasets = []
    for testdt in testsets:

        path_hq , path_lq = testdt[0], testdt[1]
        if debug: print (path_hq , path_lq)

        if ("denoising" in path_hq) or ("jpeg" in path_hq):
            dataset_name  = path_hq.split("/")[-1]
            dataset_sigma = path_lq.split("/")[-1].split("_")[-1].split(".")[0]
            dataset_name  = dataset_name+ f"_{dataset_sigma}"
        elif "Rain" in path_hq:
            if "Rain100L" in path_hq:
                dataset_name  = "Rain100L"
            else:
                dataset_name  = path_hq.split("/")[3]

        elif ("gopro" in path_hq) or ("GoPro" in path_hq):
            dataset_name  = "GoPro"
        elif "LOL" in path_hq:
            dataset_name  = "LOL"
        elif "SOTS" in path_hq:
            dataset_name  = "SOTS"
        elif "fiveK" in path_hq:
            dataset_name  = "MIT5K"
        else:
            assert False, f"{path_hq} - unknown dataset"

        hq_img_paths = sorted(glob(os.path.join(path_hq, "*")))
        lq_img_paths = sorted(glob(os.path.join(path_lq, "*")))

        if "SOTS" in path_hq:
            # Haze removal SOTS test dataset
            dataset_name  = "SOTS"
            hq_img_paths = sorted(glob(os.path.join(path_hq, "*.jpg")))
            assert len(hq_img_paths) == 500

            lq_img_paths = [file.replace("GT", "IN") for file in hq_img_paths]

        if "fiveK" in path_hq:
            dataset_name  = "MIT5K"
            testf = "test-data/mit5k/test.txt"
            f = open(testf, "r")
            test_ids = f.readlines()
            test_ids = [x.strip() for x in test_ids]
            f.close()
            hq_img_paths = [os.path.join(path_hq, f"{x}.jpg") for x in test_ids]
            lq_img_paths = [x.replace("expertC", "input") for x in hq_img_paths]
            assert len(hq_img_paths) == 498

        if "gopro" in path_hq:
            assert len(hq_img_paths) == 1111

        if "LOL" in path_hq:
            assert len(hq_img_paths) == 15

        assert len(hq_img_paths) == len(lq_img_paths)

        deg_name  = get_deg_name(path_hq)
        deg_class = DEG_MAP[deg_name]

        valdts = RefDegImage(hq_img_paths = hq_img_paths,
                            lq_img_paths  = lq_img_paths,
                            val = True, name= dataset_name, deg_name=deg_name, deg_class=deg_class)

        datasets.append(valdts)

    assert len(datasets) == len(testsets)
    print (20*'****')

    return datasets

#options.py

# New section

In [15]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--patch_size', type=int, default=256, help='patchsize of input.')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of input.')

parser.add_argument('--num_workers', type=int, default=8, help='number of workers.')
parser.add_argument("--checkpoint_dir",type=str, default="/content/drive/MyDrive/FYPData/train_ckpt",help = "Name of the Directory where the checkpoint is to be saved")
parser.add_argument('--lm',      type=str, default="/content/drive/MyDrive/FYPData/models/lm_instructir-7d.pt", help='Path to the language model weights')
parser.add_argument('--config',  type=str, default='configs/eval5d.yml', help='Path to config file')

parser.add_argument('--promptify', type=str, default="simple_augment")
parser.add_argument('--debug',   action='store_true', help="Debug mode")
parser.add_argument('--save',    type=str, default='/content/drive/MyDrive/FYPData/performance_results/', help="Path to save the resultant images")
args = parser.parse_args([])

#eval_instructir.py

In [17]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as TF

import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import yaml
import random
import gc
from datetime import datetime

# from utils import *
# from models import instructir

# from text.models import LanguageModel, LMHead

config = {
    "llm": {
        "model": "TaylorAI/bge-micro-v2",  # See Paper Sec. 3.2 and Appendix
        "model_dim": 384,
        "embd_dim": 256,
        "nclasses": 7,  # noise, blur, rain, haze, lol, enhancement, upsampling (Paper Sec. 4.3)
        "weights": False
    },
    "model": {
        "arch": "instructir",
        "use_text": True,
        "in_ch": 3,
        "out_ch": 3,
        "width": 32,
        "enc_blks": [2, 2, 4, 8],
        "middle_blk_num": 4,
        "dec_blks": [2, 2, 2, 2],
        "textdim": 256,
        "weights": False
    },
    "test": {
        "batch_size": 1,
        "num_workers": 8,
        "dn_datapath": "/content/drive/MyDrive/FYPData/test-data/denoising_testsets/",
        "dn_datasets": ["CBSD68", "Kodak24"],
        "dn_sigmas": [15, 25, 50],
        "rain_targets": ["/content/drive/MyDrive/FYPData/test-data/Rain100L/original/"],
        "rain_inputs": ["/content/drive/MyDrive/FYPData/test-data/Rain100L/rainy/"],
        "haze_targets": "/content/drive/MyDrive/FYPData/test-data/SOTS/GT/",
        "haze_inputs": "/content/drive/MyDrive/FYPData/test-data/SOTS/IN/",
        "lol_targets": "/content/drive/MyDrive/FYPData/test-data/LOL/high/",
        "lol_inputs": "/content/drive/MyDrive/FYPData/test-data/LOL/low/",
        "gopro_targets": "/content/drive/MyDrive/FYPData/test-data/GoPro/target/",
        "gopro_inputs": "/content/drive/MyDrive/FYPData/test-data/GoPro/input/"
    }
}


def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True


if __name__=="__main__":

    SEED=42
    seed_everything(SEED=SEED)
    torch.backends.cudnn.deterministic = True

    now = datetime.now()

    # testing_message = f"{now.strftime('%Y-%m-%d %H:%M:%S')} Testing"

    # GPU        = args.device
    DEBUG      = args.debug

    IMAGE_MODEL_NAME = "/content/drive/MyDrive/FYPData/models/im_instructir-7d.pt"
    # IMAGE_MODEL_NAME = "/content/drive/MyDrive/FYPData/models/instructir_with_weight_modulation_32chan.pt"
    # instructIR_with_weight_modulation_20epochs_4mid.pth
    CONFIG     = args.config
    LM_HEAD_MODEL   = args.lm
    SAVE_PATH  = args.save


    testing_message = f"{now.strftime('%Y-%m-%d %H:%M:%S')} Testing Original InstructIR Correct Human Instructions -Ablation study on training human instructions as input\n"

    print ('CUDA GPU available: ', torch.cuda.is_available())

    # torch.cuda.set_device(f'cuda:{GPU}')
    device = torch.device(f'cuda' if torch.cuda.is_available() else "cpu")
    print('CUDA visible devices: ' + str(torch.cuda.device_count()))
    if torch.cuda.is_available():
        print('CUDA current device: ', torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device()))

    # parse config file
    # with open(os.path.join(CONFIG), "r") as f:
    #     config = yaml.safe_load(f)

    cfg = dict2namespace(config)

    use_offset = False

    testing_message += f"---- Weight Modification Enabled: {use_offset}\n"

    print(f"---- Weight Modification Enabled: {use_offset}\n")

    infor = f"Image Model Name: {IMAGE_MODEL_NAME}, Projection Head: {LM_HEAD_MODEL}\nDevice: {device}, Human Instruction Source: {args.promptify}\nConfig: {CONFIG}\n\n"

    testing_message += infor

    print(20*"****")
    print("EVALUATION")
    print(infor)
    print(20*"****")

    ################### TESTING DATASET

    TESTSETS = []
    denoise_testsets   = []
    rain_testsets = []
    haze_testsets = []
    # Denoising
    printed_noise_path = True
    try:
        for testset in cfg.test.dn_datasets:
            for sigma in cfg.test.dn_sigmas:
                noisy_testpath = os.path.join(cfg.test.dn_datapath, testset+ f"_{sigma}")
                clean_testpath = os.path.join(cfg.test.dn_datapath, testset)
                if printed_noise_path:
                    print(f"clean_testpath:{clean_testpath}, noisy_testpath:{noisy_testpath} ")
                    printed_noise_path = False
                denoise_testsets.append([clean_testpath, noisy_testpath])
    except:
        denoise_testsets = []

    printed_rain_path = True
    # RAIN
    try:
        for noisy_testpath, clean_testpath in zip(cfg.test.rain_inputs, cfg.test.rain_targets):
            if printed_rain_path:
                print(f"clean_testpath:{clean_testpath}, noisy_testpath:{noisy_testpath} ")
                printed_rain_path = False
            rain_testsets.append([clean_testpath, noisy_testpath])
    except:
        rain_testsets = []

    # HAZE
    try:
        haze_testsets = [[cfg.test.haze_targets, cfg.test.haze_inputs]]
    except:
        haze_testsets = []

    # # BLUR
    # try:
    #     blur_testsets = [[cfg.test.gopro_targets, cfg.test.gopro_inputs]]
    # except:
    #     blur_testsets = []

    # # LOL
    # try:
    #     lol_testsets = [[cfg.test.lol_targets, cfg.test.lol_inputs]]
    # except:
    #     lol_testsets = []

    # # MIT5K
    # try:
    #     mit_testsets = [[cfg.test.mit_targets, cfg.test.mit_inputs]]
    # except:
    #     mit_testsets = []

    TESTSETS += denoise_testsets
    TESTSETS += rain_testsets
    TESTSETS += haze_testsets
    # TESTSETS += blur_testsets
    # TESTSETS += lol_testsets
    # TESTSETS += mit_testsets

    # print ("Tests:", TESTSETS)

    if len(denoise_testsets) > 0:
        testing_message += f"Denoise testset length: {len(denoise_testsets)}\n"

    if len(rain_testsets) > 0:
        testing_message += f"Derain testset length: {len(rain_testsets)}\n"

    if len(haze_testsets) > 0:
        testing_message += f"Dehaze testset length: {len(haze_testsets)}\n"

    testset_len = f"Total testsets length: {len(TESTSETS)}\n"

    testing_message += testset_len

    print (testset_len)
    print (20 * "----")


    ################### RESTORATION MODEL

    print ("Creating InstructIR")
    model = create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
                    middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks,
                    txtdim=cfg.model.textdim, include_offset=use_offset)

    ################### LOAD IMAGE MODEL

    assert IMAGE_MODEL_NAME, "Model weights required for evaluation"

    print ("IMAGE MODEL CKPT:", IMAGE_MODEL_NAME)
    model.load_state_dict(torch.load(IMAGE_MODEL_NAME, map_location=device), strict=False)
    for param in model.parameters():
        param.requires_grad = True
    model = model.to(device)
    total_params = count_params(model)

    # Freeze prompt generator blocks
    for prompt_block in model.promptBlocks:
        for param in prompt_block.parameters():
            param.requires_grad = False

    # Freeze offset generators
    # (Assuming self.offset_generators is a list of OffsetGenerator modules)
    for offset_gen in model.offset_generators:
        for param in offset_gen.parameters():
            param.requires_grad = False


    for param in model.prompt_block_middle_blks.parameters():
        param.requires_grad = False

    for param in model.middleblock_offsetGen.parameters():
        param.requires_grad = False

    base_ir_model_params = count_params(model)

    for param in model.parameters():
      param.requires_grad = False

    # enable prompt generator blocks
    for prompt_block in model.promptBlocks:
        for param in prompt_block.parameters():
            param.requires_grad = True

    # enable offset generators
    for offset_gen in model.offset_generators:
        for param in offset_gen.parameters():
            param.requires_grad = True

    for param in model.prompt_block_middle_blks.parameters():
        param.requires_grad = True

    for param in model.middleblock_offsetGen.parameters():
        param.requires_grad = True


    params_of_modification = count_params(model)

    for param in model.parameters():
      param.requires_grad = True


    model_params_message = f"""\nTotal params: {total_params/1e6}M,
    base image restoration model params: {base_ir_model_params/1e6}M\n
    OffsetGenerator & PromptGenBlock params: {params_of_modification/1e6}M\n"""

    testing_message += model_params_message

    # nparams = count_params(model)
    print(model_params_message)
    ################### LANGUAGE MODEL

    # try:
    #     PROMPT_DB  = cfg.llm.text_db
    # except:
    #     PROMPT_DB  = None

    if cfg.model.use_text:
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        # Initialize the LanguageModel class
        LMODEL = cfg.llm.model
        language_model = LanguageModel(model=LMODEL).to(device)
        lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
        lm_head = lm_head.to(device)
        # language_model = language_model.to(device)
        lm_nparams   = count_params(lm_head)

        print("LMHEAD MODEL CKPT:", LM_HEAD_MODEL)
        lm_head.load_state_dict(torch.load(LM_HEAD_MODEL, map_location=device), strict=True)
        print("Loaded weights!")

    else:
        LMODEL = None
        language_model = None
        lm_head = None
        lm_nparams = 0

    print (20 * "----")

    ################### TESTING !!

    # from datasets import RefDegImage, augment_prompt, create_testsets

    if args.promptify == "simple_augment":
        promptify = augment_prompt
    elif args.promptify == "chatgpt":
        instruction_file_path = "/content/drive/MyDrive/FYPData/human_instructions.json"
        with open(instruction_file_path, "r") as f:
            prompts = json.load(f)
        for deg in prompts.keys():
            random.shuffle(prompts[deg])

        def promptify(deg):
            return random.choice(prompts[deg])

        print("--- Using ChatGPT generated human instructions\n")
        testing_message += "--- Using ChatGPT generated human instructions\n"
    else:
        def promptify(deg):
            return args.promptify


    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"

    test_datasets = create_testsets(TESTSETS, debug=True)

    test_model(model, language_model, lm_head, test_datasets, device, promptify, savepath=None, initial_message=testing_message)


CUDA GPU available:  True
CUDA visible devices: 1
CUDA current device:  0 NVIDIA A100-SXM4-40GB
---- Weight Modification Enabled: False

********************************************************************************
EVALUATION
Image Model Name: /content/drive/MyDrive/FYPData/models/im_instructir-7d.pt, Projection Head: /content/drive/MyDrive/FYPData/models/lm_instructir-7d.pt
Device: cuda, Human Instruction Source: simple_augment
Config: configs/eval5d.yml


********************************************************************************
clean_testpath:/content/drive/MyDrive/FYPData/test-data/denoising_testsets/CBSD68, noisy_testpath:/content/drive/MyDrive/FYPData/test-data/denoising_testsets/CBSD68_15 
clean_testpath:/content/drive/MyDrive/FYPData/test-data/Rain100L/original/, noisy_testpath:/content/drive/MyDrive/FYPData/test-data/Rain100L/rainy/ 
Total testsets length: 8

--------------------------------------------------------------------------------
Creating InstructIR
IMAGE MOD

KeyboardInterrupt: 