### Main

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Standard library imports
import random
import json
# Third-party imports
from PIL import Image
import torch
from torch import nn
from torch.nn import init as init
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
from transformers import AutoModel, AutoTokenizer
from torchmetrics.functional import peak_signal_noise_ratio
from tabulate import tabulate
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs_lm = 10
batch_size = 16
num_epochs = 500
checkpoint_dir="/home/cvpr_ug_2/abhinavinstructir/checkpoint"

############======== IMAGE MODEL =========############
class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1) # Split tensor into 2 equal parts along the channel dimension
        return x1 * x2

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): # expansion ratio for the depthwise convolution
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) # groups=1 means basic simple convolution (not grouped/depthwise)
        self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
                               bias=True)
        self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # Takes an input feature map of shape (N, C, H, W), reduces each channel to a 1×1 spatial size by computing the avg over H×W --> (N, C, 1, 1)
            nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True),
        )

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp

        x = self.norm1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma


class NAFNet(nn.Module):
    def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
        super().__init__()

        self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)

        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()

        chan = width
        for num in enc_blk_nums: 
            self.encoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            self.downs.append(
                nn.Conv2d(chan, 2*chan, 2, 2)
            )
            chan = chan * 2

        self.middle_blks = \
            nn.Sequential(
                *[NAFBlock(chan) for _ in range(middle_blk_num)]
            )

        for num in dec_blk_nums:
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2) #Rearranges feature maps from channels into spatial resolution:
                    #Input: shape (B, chan*2, H, W) Output: (B, chan/2, 2H, 2W)
                )
            )
            chan = chan // 2
            self.decoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )

        self.padder_size = 2 ** len(self.encoders)

    def forward(self, inp):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)

        x = self.intro(inp)

        encs = []

        for encoder, down in zip(self.encoders, self.downs):
            x = encoder(x)
            encs.append(x)
            x = down(x)

        x = self.middle_blks(x)

        for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
            x = up(x)
            x = x + enc_skip
            x = decoder(x)

        x = self.ending(x)
        x = x + inp

        return x[:, :, :H, :W]

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x


class Local_Base():
    def convert(self, *args, train_size, **kwargs):
        replace_layers(self, *args, train_size=train_size, **kwargs)
        imgs = torch.rand(train_size)
        with torch.no_grad():
            self.forward(imgs)

class NAFNetLocal(Local_Base, NAFNet):
    def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
        Local_Base.__init__(self)
        NAFNet.__init__(self, *args, **kwargs)

        N, C, H, W = train_size
        base_size = (int(H * 1.5), int(W * 1.5))

        self.eval()
        with torch.no_grad():
            self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)


def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
    """
    Create Nafnet model
    https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
    """
    
    net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
    
    return net

# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Source: https://github.com/megvii-research/NAFNet



class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
    


class AvgPool2d(nn.Module):
    def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
        super().__init__()
        self.kernel_size = kernel_size
        self.base_size = base_size
        self.auto_pad = auto_pad

        # only used for fast implementation
        self.fast_imp = fast_imp
        self.rs = [5, 4, 3, 2, 1]
        self.max_r1 = self.rs[0]
        self.max_r2 = self.rs[0]
        self.train_size = train_size

    def extra_repr(self) -> str:
        return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
            self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
        )

    def forward(self, x):
        if self.kernel_size is None and self.base_size:
            train_size = self.train_size
            if isinstance(self.base_size, int):
                self.base_size = (self.base_size, self.base_size)
            self.kernel_size = list(self.base_size)
            self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
            self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]

            # only used for fast implementation
            self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
            self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])

        if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
            return F.adaptive_avg_pool2d(x, 1)

        if self.fast_imp:  # Non-equivalent implementation but faster
            h, w = x.shape[2:]
            if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
                out = F.adaptive_avg_pool2d(x, 1)
            else:
                r1 = [r for r in self.rs if h % r == 0][0]
                r2 = [r for r in self.rs if w % r == 0][0]
                # reduction_constraint
                r1 = min(self.max_r1, r1)
                r2 = min(self.max_r2, r2)
                s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
                n, c, h, w = s.shape
                k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
                out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
                out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
        else:
            n, c, h, w = x.shape
            s = x.cumsum(dim=-1).cumsum_(dim=-2)
            s = torch.nn.functional.pad(s, (1, 0, 1, 0))  # pad 0 for convenience
            k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
            s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
            out = s4 + s1 - s2 - s3
            out = out / (k1 * k2)

        if self.auto_pad:
            n, c, h, w = x.shape
            _h, _w = out.shape[2:]
            # print(x.shape, self.kernel_size)
            pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
            out = torch.nn.functional.pad(out, pad2d, mode='replicate')

        return out

def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
    for n, m in model.named_children():
        if len(list(m.children())) > 0:
            ## compound module, go inside it
            replace_layers(m, base_size, train_size, fast_imp, **kwargs)

        if isinstance(m, nn.AdaptiveAvgPool2d):
            pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
            assert m.output_size == 1
            setattr(model, n, pool)

class ICB(nn.Module):
    """
    Instruction Condition Block (ICB)
    Paper Section 3.3
    """

    def __init__(self, feature_dim, text_dim=768):
        super(ICB, self).__init__()
        self.fc    = nn.Linear(text_dim, feature_dim)
        self.block = NAFBlock(feature_dim)
        self.beta  = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)

    def forward(self, x, text_embedding):
        gating_factors = torch.sigmoid(self.fc(text_embedding))
        gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)

        f = x * self.gamma + self.beta  # 1) learned feature scaling/modulation
        f = f * gating_factors          # 2) (soft) feature routing based on text
        f = self.block(f)               # 3) block feature enhancement
        return f + x


class InstructIR(nn.Module):
    """
    InstructIR model using NAFNet (ECCV 2022) as backbone.
    The model takes as input an RGB image and a text embedding (encoded instruction).
    Described in Paper Section 3.3
    """

    def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
        super().__init__()

        self.intro  = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)
        self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
                              bias=True)

        self.encoders    = nn.ModuleList()
        self.decoders    = nn.ModuleList()
        self.middle_blks = nn.ModuleList()
        self.ups         = nn.ModuleList()
        self.downs       = nn.ModuleList()
        self.enc_cond    = nn.ModuleList()
        self.dec_cond    = nn.ModuleList()

        chan = width
        for num in enc_blk_nums:
            self.encoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            
            self.enc_cond.append(ICB(chan,txtdim))

            self.downs.append(
                nn.Conv2d(chan, 2*chan, 2, 2)
            )
            chan = chan * 2

        self.middle_blks = nn.Sequential(
                *[NAFBlock(chan) for _ in range(middle_blk_num)]
            )

        for num in dec_blk_nums:
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2)
                )
            )
            chan = chan // 2
            self.decoders.append(  
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(num)]
                )
            )
            # Add text embedding as modulation
            self.dec_cond.append(ICB(chan,txtdim))

        self.padder_size = 2 ** len(self.encoders)

    def forward(self, inp, txtembd):
        B, C, H, W = inp.shape
        inp = self.check_image_size(inp)

        x = self.intro(inp)
        encs = []

        for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
            x = encoder(x)
            x = enc_mod(x,txtembd)
            encs.append(x)
            x = down(x)

        x = self.middle_blks(x)

        for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
            x = up(x)
            x = x + enc_skip
            x = decoder(x)
            x = dec_mod(x,txtembd)

        x = self.ending(x)
        x = x + inp

        return x[:, :, :H, :W]

    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x


def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):

    net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)

    return net

############======== LANGUAGE MODEL =========############


# Models that use mean pooling
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


class LanguageModel(nn.Module):
    def __init__(self, model='TaylorAI/bge-micro-v2'):
        super(LanguageModel, self).__init__()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
        self.er = AutoTokenizer.from_pretrained(model)
        self.model = AutoModel.from_pretrained(model).to(device)
        self.model_name = model
        # Remove the CLIP vision tower
        if "clip" in self.model_name:
            self.model.vision_model = None
            
        # Freeze the pre-trained parameters (very important)
        for param in self.model.parameters():
            param.requires_grad = False

        # Make sure to set evaluation mode (also important)
        self.model.eval()

    def forward(self, text_batch):
        inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad(): # Ensure no gradients are computed for this forward pass

            if "clip" in self.model_name:
                sentence_embedding = self.model.get_text_features(**inputs)
                return sentence_embedding

            outputs = self.model(**inputs)

        if any(model in self.model_name for model in POOL_MODELS):
            sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
            # Normalize embeddings
            sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
        else:
            sentence_embedding = outputs.last_hidden_state[:, 0, :]
        return sentence_embedding
    

class LMHead(nn.Module):
    def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
        super(LMHead, self).__init__()
        
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        #self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        embd = self.fc1(x)
        embd = F.normalize(embd, p=2, dim=1)
        deg_pred = self.fc2(embd)
        return embd, deg_pred


###############=============PROMPT DATALOADER============##############
class TaskPromptDataset(Dataset):
    """Dataset for prompts from a single task"""
    def __init__(self, prompts, task_id):
        self.samples = [(prompt.strip(), task_id) for prompt in prompts]
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text, label = self.samples[idx]
        return text, torch.tensor(label, dtype=torch.long)

class MultiTaskPromptLoader:
    """Dictionary-style loader that provides batched prompts per task"""
    def __init__(self, json_path, batch_size=16):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)    
        
        # Map between task names, JSON keys, and task IDs
        self.task_config = {
            'blur': {'json_key': 'blur', 'id': 2},
            'haze': {'json_key': 'haze', 'id': 4},
            'noise': {'json_key': 'noise', 'id': 0},
            'lol': {'json_key': 'lol', 'id': 3},
            'rain': {'json_key': 'rain', 'id': 1}
        }
        
        self.loaders = {}
        for task_name, config in self.task_config.items():
            # Get prompts for this task from JSON
            prompts = raw_data.get(config['json_key'], [])
            
            # Create dataset and loader
            dataset = TaskPromptDataset(prompts, config['id'])
            loader = DataLoader(
                dataset, 
                batch_size=batch_size,
                shuffle=True,
                num_workers=4,
                collate_fn=self._collate_fn,
                pin_memory=True,
                drop_last=True
            )
            self.loaders[task_name] = loader
            
        self.iterators = {task: iter(loader) for task, loader in self.loaders.items()}

    def _collate_fn(self, batch):
        """Custom collate to handle text data"""
        texts, labels = zip(*batch)
        return list(texts), torch.stack(labels)

    def get_batch(self, key):
        """Get a batch from specified task"""
        try:
            return next(self.iterators[key])
        except StopIteration:
            self.iterators[key] = iter(self.loaders[key])
            return next(self.iterators[key])
        
        
###############=============IMAGE DATALOADER============###############
class PairedTransform:
    """Handles paired image transformations"""
    def __init__(self, size=(256, 256)):
        self.size = size

    def __call__(self, input_img, target_img):
        # Resize if smaller than target size
        if input_img.height < self.size[0] or input_img.width < self.size[1]:
            input_img = TF.resize(input_img, self.size)
            target_img = TF.resize(target_img, self.size)
        else:
            # Random crop
            i, j, h, w = transforms.RandomCrop.get_params(input_img, self.size)
            input_img = TF.crop(input_img, i, j, h, w)
            target_img = TF.crop(target_img, i, j, h, w)

        # Random horizontal flip
        if torch.rand(1) < 0.5:
            input_img = TF.hflip(input_img)
            target_img = TF.hflip(target_img)

        # Random vertical flip
        if torch.rand(1) < 0.5:
            input_img = TF.vflip(input_img)
            target_img = TF.vflip(target_img)

        return TF.to_tensor(input_img), TF.to_tensor(target_img)

class TaskDataset(Dataset):
    """Dataset for a single task"""
    def __init__(self, root_dir, transform=None):
        self.input_dir = os.path.join(root_dir, 'input')
        self.target_dir = os.path.join(root_dir, 'target')
        self.filenames = [f for f in os.listdir(self.input_dir) 
                         if os.path.isfile(os.path.join(self.input_dir, f))]
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        input_img = Image.open(os.path.join(self.input_dir, self.filenames[idx])).convert('RGB')
        target_img = Image.open(os.path.join(self.target_dir, self.filenames[idx])).convert('RGB')

        if self.transform:
            return self.transform(input_img, target_img)
        return TF.to_tensor(input_img), TF.to_tensor(target_img)

class MultiTaskLoader:
    """Dictionary-style loader that provides access to all tasks"""
    def __init__(self, root_dir, batch_size=16, transform=None):
        self.tasks = {
            'blur': TaskDataset(os.path.join(root_dir, 'deblurring_dataset'), transform),
            'haze': TaskDataset(os.path.join(root_dir, 'dehazing_dataset'), transform),
            'lol': TaskDataset(os.path.join(root_dir, 'lol_dataset'), transform),
            'noise': TaskDataset(os.path.join(root_dir, 'noise_dataset'), transform),
            'rain': TaskDataset(os.path.join(root_dir, 'rainy_image_dataset'), transform)
        }
        
        self.loaders = {
            task: DataLoader(dataset, batch_size=batch_size, 
                            shuffle=True, num_workers=4, pin_memory=True,drop_last=True)
            for task, dataset in self.tasks.items()
        }
        self.iterators = {task: iter(loader) for task, loader in self.loaders.items()}

    def __getitem__(self, key):
        return self.loaders[key]

    def get_batch(self, key):
        """Get a batch from specified task"""
        try:
            return next(self.iterators[key])
        except StopIteration:
            self.iterators[key] = iter(self.loaders[key])
            return next(self.iterators[key])
        

######################====== ACCURACY CALCULATION ======#############################
def calculate_accuracy(prompt_loader, language_model, lm_head, device):
    """Calculate classification accuracy across all tasks"""
    lm_head.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        # Iterate through all tasks
        for task_name, loader in prompt_loader.loaders.items():
            # Create fresh iterator for each task
            task_iterator = iter(loader)
            
            # Process all batches for this task
            while True:
                try:
                    texts, labels = next(task_iterator)
                    labels = labels.to(device)
                    
                    # Forward pass
                    embeddings = language_model(texts)
                    _, logits = lm_head(embeddings)
                    
                    # Calculate accuracy
                    preds = torch.argmax(logits, dim=1)
                    total_correct += (preds == labels).sum().item()
                    total_samples += labels.shape[0]
                    
                except StopIteration:
                    break  # Move to next task
    
    accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0.0
    return accuracy

######################====== PSNR CALCULATION ======#############################
# Fixed prompts for evaluation
FIXED_PROMPTS = {
    'noise': ["Help me with my picture, it's full of tiny spots."],
    'blur': ["Please, clean up this blurry photo."],
    'rain': ["Remove the streaks of falling rain from my photo."],
    'lol': ["Brighten the dark regions in this image without overexposing the highlights."],
    'haze': ["Remove the atmospheric haze from this image."]
}

def evaluate_model(image_model, lm_head, image_loader, device):
    """Calculate PSNR across all tasks using fixed prompts"""
    image_model.eval()
    lm_head.eval()
    
    psnr_results = []
    table_data = []
    
    with torch.no_grad():
        for task in image_loader.tasks.keys():
            # Get fixed prompt embedding
            texts = FIXED_PROMPTS[task]
            embeddings = language_model(texts)
            text_embd,_  = lm_head(embeddings)
            
            # Get sample batch
            inputs, targets = image_loader.get_batch(task)
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Process images
            outputs = image_model(inputs, text_embd)
            
            # Calculate PSNR
            psnr = peak_signal_noise_ratio(outputs, targets).item()
            psnr_results.append(psnr)
            table_data.append([task, f"{psnr:.2f} dB"])
    
    # Print formatted table
    print("\n" + tabulate(table_data, headers=["Task", "PSNR"], tablefmt="grid"))
    return psnr_results

def save_visuals(image_model, lm_head, image_loader, epoch, device):
    """Save 5x2 visualization grid"""
    image_model.eval()
    lm_head.eval()
    
    _, axs = plt.subplots(5, 2, figsize=(10, 25))
    
    with torch.no_grad():
        for row_idx, task in enumerate(image_loader.tasks.keys()):
            # Get fixed prompt
            texts = FIXED_PROMPTS[task]
            embeddings = language_model(texts)
            text_embd,_ = lm_head(embeddings)
            
            # Get sample images
            inputs, targets = image_loader.get_batch(task)
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Process images
            outputs = image_model(inputs, text_embd)
            
            # Convert to numpy
            input_img = inputs[0].cpu().numpy().transpose(1, 2, 0)
            output_img = outputs[0].cpu().numpy().transpose(1, 2, 0)
            
            # Plot
            axs[row_idx, 0].imshow(input_img)
            axs[row_idx, 0].set_title(f"{task} Input")
            axs[row_idx, 1].imshow(output_img)
            axs[row_idx, 1].set_title(f"{task} Output")
            
    plt.tight_layout()
    vis_path = os.path.join(VIS_DIR, f"epoch_{epoch}.png")
    plt.savefig(vis_path)
    plt.close()
    print(f"Saved visualization to {vis_path}")



###############========Language Model Head TRAINING LOOP===========################

if __name__ == "__main__":

    transform = PairedTransform(size=(256, 256))
    image_loader = MultiTaskLoader(root_dir='/home/cvpr_ug_2/abhinavDomainIR/Dataset', batch_size=batch_size,transform=transform)
    prompt_loader = MultiTaskPromptLoader('/home/cvpr_ug_2/abhinavDomainIR/prompts.json', batch_size=batch_size)
    ######################====== LANGUAGE MODEL INITIALIZATION ======#############################
    LMODEL = 'TaylorAI/bge-micro-v2'
    language_model = LanguageModel(model=LMODEL).to(device).eval() # Keep frozen
    lm_head = LMHead(embedding_dim=384, hidden_dim=256, num_classes=5).to(device)
    # Visualization config
    VIS_DIR = os.path.join(checkpoint_dir, "visualizations")
    os.makedirs(VIS_DIR, exist_ok=True)
    ######################====== IMAGE MODEL INITIALIZATION ======#############################
    image_model = create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 4, dec_blks = [2, 2, 2, 2], txtdim=256)
    image_model = image_model.to(device)
    
    ##############=======Parameter Count===========##########   
    def count_trainable_parameters(model):
        """
        Counts the number of trainable parameters in a PyTorch model.
    
        Args:
            model: PyTorch model
        
        Returns:
            total_params: Total number of trainable parameters
        """
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return total_params
    ################======== Count trainable parameters===========################
    trainable_params = count_trainable_parameters(image_model)
    print(f"Total trainable parameters: {trainable_params:,}")
    ######################====== Loss and Optimizer ======#############################
    criterion_class = nn.CrossEntropyLoss()
    
    # LM Head-specific optimizer
    optimizer_lm = torch.optim.AdamW(lm_head.parameters(), lr=5e-4)

    print("Starting Stage 1: LM Head Training")
    for epoch in range(1,num_epochs_lm+1):
        
        lm_head.train()
        for _ in range(1,2000,batch_size):
        
            # Random task selection
            task = random.choice(['blur', 'haze', 'noise', 'lol', 'rain'])
        
            # Get batch
            texts,labels = prompt_loader.get_batch(task)
            labels = labels.to(device)

            # Forward pass through frozen language model
            with torch.no_grad():
                embeddings = language_model(texts)

             
            # LM head forward
            _, logits = lm_head(embeddings)
            lm_loss = criterion_class(logits, labels)
            
            # Backpropagate
            optimizer_lm.zero_grad()
            lm_loss.backward()
            optimizer_lm.step()            
    
        train_acc = calculate_accuracy(prompt_loader, language_model, lm_head, device) 
        print(f"Epoch [{epoch}/{num_epochs_lm}] | " f"Train Acc: {train_acc:.2f}%")

    #################MAIN TRAINING LOOP#################


    criterion_class = nn.CrossEntropyLoss()
    criterion_image = nn.L1Loss()


    optimizer_lm = optim.Adam(lm_head.parameters(), lr=1e-3)
    
    # Unified AdamW optimizer with learning rate 5e-4
    optimizer = torch.optim.AdamW([
        {'params': lm_head.parameters()},
        {'params': image_model.parameters()}
        ], lr=5e-4)
    
    # Cosine annealing scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    print("Starting Stage 2: Main Training")
    # Main training loop
    for epoch in range(1,num_epochs+1):
        lm_head.train()
        image_model.train()
    
        for _ in range(0,2000,batch_size):

            # Random task selection
            task = random.choice(['blur', 'haze', 'noise', 'lol', 'rain'])
        
            # --- Language Model Head Training ---
            texts,labels = prompt_loader.get_batch(task)
            inputs, targets = image_loader.get_batch(task)

            # Move data to device
            labels = labels.to(device)
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # --- Forward passes ---
            # Text embeddings (frozen language model)
            with torch.no_grad():
                embeddings = language_model(texts)
            
            # LM head forward
            text_embd, logits = lm_head(embeddings)

            # Image model forward
            outputs = image_model(inputs, text_embd)
            
            # --- Loss calculation ---
            loss_image = criterion_image(outputs, targets)
            loss_class = criterion_class(logits, labels)
            total_loss = loss_image + 0.3 * loss_class  # Combined loss

            # --- Backward pass and optimize ---
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        # Update learning rate
        scheduler.step()    
        
        # Epoch evaluation
        print(f"\nEpoch {epoch} Evaluation:")
    
        # 1. LM Head Accuracy
        train_acc = calculate_accuracy(prompt_loader, language_model, lm_head, device)
        print(f"Classification Accuracy: {train_acc:.2f}%")

        # 2. PSNR Table
        psnr_values = evaluate_model(image_model, lm_head, image_loader, device)
    
        if epoch % 50 == 0:
            # Save models
            checkpoint_path = os.path.join(checkpoint_dir, f"joint_epoch_{epoch}.pth")
            torch.save({
                'image_model': image_model.state_dict(),
                'lm_head': lm_head.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }, checkpoint_path)
        
            # Save visuals
            save_visuals(image_model, lm_head, image_loader, epoch, device)
            print(f"Saved checkpoint and visuals for epoch {epoch}")


In [None]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure
from tabulate import tabulate
import csv

# Use the same device configuration as main
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your existing components
from main import (
    LanguageModel,
    LMHead,
    create_model,
    PairedTransform
)

class EvalDataset(Dataset):
    def __init__(self, input_dir, target_dir):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.filenames = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
        
        # Use same transforms as training
        self.transform = PairedTransform(size=(256, 256))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.filenames[idx])
        target_path = os.path.join(self.target_dir, self.filenames[idx])
        
        input_img = Image.open(input_path).convert('RGB')
        target_img = Image.open(target_path).convert('RGB')
        
        if self.transform:
            input_tensor, target_tensor = self.transform(input_img, target_img)
        else:
            input_tensor = transforms.ToTensor()(input_img)
            target_tensor = transforms.ToTensor()(target_img)
            
        return input_tensor, target_tensor, self.filenames[idx]

class ModelEvaluator:
    def __init__(self, checkpoint_path):
        # Initialize models
        self.language_model = LanguageModel(model='TaylorAI/bge-micro-v2').to(device).eval()
        self.lm_head = LMHead(embedding_dim=384, hidden_dim=256, num_classes=5).to(device)
        image_model = create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 4, dec_blks = [2, 2, 2, 2], txtdim=256)
        image_model = image_model.to(device)
        self.image_model = image_model
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.image_model.load_state_dict(checkpoint['image_model'])
        self.lm_head.load_state_dict(checkpoint['lm_head'])
        self.image_model.eval()
        self.lm_head.eval()
        
        # Fixed prompts from main
        self.task_prompts = {
            'denoising': ["Help me with my picture, it's full of tiny spots."],
            'deblurring': ["Please, clean up this blurry photo."],
            'deraining': ["Remove the streaks of falling rain from my photo."],
            'lowlight': ["Brighten the dark regions in this image without overexposing the highlights."],
            'dehazing': ["Remove the atmospheric haze from this image."]
        }

    def get_task_embeddings(self, task_name):
        texts = self.task_prompts[task_name]
        with torch.no_grad():
            embeddings = self.language_model(texts)
            text_embd, logits = self.lm_head(embeddings)
        return text_embd, logits

    def evaluate_dataset(self, input_dir, target_dir, task_name):
        dataset = EvalDataset(input_dir, target_dir)
        loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
        
        psnr_values = []
        ssim_values = []
        
        text_embd,_ = self.get_task_embeddings(task_name)
        
        with torch.no_grad():
            for inputs, targets, filenames in loader:
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                outputs = self.image_model(inputs, text_embd)
                
                # Calculate metrics
                batch_psnr = peak_signal_noise_ratio(outputs, targets)
                batch_ssim = structural_similarity_index_measure(outputs, targets)
                
                psnr_values.append(batch_psnr.item())
                ssim_values.append(batch_ssim.item())
                
        return np.mean(psnr_values), np.mean(ssim_values)

def generate_report(checkpoint_path, dataset_paths, output_file="evaluation_report.csv"):
    evaluator = ModelEvaluator(checkpoint_path)
    
    results = []
    
    for task_name, datasets in dataset_paths.items():
        for dataset in datasets:
            target_dir, input_dir, *extra = dataset
            noise_level = extra[0] if extra else ""
            
            print(f"Evaluating {task_name} - {noise_level if noise_level else ''}")
            
            psnr, ssim = evaluator.evaluate_dataset(input_dir, target_dir, task_name)
            
            results.append({
                'Task': task_name,
                'Dataset': os.path.basename(input_dir),
                'Noise Level': noise_level,
                'PSNR': f"{psnr:.2f}",
                'SSIM': f"{ssim:.4f}"
            })
    
    # Save to CSV
    with open(output_file, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=['Task', 'Dataset', 'Noise Level', 'PSNR', 'SSIM'])
        writer.writeheader()
        writer.writerows(results)
    
    # Print summary table
    table_data = [[r['Task'], r['Dataset'], r['Noise Level'], r['PSNR'], r['SSIM']] for r in results]
    print("\nEvaluation Summary:")
    print(tabulate(table_data, headers=["Task", "Dataset", "Noise Level", "PSNR", "SSIM"], tablefmt="grid"))

# Configuration
# Dataset paths
dataset_paths = {
    "denoising": [
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/CBSD68", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/CBSD68_15", "15"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/CBSD68", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/CBSD68_25", "25"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/CBSD68", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/CBSD68_50", "50"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/Kodak24", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/Kodak24_15", "15"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/Kodak24", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/Kodak24_25", "25"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/Kodak24", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/Kodak24_50", "50"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/urban100", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/urban100_15", "15"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/urban100", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/urban100_25", "25"],
        ["/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/urban100", "/kaggle/input/image-restoration-test-data/test-data/denoising_testsets/urban100_50", "50"]
    ],
    "deblurring": [
        ["/kaggle/input/image-restoration-test-data/test-data/GoPro/target", "/kaggle/input/image-restoration-test-data/test-data/GoPro/input", "gopro"]
    ],
    "deraining": [
        ["/kaggle/input/image-restoration-test-data/test-data/Rain100L/target", "/kaggle/input/image-restoration-test-data/test-data/Rain100L/input", "rain100l"]
    ],
    "lowlight": [
            ["/kaggle/input/image-restoration-test-data/test-data/LOL/high", "/kaggle/input/image-restoration-test-data/test-data/LOL/low", "lol"]
    ],
    "dehazing":[
        ["/kaggle/input/image-restoration-test-data/test-data/SOTS/GT","/kaggle/input/image-restoration-test-data/test-data/SOTS/IN","sots"]
    ]
    }

In [None]:
if __name__ == "__main__":
    checkpoint_path = "/home/cvpr_ug_2/abhinavinstructir/checkpoint/joint_epoch_400.pth"  # Update this
    generate_report(checkpoint_path, dataset_paths)