# New section

# Load the utils files:

In [1]:
# !cp -r /content/drive/MyDrive/FYPData/ /content/
from google.colab import drive
drive.mount('/content/drive')



# !cp /content/drive/MyDrive/FYPData/Train.zip /content/
# !unzip /content/Train.zip -d /content/


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#NoiseDegradation.py


In [2]:
import torch
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor

import random
import numpy as np

class NoiseDegradation(object):
    def __init__(self, args):
        super(NoiseDegradation, self).__init__()
        self.args = args
        self.toTensor = ToTensor()
        self.crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size),
        ])
    def _add_gaussian_noise(self, clean_patch, sigma):
        noise = np.random.randn(*clean_patch.shape)
        noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)

        return noisy_patch, clean_patch

    def _add_noise_degradation_by_level(self, clean_patch, degrade_type):
        degraded_patch = None
        if degrade_type == 0:
            # noise level (sigma) =15
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)
        elif degrade_type == 7:
            # noise level (sigma) =25
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)
        elif degrade_type == 8:
            # noise level (sigma) =50
            degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)

         # If degraded_patch is still None (meaning degrade_type wasn't 0, 7, or 8)
        # handle the case (raise an error, return the original, etc.)
        if degraded_patch is None:
            degraded_patch = clean_patch # or raise ValueError(f"Invalid degrade_type: {degrade_type}")

        return degraded_patch, clean_patch

    def add_noise_degradation(self,clean_patch,degrade_type = None):
        if degrade_type == 0 or degrade_type == 7 or degrade_type == 8:
            degrade_type= degrade_type
        else:
            degrade_type = random.choices([0,7,8])


        degraded_patch, _ = self._add_noise_degradation_by_level(clean_patch, degrade_type)
        return degraded_patch


#utils.py


In [3]:
import imageio
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import os
import argparse
import random
import torch


def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True

# def saveImage(filename, image):
#     imageTMP = np.clip(image * 255.0, 0, 255).astype('uint8')
#     imageio.imwrite(filename, imageTMP)

def save_rgb (img, filename):

    img = np.clip(img, 0., 1.)
    if np.max(img) <= 1:
        img = img * 255

    img = img.astype(np.float32)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, img)


def load_img (filename, norm=True,):

    # original code: img = np.array(Image.open(filename).convert("RGB"))
    # changed to:
    img = np.array(Image.open(filename).convert("RGB").resize((320,480)))

    if norm:
        img = img / 255.
        img = img.astype(np.float32)
    return img

def plot_all (images, figsize=(20,10), axis='off', names=None):
    nplots = len(images)
    fig, axs = plt.subplots(1,nplots, figsize=figsize, dpi=80,constrained_layout=True)
    for i in range(nplots):
        axs[i].imshow(images[i])
        if names: axs[i].set_title(names[i])
        axs[i].axis(axis)
    plt.show()

def modcrop(img_in, scale=2):
    # img_in: Numpy, HWC or HW
    img = np.copy(img_in)
    if img.ndim == 2:
        H, W = img.shape
        H_r, W_r = H % scale, W % scale
        img = img[:H - H_r, :W - W_r]
    elif img.ndim == 3:
        H, W, C = img.shape
        H_r, W_r = H % scale, W % scale
        img = img[:H - H_r, :W - W_r, :]
    else:
        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
    return img

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


########## MODEL

def count_params(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable_params

#TrainDataset.py

In [4]:

from torch.utils.data import Dataset

import torchvision.transforms.functional as TF
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
from PIL import Image
import numpy as np
import os
import random

# from utils import load_img, modcrop
# from torchvision import transforms
# from training_utils.NoiseDegradation import NoiseDegradation



DEG_MAP = {
    "noise_15" : 0,
    "blur"     : 1,
    "rain"     : 2,
    "haze"     : 3,
    "lol"      : 4,
    "sr"       : 5,
    "en"       : 6,
    "noise_25" : 7,
    "noise_50" : 8
}

DEG2TASK = {
    "noise": "denoising",
    "blur" : "deblurring",
    "rain" : "deraining",
    "haze" : "dehazing",
    "lol"  : "lol",
    "sr"   : "sr",
    "en"   : "enhancement"
}

def crop_img(image, base=16):
    """
    Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others.
    """
    h = image.shape[0]
    w = image.shape[1]
    crop_h = h % base
    crop_w = w % base
    return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]

def data_augmentation(image, mode):
    if mode == 0:
        # original
        out = image.numpy()
    elif mode == 1:
        # flip up and down
        out = np.flipud(image)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(image)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(image)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(image, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(image, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(image, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(image, k=3)
        out = np.flipud(out)
    else:
        raise Exception('Invalid choice of image transformation')
    return out

def random_augmentation(*args):
    out = []
    flag_aug = random.randint(1, 7)
    for data in args:
        out.append(data_augmentation(data, flag_aug).copy())
    return out




################# DATASETS


class InstructIRTrainDataset(Dataset):
    """
    Dataset for Image Restoration having low-quality image and the reference image.
    Tasks: synthetic denoising, deblurring, super-res, etc.
    """

    def __init__(self, args):

        super(InstructIRTrainDataset, self).__init__()

        self.toTensor  = ToTensor()
        self.noise_gradation_generator = NoiseDegradation(args)
        self.args = args
        self.de_type = args.de_type
        print(self.de_type)

        self._init_ids()
        self._merge_ids()

        self.crop_transform = Compose([
            ToPILImage(),
            RandomCrop(args.patch_size)
        ])

    def _crop_patch(self, img_1, img_2):
        H = img_1.shape[0]
        W = img_1.shape[1]
        ind_H = random.randint(0, H - self.args.patch_size)
        ind_W = random.randint(0, W - self.args.patch_size)

        patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
        patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]

        return patch_1, patch_2

    def _get_original_rain_name(self, rainy_name):
        og_name = rainy_name.split("rainy")[0] + 'original/norain-' + rainy_name.split('rain-')[-1]
        return og_name

    def _init_clean_image_for_noise_degradation(self):
        ref_file = self.args.data_file_dir + "clean_image_for_denoise.txt"
        temp_ids = []
        temp_ids+= [id_.strip() for id_ in open(ref_file)]
        clean_ids = []
        name_list = os.listdir(self.args.denoise_dir)
        clean_ids += [self.args.denoise_dir + id_ for id_ in name_list if id_.strip() in temp_ids]

        self.s15_ids = []
        self.s25_ids = []
        self.s50_ids = []

        if 'denoise_15' in self.de_type:
            self.s15_ids = [{"clean_id": x,"de_type":0} for x in clean_ids]
            self.s15_ids = self.s15_ids
            random.shuffle(self.s15_ids)
            self.s15_counter = 0
        if 'denoise_25' in self.de_type:
            self.s25_ids = [{"clean_id": x,"de_type":7} for x in clean_ids]
            self.s25_ids = self.s25_ids
            random.shuffle(self.s25_ids)
            self.s25_counter = 0
        if 'denoise_50' in self.de_type:
            self.s50_ids = [{"clean_id": x,"de_type":8} for x in clean_ids]
            self.s50_ids = self.s50_ids
            random.shuffle(self.s50_ids)
            self.s50_counter = 0


        print(f"Noisy Sigma 15 images len: {len(self.s15_ids)}")
        print(f"Noisy Sigma 25 images len: {len(self.s25_ids)}\n")
        print(f"Noisy Sigma 50 images len: {len(self.s50_ids)}\n")


    def _init_rs_ids(self):
        temp_ids = []
        rs = self.args.data_file_dir + "/rainy.txt"
        temp_ids+= [self.args.derain_dir + id_.strip() for id_ in open(rs)]
        self.rs_ids = [{"clean_id":x, "de_type":2} for x in temp_ids]
        self.rs_ids = self.rs_ids

        self.rl_counter = 0
        self.num_rl = len(self.rs_ids)
        print("Total Rainy Images : {}".format(self.num_rl))

    def _init_hazy_ids(self):
        temp_ids = []
        hazy = self.args.data_file_dir + "hazy_outside.txt"
        temp_ids+= [self.args.dehaze_dir + id_.strip() for id_ in open(hazy)]
        self.hazy_ids = [{"clean_id" : x,"de_type":3} for x in temp_ids]

        self.hazy_counter = 0

        self.num_hazy = len(self.hazy_ids)
        print("Total Hazy Images : {}".format(self.num_hazy))

    def _get_nonhazy_name(self, hazy_name):
        dir_name = hazy_name.split("synthetic")[0] + 'original/'
        name = hazy_name.split('/')[-1].split('_')[0]
        suffix = '.' + hazy_name.split('.')[-1]
        nonhazy_name = dir_name + name + suffix
        return nonhazy_name


    def __len__(self):
        return len(self.dataset_ids)

    def _init_ids(self):
        if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type:
            self._init_clean_image_for_noise_degradation()
        if 'derain' in self.de_type:
            self._init_rs_ids()
        if 'dehaze' in self.de_type:
            self._init_hazy_ids()

        random.shuffle(self.de_type)

    def _merge_ids(self):
        self.dataset_ids = []
        if "denoise_15" in self.de_type:
            self.dataset_ids += self.s15_ids
            self.dataset_ids += self.s25_ids
            self.dataset_ids += self.s50_ids
        if "derain" in self.de_type:
            self.dataset_ids+= self.rs_ids
        if "dehaze" in self.de_type:
            self.dataset_ids += self.hazy_ids

        print(f"Dataset_ids length: {len(self.dataset_ids)}")

    def __getitem__(self, idx):
        dataset = self.dataset_ids[idx]
        hq_path = dataset["clean_id"]
        deg_id = dataset["de_type"]

        if deg_id == 0 or deg_id == 7 or deg_id == 8:
            # noisy image removal
            if deg_id == 0:
                hq_path = dataset["clean_id"]
            elif deg_id == 7:
                hq_path = dataset["clean_id"]
            elif deg_id == 8:
                hq_path = dataset["clean_id"]

            clean_img = crop_img(np.array(Image.open(hq_path).convert('RGB')), base=16)
            clean_patch = self.crop_transform(clean_img)
            clean_patch= np.array(clean_patch)

            clean_name = hq_path.split("/")[-1].split('.')[0]

            clean_patch = random_augmentation(clean_patch)[0]

            degrad_patch = self.noise_gradation_generator.add_noise_degradation(clean_patch, deg_id)
        else:
            if deg_id == 2:
                # Rain Streak Removal
                degrad_img = crop_img(np.array(Image.open(dataset["clean_id"]).convert('RGB')), base=16)
                clean_name = self._get_original_rain_name(dataset["clean_id"])
                clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
            elif deg_id == 3:
                # Dehazing with SOTS outdoor training set
                degrad_img = crop_img(np.array(Image.open(dataset["clean_id"]).convert('RGB')), base=16)
                clean_name = self._get_nonhazy_name(dataset["clean_id"])
                # print(f"hazy: clean_name: {clean_name}, degraded name: {dataset['clean_id']}")
                clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)

            degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img))

        clean_patch = self.toTensor(clean_patch)
        degrad_patch = self.toTensor(degrad_patch)

        return [clean_name, deg_id], degrad_patch, clean_patch

#Scheduler.py


In [5]:
import math
from torch.optim.lr_scheduler import _LRScheduler
import warnings
from typing import List
from torch.optim import Optimizer

# ------------------------------------------------------------------------
# Copyright (c) 2023 va1shn9v. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/va1shn9v/PromptIR

'''
@inproceedings{potlapalli2023promptir,
  title={PromptIR: Prompting for All-in-One Image Restoration},
  author={Potlapalli, Vaishnav and Zamir, Syed Waqas and Khan, Salman and Khan, Fahad},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023}
}
'''
class LinearWarmupCosineAnnealingLR(_LRScheduler):
    """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
    and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
    .. warning::
        It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
        after each iteration as calling it after each epoch will keep the starting lr at
        warmup_start_lr for the first epoch which is 0 in most cases.
    .. warning::
        passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
        It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
        :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
        epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
        train and validation methods.
    Example:
        >>> layer = nn.Linear(10, 1)
        >>> optimizer = Adam(layer.parameters(), lr=0.02)
        >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
        >>> #
        >>> # the default case
        >>> for epoch in range(40):
        ...     # train(...)
        ...     # validate(...)
        ...     scheduler.step()
        >>> #
        >>> # passing epoch param case
        >>> for epoch in range(40):
        ...     scheduler.step(epoch)
        ...     # train(...)
        ...     # validate(...)
    """

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_epochs: int,
        max_epochs: int,
        warmup_start_lr: float = 0.0,
        eta_min: float = 0.0,
        last_epoch: int = -1,
    ) -> None:
        """
        Args:
            optimizer (Optimizer): Wrapped optimizer.
            warmup_epochs (int): Maximum number of iterations for linear warmup
            max_epochs (int): Maximum number of iterations
            warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
            eta_min (float): Minimum learning rate. Default: 0.
            last_epoch (int): The index of last epoch. Default: -1.
        """
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min

        super().__init__(optimizer, last_epoch)

    def get_lr(self) -> List[float]:
        """Compute learning rate using chainable form of the scheduler."""
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
                UserWarning,
            )

        if self.last_epoch == 0:
            return [self.warmup_start_lr] * len(self.base_lrs)
        if self.last_epoch < self.warmup_epochs:
            return [
                group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]
        if self.last_epoch == self.warmup_epochs:
            return self.base_lrs
        if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
            return [
                group["lr"]
                + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]

        return [
            (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            / (
                1
                + math.cos(
                    math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
                )
            )
            * (group["lr"] - self.eta_min)
            + self.eta_min
            for group in self.optimizer.param_groups
        ]

    def _get_closed_form_lr(self) -> List[float]:
        """Called when epoch is passed as a param to the `step` function of the scheduler."""
        if self.last_epoch < self.warmup_epochs:
            return [
                self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr in self.base_lrs
            ]

        return [
            self.eta_min
            + 0.5
            * (base_lr - self.eta_min)
            * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            for base_lr in self.base_lrs
        ]


#Offsetgenerator.py


In [6]:
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn  import functional as F
import torch

class OffsetGenerator(nn.Module):
    def __init__(self, in_prompt_dim, out_conv_shapes):
        """
        in_prompt_dim: The number of channels in the incoming prompt (e.g. prompt_dim).
        out_conv_shapes: Some descriptor of how many weights or channels
                         you need to offset (e.g. #channels_out * #channels_in * kernel_dim^2).
        """

        super().__init__()
        out_channel = 32

        # refine local patterns in the prompt
        self.conv = nn.Conv2d(in_prompt_dim, out_channel, kernel_size=3, padding=1, stride=1)

        self.global_average_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.output_conv_size = out_conv_shapes

        self.flatten_layer = nn.Linear(in_features=out_channel, out_features=self.output_conv_size)


    def forward(self, prompt):
        """
        prompt: output from PGM with shape (B, prompt_dim, H, W)
        returns: shape (B, out_conv_size)
        """

        prompt = self.conv(prompt)

        prompt = self.global_average_pool(prompt)

        prompt = prompt.view(prompt.size(0), -1)

        offset_vector = self.flatten_layer(prompt)

        offset_vector = torch.tanh(offset_vector)

        return offset_vector



#nafnet.py


In [7]:
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/megvii-research/NAFNet

'''
Simple Baselines for Image Restoration

@article{chen2022simple,
  title={Simple Baselines for Image Restoration},
  author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
  journal={arXiv preprint arXiv:2204.04676},
  year={2022}
}
'''

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init as init
# from torch.nn.modules.batchnorm import _BatchNorm
# from models.nafnet_utils import Local_Base, LayerNorm2d
# from models.cbin_weight import CBINorm_Conv2d

# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/megvii-research/NAFNet

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
            dim=0), None

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)



class AvgPool2d(nn.Module):
    def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
        super().__init__()
        self.kernel_size = kernel_size
        self.base_size = base_size
        self.auto_pad = auto_pad

        # only used for fast implementation
        self.fast_imp = fast_imp
        self.rs = [5, 4, 3, 2, 1]
        self.max_r1 = self.rs[0]
        self.max_r2 = self.rs[0]
        self.train_size = train_size

    def extra_repr(self) -> str:
        return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
            self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
        )

    def forward(self, x):
        if self.kernel_size is None and self.base_size:
            train_size = self.train_size
            if isinstance(self.base_size, int):
                self.base_size = (self.base_size, self.base_size)
            self.kernel_size = list(self.base_size)
            self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
            self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]

            # only used for fast implementation
            self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
            self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])

        if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
            return F.adaptive_avg_pool2d(x, 1)

        if self.fast_imp:  # Non-equivalent implementation but faster
            h, w = x.shape[2:]
            if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
                out = F.adaptive_avg_pool2d(x, 1)
            else:
                r1 = [r for r in self.rs if h % r == 0][0]
                r2 = [r for r in self.rs if w % r == 0][0]
                # reduction_constraint
                r1 = min(self.max_r1, r1)
                r2 = min(self.max_r2, r2)
                s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
                n, c, h, w = s.shape
                k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
                out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
                out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
        else:
            n, c, h, w = x.shape
            s = x.cumsum(dim=-1).cumsum_(dim=-2)
            s = torch.nn.functional.pad(s, (1, 0, 1, 0))  # pad 0 for convenience
            k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
            s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
            out = s4 + s1 - s2 - s3
            out = out / (k1 * k2)

        if self.auto_pad:
            n, c, h, w = x.shape
            _h, _w = out.shape[2:]
            # print(x.shape, self.kernel_size)
            pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
            out = torch.nn.functional.pad(out, pad2d, mode='replicate')

        return out

def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
    for n, m in model.named_children():
        if len(list(m.children())) > 0:
            ## compound module, go inside it
            replace_layers(m, base_size, train_size, fast_imp, **kwargs)

        if isinstance(m, nn.AdaptiveAvgPool2d):
            pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
            assert m.output_size == 1
            setattr(model, n, pool)


'''
ref.
@article{chu2021tlsc,
  title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
  author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
  journal={arXiv preprint arXiv:2112.04491},
  year={2021}
}
'''
class Local_Base():
    def convert(self, *args, train_size, **kwargs):
        replace_layers(self, *args, train_size=train_size, **kwargs)
        imgs = torch.rand(train_size)
        with torch.no_grad():
            self.forward(imgs)


class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0., modify_conv_weights=False):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
                               bias=True)
        self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True),
        )

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp, offset_vector=None):
        x = inp

        x = self.norm1(x)

        x = self.conv1(x)

        # # modifying the conv weights
        # if self.modify_conv_weights:
        #     self.conv1.weight.data = self.weights_modifier_1(self.conv1.weight, prompt)

        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)

        if offset_vector is not None:
            base_weight = self.conv5.weight
            # Correct the reshaping logic to match the offset vector size
            offset_reshaped = offset_vector.view(offset_vector.shape[0], self.conv5.out_channels, self.conv5.in_channels, self.conv5.kernel_size[0], self.conv5.kernel_size[1])

            # Apply offset to the base weights, considering the batch dimension
            # This assumes you want to apply different offsets per item in the batch
            modulated_weight = base_weight[None, ...]  + offset_reshaped


            # Perform convolution using modulated weights for each item/image in the batch
            x_list = []
            for i in range(offset_vector.shape[0]):
                x_list.append(F.conv2d(
                    x[i:i+1], modulated_weight[i],
                    bias=self.conv5.bias,
                    stride=self.conv5.stride,
                    padding=self.conv5.padding,
                    dilation=self.conv5.dilation,
                    groups=self.conv5.groups
               ))
            x = torch.cat(x_list, dim=0)
        else:
            x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma


class NAFNet(nn.Module):

    def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
        super().__init__()

        self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        chan = width
        for num in enc_blk_nums:
            self.encoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            self.downs.append(
                nn.Conv2d(chan, 2*chan, 2, 2)
            )
            chan = chan * 2

        self.middle_blks = \
            nn.Sequential(
                *[NAFBlock(chan) for _ in range(middle_blk_num)]
            )

        for num in dec_blk_nums:
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2)
                )
            )
            chan = chan // 2
            self.decoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )

        self.padder_size = 2 ** len(self.encoders)

    def forward(self, inp):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)

        x = self.intro(inp)

        encs = []

        for encoder, down in zip(self.encoders, self.downs):
            x = encoder(x)
            encs.append(x)
            x = down(x)

        x = self.middle_blks(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = up(x)
            x = x + enc_skip
            x = decoder(x)

        x = self.ending(x)
        x = x + inp

        return x[:, :, :H, :W]

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x

class NAFNetLocal(Local_Base, NAFNet):
    def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
        Local_Base.__init__(self)
        NAFNet.__init__(self, *args, **kwargs)

        N, C, H, W = train_size
        base_size = (int(H * 1.5), int(W * 1.5))

        self.eval()
        with torch.no_grad():
            self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)


def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
    """
    Create Nafnet model
    https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
    """

    net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)

    # inp_shape = (3, 256, 256)

    # from ptflops import get_model_complexity_info

    # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)

    # params = float(params[:-3])
    # macs = float(macs[:-4])

    # print(macs, params)

    return net

#InstructIR.py

# New section

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init as init
# from models.OffsetGenerator import OffsetGenerator
# from models.nafnet import NAFBlock

# ------------------------------------------------------------------------
# Copyright (c) 2023 va1shn9v. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/va1shn9v/PromptIR

'''
@inproceedings{potlapalli2023promptir,
  title={PromptIR: Prompting for All-in-One Image Restoration},
  author={Potlapalli, Vaishnav and Zamir, Syed Waqas and Khan, Salman and Khan, Fahad},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023}
}
'''
##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
    def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlock,self).__init__()
        self.prompt_dim= prompt_dim
        # prompt_dim=128: Defines the number of channels in the prompt.
        # prompt_len=5: The number of different prompts available.
        # prompt_size=96: The spatial resolution of each prompt (assumed to be square: 96×96).
        # lin_dim=192: The input dimension for the linear layer.
        # prompt_param's size = 1 * N * C * H * W where N = number of prompt components
        # prompt_param = learnable parameters
        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        # A linear layer takes an input of size "lin_dim" and produces "prompt_len" outputs.
        # This layer generates weights that determine the importance of each prompt.
        self.linear_layer = nn.Linear(lin_dim,prompt_len)
        # A 3×3 convolution with the same number of input and output channels (prompt_dim).
        # Stride =1 and padding =1 ensure the spatial size remains unchanged.
        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)

    # During training, the PromptGenBlock learns to encode prompt_len different degradations into the prompt_param tensor.
    def forward(self,x):
        # x = image feature representation
        B,C,H,W = x.shape
        # x is averaged over the last two dimensions (H, W). The output tensor is of shape (B, C) => global descriptor of the input
        emb = x.mean(dim=(-2,-1))
        # the embedding is passed through the linear layer to produce a tensor of shape (B, prompt_len)
        # softmax() ensures values in the linear layer output sum to 1 across the "prompt_len" dimension
        # meaning each prompt gets an importance weight for each sample in the batch
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
        # prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) expands the dimensions of prompt_weights to shape (B, prompt_len, 1, 1, 1).
        # self.prompt_param.unsqueeze(0).repeat(B, 1, 1, 1, 1, 1).squeeze(1):
        # Expands prompt_param to (B, prompt_len, prompt_dim, prompt_size, prompt_size), so each batch element has its own copy.
        # repeat(B, 1, 1, 1, 1, 1) replicates the prompts across the batch.
        # The two tensors are multiplied element-wise, meaning each prompt is weighted by the computed prompt_weights.
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        # The weighted sum across prompt_len produces a single prompt tensor for each batch element.
        # The new shape is (B, prompt_dim, prompt_size, prompt_size).
        prompt = torch.sum(prompt,dim=1)
        # resize the spatial dimentions of prompt (prompt_size, prompt_size) to (H, W) while keeping B and prompt_dim unchanged.
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        # The learned prompt undergoes a 3×3 convolution for further processing without changing the shape of "prompt"
        prompt = self.conv3x3(prompt)
        # final shape of prompt = (B, prompt_dim, H, W)
        return prompt
        # If multiple degradations exist in the input image feature, the returned prompt will encode a mixture of the degradations it has learned during training





##########################################################################

class ICB(nn.Module):
    """
    Instruction Condition Block (ICB)
    Paper Section 3.3
    """

    def __init__(self, feature_dim, text_dim=768):
        super(ICB, self).__init__()
        self.fc    = nn.Linear(text_dim, feature_dim)
        self.block = NAFBlock(feature_dim)
        self.beta  = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
    # f' = Block(f * mc) + f
    # mc = sigmoid(Wc * emb)
    def forward(self, x, text_embedding):
        gating_factors = torch.sigmoid(self.fc(text_embedding))
        # mc
        gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)


        f = x * self.gamma + self.beta  # 1) learned feature scaling/modulation
        f = f * gating_factors          # 2) (soft) feature routing based on text
        f = self.block(f)               # 3) block feature enhancement
        return f + x # skip connection


class InstructIR(nn.Module):
    """
    InstructIR model using NAFNet (ECCV 2022) as backbone.
    The model takes as input an RGB image and a text embedding (encoded instruction).
    Described in Paper Section 3.3
    """

    def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768, include_offset=False):
        super().__init__()

        self.intro  = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)

        self.encoders    = nn.ModuleList()
        self.decoders    = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups         = nn.ModuleList()
        self.downs       = nn.ModuleList()
        self.enc_cond    = nn.ModuleList()
        self.dec_cond    = nn.ModuleList()

        self.include_offset = include_offset

        chan = width

        # if include_offset is True:
        self.prompt_block_level1 = PromptGenBlock(prompt_dim=chan*2,prompt_len=3,prompt_size = 128,lin_dim = chan*2)
        self.prompt_block_level2 = PromptGenBlock(prompt_dim=chan*4,prompt_len=3,prompt_size = 64,lin_dim = chan*4)
        self.prompt_block_level3 = PromptGenBlock(prompt_dim=chan*8,prompt_len=3,prompt_size = 32,lin_dim = chan*8)

        # prompt_dim_level3 = chan*2**3

        self.promptBlocks = nn.ModuleList()
        self.promptBlocks.append(self.prompt_block_level3)
        self.promptBlocks.append(self.prompt_block_level2)
        self.promptBlocks.append(self.prompt_block_level1)

        # self.promptBlocks = [self.prompt_block_level3, self.prompt_block_level2, self.prompt_block_level1]

        for num in enc_blk_nums:
            #  Each encoder applies multiple NAFBlocks
            self.encoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            # Each encoding layer is modulated using a corresponding ICB
            # which incorporates the provided text embeddings.
            self.enc_cond.append(ICB(chan, txtdim))
            # Downsampling layers that reduce spatial resolution while increasing channel dimensions.
            self.downs.append(
                nn.Conv2d(chan, 2*chan, 2, 2)
            )
            chan = chan * 2
        # Middle blocks: a series of NAFBlocks applied to the deepest,
        # most abstract representation of the image features.
        # print(f"middle block chan: {chan}")
        self.middle_blks = nn.Sequential(
            *[NAFBlock(chan) for _ in range(middle_blk_num)]
        )

        naf_block = self.middle_blks[0]
        cIn = naf_block.conv5.in_channels
        cOut = naf_block.conv5.out_channels
        kernel_size = naf_block.conv5.kernel_size
        vector_size = cIn * cOut * kernel_size[0] * kernel_size[1]

        self.middleblock_offsetGen=OffsetGenerator(in_prompt_dim=chan, out_conv_shapes=vector_size)


        self.prompt_block_middle_blks = PromptGenBlock(prompt_dim=chan,prompt_len=3,prompt_size = 32,lin_dim = chan)

        # decoding path
        for num in dec_blk_nums:
            # Upsampling layers that increase the spatial resolution and
            # decrease the channel dimensions
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2)
                )
            )
            chan = chan // 2
            # sequantially processes upsampled features using multiple NAFBlocks
            self.decoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            # Add text embedding as modulation
            self.dec_cond.append(ICB(chan, txtdim))

        self.padder_size = 2 ** len(self.encoders)


        # self.offset_generators = []
        self.offset_generators = nn.ModuleList()

        # if include_offset is True:
        for i, decoder in enumerate(self.decoders):
            if(i <3):
                naf_block = decoder[0]

                cIn = naf_block.conv5.in_channels
                cOut = naf_block.conv5.out_channels
                kernel_size = naf_block.conv5.kernel_size

                vector_size = cIn * cOut * kernel_size[0] * kernel_size[1]

                # in_prompt_dim = self.promptBlocks[i].prompt_dim

                self.offset_generators.append(OffsetGenerator(in_prompt_dim=self.promptBlocks[i].prompt_dim, out_conv_shapes=vector_size))


    def forward(self, inp, txtembd):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)
        # intro = a convolutional layer to preprocess the input image
        x = self.intro(inp)
        encs = []

        for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
            x = encoder(x)
            x = enc_mod(x, txtembd)
            encs.append(x)
            x = down(x)


        if self.include_offset:
            # print("middle block include_offset")
            middle_block_prompt = self.prompt_block_middle_blks(x)
            # print("after prompt_block_level3")
            offset_vector = self.middleblock_offsetGen(middle_block_prompt)
            # print("after offset_generators[0](middle_block_prompt)")
            for naf_block in self.middle_blks:
                x = naf_block(x, offset_vector)
        else:
            x = self.middle_blks(x)

        index = 0
        for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):

            if self.include_offset is True:
                # print(f"self.include_offset: {self.include_offset}")
                x = up(x)
                # offset_vector = None
                # if (index < 3):
                #     degradation_aware_prompt = self.promptBlocks[index](x)
                #     offset_vector = self.offset_generators[index](degradation_aware_prompt)
                #     index += 1

                x = x + enc_skip

                offset_vector = None
                if (index < 3):
                    degradation_aware_prompt = self.promptBlocks[index](x)
                    offset_vector = self.offset_generators[index](degradation_aware_prompt)
                    index += 1

                if offset_vector is not None:
                    for naf_block in decoder:
                        x = naf_block(x, offset_vector)
                else:
                    x = decoder(x)

                x = dec_mod(x, txtembd)
            else:
                x = up(x)
                x = x + enc_skip
                x = decoder(x)
                x = dec_mod(x, txtembd)

        # ending = conv layer to postprocess the final decoded feature into the desired image format.
        x = self.ending(x)
        x = x + inp

        return x[:, :, :H, :W]

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x


def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768, include_offset=False):

    net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim, include_offset=include_offset)

    return net

#Text model

In [9]:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
import os

# Models that use mean pooling
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


class LanguageModel(nn.Module):
    def __init__(self, model='distilbert-base-uncased'):
        super(LanguageModel, self).__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModel.from_pretrained(model)
        self.model_name = model
        # Remove the CLIP vision tower
        if "clip" in self.model_name:
            self.model.vision_model = None
        # Freeze the pre-trained parameters (very important)
        for param in self.model.parameters():
            param.requires_grad = False

        # Make sure to set evaluation mode (also important)
        self.model.eval()

    def forward(self, text_batch):
        inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
        # Get the device of the model's parameters (assumes all parameters are on the same device)
        device = next(self.model.parameters()).device
        # Move all inputs to that device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad(): # Ensure no gradients are computed for this forward pass

            if "clip" in self.model_name:
                sentence_embedding = self.model.get_text_features(**inputs)
                return sentence_embedding

            outputs = self.model(**inputs)

        if any(model in self.model_name for model in POOL_MODELS):
            sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
            # Normalize embeddings
            sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
        else:
            sentence_embedding = outputs.last_hidden_state[:, 0, :]
        return sentence_embedding


class LMHead(nn.Module):
    def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
        super(LMHead, self).__init__()

        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        #self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        embd = self.fc1(x)
        embd = F.normalize(embd, p=2, dim=1)
        deg_pred = self.fc2(embd)
        return embd, deg_pred

#options.py

In [10]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--patch_size', type=int, default=256, help='patchsize of input.')
parser.add_argument('--instructir_batch_size', type=int, default=32, help='batch size of InstructIR.')
parser.add_argument('--dataloader_batch_size', type=int, default=32, help='batch size of dataloader.')


parser.add_argument('--epochs', type=int, default=300, help='Number of epochs for training')

parser.add_argument('--num_workers', type=int, default=8, help='number of workers.')
parser.add_argument("--checkpoint_dir",type=str, default="/content/drive/MyDrive/FYPData/train_ckpt",help = "Name of the Directory where the checkpoint is to be saved")
parser.add_argument('--lm_head',      type=str, default="/content/drive/MyDrive/FYPData/models/lm_instructir-7d.pt", help='Path to the language model weights')
parser.add_argument('--config',  type=str, default='configs/eval5d.yml', help='Path to config file')
parser.add_argument('--device',  type=int, default=0, help="GPU device")

parser.add_argument("--wblogger",type=str,default="instructir",help = "Determine to log to wandb or not and the project name")

parser.add_argument('--data_file_dir',  type=str, default="/content/drive/MyDrive/FYPData/train_data_names/", help="Files that contains the training file names")

# Reading from Google Drive Directly
parser.add_argument('--denoise_dir',  type=str, default="/content/drive/MyDrive/FYPData/Train/denoise/", help="Directory containing the images for denoise")

parser.add_argument('--dehaze_dir',  type=str, default="/content/drive/MyDrive/FYPData/Train/dehaze/", help="Directory containing the images for dehaze")

parser.add_argument('--derain_dir',  type=str, default="/content/drive/MyDrive/FYPData/Train/derain/", help="Directory containing the images for derain")


parser.add_argument('--de_type', nargs='+', default=['derain', 'denoise_15', 'dehaze', 'denoise_25', 'denoise_50'],
                    help='which type of degradations is training and testing for.')
# "denoise_15", "dehaze", "denoise_25", "denoise_50"
parser.add_argument('--trained_model_weights', type=str, default="/content/drive/MyDrive/FYPData/trained_weights", help="File name of model state_dict()")

parser.add_argument('--image_model', type=str, default="/content/drive/MyDrive/FYPData/models/im_instructir-7d.pt", help='Path to the language model weights')


parser.add_argument('--initial_lr',  type=float, default=5e-4, help="Learning Rate for optimizer")

parser.add_argument('--warmup_lr',  type=float, default=5e-6, help="Learning Rate for Scheduler")

parser.add_argument('--eta_min',  type=float, default=5e-6, help="Final Learning Rate for Scheduler")
# 30
parser.add_argument('--warmup_epochs',  type=int, default=20, help="warmup epochs for scheduler")
# 20
parser.add_argument('--chkpt_epoch',  type=int, default=20, help="Epochs for each Checkpoint ")


_StoreAction(option_strings=['--chkpt_epoch'], dest='chkpt_epoch', nargs=None, const=None, default=20, type=<class 'int'>, choices=None, required=False, help='Epochs for each Checkpoint ', metavar=None)

#Install "lightning" module

In [11]:
!pip install transformers



In [12]:
config = {
    "llm": {
        "model": "TaylorAI/bge-micro-v2",  # See Paper Sec. 3.2 and Appendix
        "model_dim": 384,
        "embd_dim": 256,
        "nclasses": 7,  # noise, blur, rain, haze, lol, enhancement, upsampling (Paper Sec. 4.3)
        "weights": False
    },
    "model": {
        "arch": "instructir",
        "use_text": True,
        "in_ch": 3,
        "out_ch": 3,
        "width": 32,
        "enc_blks": [2, 2, 4, 8],
        "middle_blk_num": 4,
        "dec_blks": [2, 2, 2, 2],
        "textdim": 256,
        "weights": False
    },
    "test": {
        "batch_size": 1,
        "num_workers": 3,
        "dn_datapath": "test-data/denoising_testsets/",
        "dn_datasets": ["CBSD68", "urban100", "Kodak24"],
        "dn_sigmas": [15, 25, 50],
        "rain_targets": ["test-data/Rain100L/target/"],
        "rain_inputs": ["test-data/Rain100L/input/"],
        "haze_targets": "test-data/SOTS/GT/",
        "haze_inputs": "test-data/SOTS/IN/",
        "lol_targets": "test-data/LOL/high/",
        "lol_inputs": "test-data/LOL/low/",
        "gopro_targets": "test-data/GoPro/target/",
        "gopro_inputs": "test-data/GoPro/input/"
    }
}


#Train Modification

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import argparse
import os
import numpy as np
import yaml
import random
import json
from datetime import datetime

# from training_utils.TrainDataset import InstructIRTrainDataset
# from training_utils.scheduler import LinearWarmupCosineAnnealingLR
# import sys
# from options import parser
# from models.instructir import InstructIR, create_model
# from text.models import LanguageModel, LMHead


def add_date_to_filename(filepath):
    # Get today's date formatted as day_month_year (e.g. 08_03_2025)
    today = datetime.today().strftime("%d_%m_%Y")
    # Split the filepath into the base and extension parts
    base, ext = os.path.splitext(filepath)
    # Construct the new filepath by inserting the date before the extension
    new_filepath = f"{base}_{today}{ext}"
    return new_filepath

def create_training_checkpoint_folder(base_folder):
    """
    Given a base folder path, this function creates a directory structure:
    base_folder/train_date_DD_MM_YYYY/train_time_HH_MM
    using today's date and the current time, then returns the new folder path.

    Args:
        base_folder (str): The path to the base checkpoint folder.

    Returns:
        str: The full path of the created training checkpoint folder.
    """
    # Format today's date as DD_MM_YYYY (e.g., "08_03_2025")
    today_date = datetime.today().strftime("%d_%m_%Y")
    # Format current time as HH_MM (e.g., "01_20")
    start_time = datetime.now().strftime("%H_%M")

    # Construct the new folder path
    new_folder_path = f"{base_folder}/train_date_{today_date}/train_time_{start_time}"

    return new_folder_path

DEST_PATH = "/content/drive/MyDrive/FYPData/train_loss/instructir_wm_train_loss.txt"

print(f"DEST_PATH: {DEST_PATH}")

def save_performance_to_file(dest_path, text):
    """
    Append the given text to the file at dest_path.
    If the file does not exist, it is created.
    If the folder does not exist, it is created.

    Args:
        dest_path (str): The path to the destination file.
        text (str): The text to be appended to the file.
    """
    # Extract the directory from the destination path.
    directory = os.path.dirname(dest_path)
    # if directory and not os.path.exists(directory):
    os.makedirs(directory, exist_ok=True)

    print(f"Saving Training Loss to {dest_path}")
    with open(dest_path, "a") as f:
        f.write(text)
        f.write("\n")
    print("Saving Training Loss done!")


def train_instructir_model(instructir_model, language_model, lm_head, train_data_loader,human_instructions, optimizer, scheduler, loss_fn, opt, device,
    accumulation_steps, image_batch_size, ini_message = "Stage One:", check_point_epoch=20):
    """
    Train the InstructIR model.
    Args:
        instructir_model: The image restoration model to train.
        language_model: A language model that has a tokenizer and a model (e.g., Hugging Face model).
        lm_head: A head applied on the language model output (e.g., to get text embeddings).
        train_data_loader: DataLoader yielding training batches.
        human_instructions: List of human instruction strings.
        optimizer: The optimizer (e.g., Adam).
        scheduler: Learning rate scheduler.
        loss_fn: Loss function (e.g., L1 loss).
        opt: An object with training options (must contain opt.epochs and opt.checkpoint_dir).
        device: The torch device (e.g., 'cuda' or 'cpu').
        accumulation_steps: Number of batches over which to accumulate gradients.
        image_batch_size: The number of images per batch (used to generate instructions).
    Returns:
        The trained instructir_model.
    """
    base = 0
    save_performance_to_file(DEST_PATH, ini_message)
    checkpoint_dir = create_training_checkpoint_folder(opt.checkpoint_dir)
    for epoch in range(opt.epochs):
        now = datetime.now()
        epoch_msg = f"{now.strftime('%Y-%m-%d %H:%M:%S')} --- Start of Epoch: {epoch + 1 + base}\n"
        print(epoch_msg)

        instructir_model.train()
        running_loss = 0.0
        # accumulated_loss = 0.0  # Initialize accumulated loss
        for batch_idx, batch in enumerate(train_data_loader):
            # Unpack batch
            [clean_name, deg_id], degrad_patch, clean_patch = batch
            # print(f"clean name: {clean_name}")
            # Get a random human instruction for each image in the batch
            human_instruction = [random.choice(human_instructions) for _ in range(image_batch_size)]

            lm_embd = language_model(human_instruction)
            lm_embd = lm_embd.to(device)
            text_embd, deg_pred = lm_head(lm_embd)
            # Forward pass through the InstructIR model
            degrad_patch = degrad_patch.to(device)
            clean_patch = clean_patch.to(device)
            restored_image = instructir_model(degrad_patch, text_embd)
            # Compute loss
            loss = loss_fn(restored_image, clean_patch)

            optimizer.zero_grad()
            # Backpropagate
            loss.backward()

            optimizer.step()

            running_loss += loss.item()

            if( batch_idx + 1) % 100 == 0:
                batch_msg = f"Epoch [{epoch + 1 + base}/{opt.epochs}], Batch [{batch_idx+1}/{len(train_data_loader)}], Running Avg Loss: {running_loss/(batch_idx+1):.20f}\n"
                # print(batch_msg)
                if( batch_idx + 1) % 250 == 0:
                    print(batch_msg)
                epoch_msg += batch_msg

        # Step the scheduler after each epoch
        if scheduler is not None:
            scheduler.step(epoch)
        avg_epoch_loss = running_loss / len(train_data_loader)
        avg_epoch_loss_str = f"Epoch [{epoch + 1 + base}/{opt.epochs }], Average Loss: {avg_epoch_loss:.20f}\n"
        epoch_msg += avg_epoch_loss_str
        print(avg_epoch_loss_str)

        # Save a checkpoint for the current epoch
        if (epoch + 1 + base) % check_point_epoch == 0:
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path_model = os.path.join(checkpoint_dir, f"64_chan_epoch_{epoch + 1 + base}_model.pt")
            checkpoint_path_scheduler = os.path.join(checkpoint_dir, f"epoch_{epoch + 1 + base}_scheduler.pt")
            checkpoint_path_optimizer = os.path.join(checkpoint_dir, f"epoch_{epoch + 1 + base}_optimizer.pt")

            torch.save(instructir_model.state_dict(), checkpoint_path_model)
            torch.save(scheduler.state_dict(), checkpoint_path_scheduler)
            torch.save(optimizer.state_dict(), checkpoint_path_optimizer)
            print(f"Checkpoint saved at {checkpoint_path_model}\n{checkpoint_path_scheduler}\n{checkpoint_path_optimizer}")
        # print end of epoch
        now = datetime.now()
        end_epoch_msg = f"{now.strftime('%Y-%m-%d %H:%M:%S')} --- End of Epoch: {epoch + 1 + base}\n"
        print(end_epoch_msg)
        epoch_msg += end_epoch_msg
        # save current epochs statistics
        save_performance_to_file(DEST_PATH, epoch_msg)
    return instructir_model

def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True

def count_params(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad is True)
    return trainable_params

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

import os
def main():
    now = datetime.now()

    # Format datetime as "YYYY-MM-DD HH:MM"
    formatted_datetime = now.strftime("%Y-%m-%d %H:%M:%S")

    training_message = f"{formatted_datetime} Training Log for Re-train InstructIR (output channel in OffsetGenertor conv = 32) - last 20 epochs\n\n"

    # torch.cuda.empty_cache()
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
    # Set seeds for reproducibility
    seed_everything()
    print(f"Torch version: {torch.__version__}")
    print(f"cuda version: {torch.version.cuda}")
    print("CUDA available:", torch.cuda.is_available())

    # Parse input arguments
    opt = parser.parse_args([])
    print("Options:")
    print(opt)


    LANGUAGE_HEAD = opt.lm_head
    CONFIG = opt.config
    IMAGE_BATCH_SIZE_DATALOADER = opt.dataloader_batch_size
    INITIAL_LEARNING_RATE = opt.initial_lr
    MAX_EPOCHS = opt.epochs
    WARMUP_EPOCHS = opt.warmup_epochs
    WARMUP_START_LR=opt.warmup_lr
    ETA_MIN = opt.eta_min
    CHECK_POINT_EPOCH = opt.chkpt_epoch
    accumulation_steps = opt.instructir_batch_size/IMAGE_BATCH_SIZE_DATALOADER

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n\n*************** device: {device} ***************\n\n")


    training_message += f"options: {opt}\n\n*************** device: {device} ***************\n\n"

    # Parse config file and convert to namespace for easy attribute access
    # with open(os.path.join(CONFIG), "r") as f:
    # config = yaml.safe_load(f)
    cfg = dict2namespace(config)

    # Load human instructions from JSON
    # print("Loading human instructions...")
    instruction_file_path = "/content/drive/MyDrive/FYPData/human_instructions.json"
    with open(instruction_file_path, "r") as f:
        data = json.load(f)
    human_instructions = data["denoising"] + data["deraining"] + data["dehazing"]
    random.shuffle(human_instructions)
    print(f"Human Instruction Set Length: {len(human_instructions)}\n")

    training_message += f"Human Instruction Set Length: {len(human_instructions)}\n"

    # Create training dataset and dataloader
    training_dataset = InstructIRTrainDataset(args=opt)
    train_data_loader = DataLoader(
        training_dataset,
        batch_size=IMAGE_BATCH_SIZE_DATALOADER,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
        num_workers=opt.num_workers,
        persistent_workers=True
    )


    training_message += f"Loading Language Model: {cfg.llm.model}"
    # Load language model and LM head
    print(f"\nLoading Language Model: {cfg.llm.model}")
    language_model = LanguageModel(model=cfg.llm.model).to(device)
    lm_head = LMHead(
        embedding_dim=cfg.llm.model_dim,
        hidden_dim=cfg.llm.embd_dim,
        num_classes=cfg.llm.nclasses
    ).to(device)

    lm_nparams = count_params(lm_head)
    print("Projection Head CKPT Path:", LANGUAGE_HEAD)
    lm_head.load_state_dict(torch.load(LANGUAGE_HEAD, map_location=device), strict=True)
    print("\nProjection Head loaded weights!", lm_nparams/1e6, "M\n")

    training_message += f"Projection Head CKPT Path: {LANGUAGE_HEAD}\nProjection Head loaded weights: {lm_nparams/1e6}M\nFreezing language model and project head"

    # Freeze parameters of the language model and LM head
    for param in language_model.parameters():
        param.requires_grad = False
    for param in lm_head.parameters():
        param.requires_grad = False

    print(f"------ Training for {opt.de_type}")

    training_message += f"------ Training for {opt.de_type}\nDataset:"

    noise_dataset_size = ""
    rain_dataset_size = ""
    haze_dataset_size = ""

    if ("denoise_15" in opt.de_type) or ("denoise_25" in opt.de_type) or ("denoise_50" in opt.de_type):
        noise_dataset_size = f"Noise sigma 15: {len(training_dataset.s15_ids)}, Noise sigma 25: {len(training_dataset.s25_ids)}, Noise sigma 50: {len(training_dataset.s50_ids)}\n"
    if "dehaze" in opt.de_type:
        haze_dataset_size = f"Haze Images Length: {len(training_dataset.hazy_ids)}\n"
    if "derain" in opt.de_type:
        rain_dataset_size = f"Rain Images Length: {len(training_dataset.rs_ids)}\n"

    training_message += noise_dataset_size
    training_message += haze_dataset_size
    training_message += rain_dataset_size




    print("----------------- Retrain InstructIR started... ----------------------")
    print("Start loading the pre-trained InstructIR model...")
    # Create the InstructIR model and move to device
    instructir_model = create_model(txtdim=256, middle_blk_num =4, width=opt.instructir_batch_size, include_offset=True)
    instructir_model = instructir_model.to(device)
    print("middle_block = 4")

    im_checkpoint_path = "/content/drive/MyDrive/FYPData/models/im_instructir-7d.pt"


    # model_file_path = os.path.join(opt.trained_model_weights, "instructir_weights_3d_12_03_2025.pt")
    # model_file_path = opt.image_model
    instructir_stage_one_path = im_checkpoint_path
    print("\nInstructIR Path:", instructir_stage_one_path, "\n")
    instructir_model.load_state_dict(torch.load(instructir_stage_one_path, map_location=device), strict=False)

    training_message += f"\nInstructIR Path: {instructir_stage_one_path}\n"

    instructir_model_nparams = count_params(instructir_model)
    total_params_msg = f"\n---- InstructIR + Modifcation: {instructir_model_nparams/1e6}M parameters\n"

    for param in instructir_model.parameters():
        param.requires_grad = False

  # enable prompt generator blocks
    for prompt_block in instructir_model.promptBlocks:
        for param in prompt_block.parameters():
            param.requires_grad = True

  # enable offset generators
    for offset_gen in instructir_model.offset_generators:
        for param in offset_gen.parameters():
            param.requires_grad = True

    for param in instructir_model.prompt_block_middle_blks.parameters():
        param.requires_grad = True

    for param in instructir_model.middleblock_offsetGen.parameters():
        param.requires_grad = True



    # instructir_model_nparams = count_params(instructir_model)
    # total_params_msg = f"\n---- InstructIR + Modifcation: {instructir_model_nparams/1e6}M parameters\n"

    params_to_train = list(instructir_model.promptBlocks.parameters()) + list(instructir_model.offset_generators.parameters()) + \
        list(instructir_model.prompt_block_middle_blks.parameters()) + list(instructir_model.middleblock_offsetGen.parameters())

    mod_params_msg = f"\n---- OffsetGenerators + PromptGenBlocks: {sum(p.numel() for p in params_to_train if p.requires_grad is True)/1e6}M parameters\n"

    print(total_params_msg, mod_params_msg)

    training_message += total_params_msg
    training_message += mod_params_msg

    assert count_params(instructir_model) == sum(p.numel() for p in params_to_train if p.requires_grad is True)

    scheduler = None
    # Set up new optimizer and learning rate scheduler
    optimizer = optim.AdamW(params_to_train,  lr=INITIAL_LEARNING_RATE)
    scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,  warmup_epochs=WARMUP_EPOCHS, max_epochs=MAX_EPOCHS,
      warmup_start_lr=WARMUP_START_LR, eta_min=ETA_MIN )

    training_params = f"""\nLearning Rate: {INITIAL_LEARNING_RATE}, Image Model Batch_size: {opt.instructir_batch_size}, Dataloader Batch_size: {IMAGE_BATCH_SIZE_DATALOADER}, Acccumulation steps: {accumulation_steps}\n"""

    scheduler_setting = f"Scheduler: {str(scheduler)}\nmax_epochs: {MAX_EPOCHS}, warmup_start_lr:{WARMUP_START_LR}, eta_min: {ETA_MIN})"

    training_message += training_params
    training_message += scheduler_setting

    print(training_params, scheduler_setting)

    if scheduler is None:
        print("\nNot using scheduler")

    #################################################### Load optimizer, scheduler, instructIR model ##############################

    # optimizer_checkpoint_path = "/content/drive/MyDrive/FYPData/trained_weights/epoch_280_optimizer.pth"
    # scheduler_checkpoint_path = "/content/drive/MyDrive/FYPData/trained_weights/epoch_280_scheduler.pth"

    # # Load the optimizer checkpoint with map_location set to device
    # optimizer.load_state_dict(torch.load(optimizer_checkpoint_path, map_location=device))

    # # Load the scheduler checkpoint with map_location set to device
    # scheduler.load_state_dict(torch.load(scheduler_checkpoint_path, map_location=device))

    # training_message += f"Loaded optimizer checkpoint at: {optimizer_checkpoint_path}\nLoaded scheduler checkpoint at: {scheduler_checkpoint_path}"


    #################################################### Load optimizer, scheduler, instructIR model ##############################


    # Loss function
    loss_fn = nn.L1Loss()

    # train for the offset generator and promptblock
    instructir_model = train_instructir_model(instructir_model=instructir_model, language_model=language_model, lm_head=lm_head,
        train_data_loader=train_data_loader, human_instructions=human_instructions, optimizer=optimizer, scheduler=scheduler,
        loss_fn=loss_fn, opt=opt, device=device, accumulation_steps=accumulation_steps, image_batch_size=IMAGE_BATCH_SIZE_DATALOADER,
        ini_message=training_message, check_point_epoch=CHECK_POINT_EPOCH
    )

    # Save the final model weights
    model_weights_file_name = "instructir_with_weight_modulation_32ch.pt"
    os.makedirs(opt.trained_model_weights, exist_ok=True)
    model_file_path = os.path.join(opt.trained_model_weights, model_weights_file_name)
    torch.save(instructir_model.state_dict(), model_file_path)
    print(f"Retraining finished, final model weights saved at: {model_file_path}\n")


if __name__ == "__main__":
    main()


DEST_PATH: /content/drive/MyDrive/FYPData/train_loss/instructir_wm_train_loss.txt
Torch version: 2.6.0+cu124
cuda version: 12.4
CUDA available: True
Options:
Namespace(patch_size=256, instructir_batch_size=32, dataloader_batch_size=32, epochs=300, num_workers=8, checkpoint_dir='/content/drive/MyDrive/FYPData/train_ckpt', lm_head='/content/drive/MyDrive/FYPData/models/lm_instructir-7d.pt', config='configs/eval5d.yml', device=0, wblogger='instructir', data_file_dir='/content/drive/MyDrive/FYPData/train_data_names/', denoise_dir='/content/drive/MyDrive/FYPData/Train/denoise/', dehaze_dir='/content/drive/MyDrive/FYPData/Train/dehaze/', derain_dir='/content/drive/MyDrive/FYPData/Train/derain/', de_type=['derain', 'denoise_15', 'dehaze', 'denoise_25', 'denoise_50'], trained_model_weights='/content/drive/MyDrive/FYPData/trained_weights', image_model='/content/drive/MyDrive/FYPData/models/im_instructir-7d.pt', initial_lr=0.0005, warmup_lr=5e-06, eta_min=5e-06, warmup_epochs=20, chkpt_epoch=20)