In [None]:
import os 
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import List
import torchvision.transforms as transforms
from tqdm import tqdm
import random
import torch
import torch.functional as F
    


In [1]:
import os
import torch
import psutil
from util import format
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import DeepLenseSuperresolutionDataset
import matplotlib.pyplot as plt 
from torchinfo import summary
from util import MSE_Metric, PSNR_Metric, SSIM_Metric
import math

# PyTorch imports
import torch.nn as nn
import torch.nn.functional as F

class CONFIG:
    BATCH_SIZE = 4
    
    # limit the data to prototype faster
    DATA_LIMIT = 200
    
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    ROUND_NUMBER = 3
    TASK_NAME = "DeepLense2024_task2A"
    DATA_PATH = os.path.join("Data", "Superresolution")
    PORTION_OF_DATA_FOR_TRAINING = 0.8

In [None]:
class DeepLenseSuperresolutionDatasetLapSRN(Dataset):

    def __init__(self, folder_path : str,
                 randomize_dataset : bool = True,
                 preprocess_LR : bool = True, 
                 preprocess_HR : bool = True,
                 call_preprocess : bool = True,
                 data_limit=0, 
                 mean_LR = None, std_LR=None, 
                 mean_HR=None, std_HR=None) -> None:
        
        self.folder_path = folder_path
        self.class_folders = []
        self.preprocess_LR = preprocess_LR
        self.preprocess_HR = preprocess_HR

        folders = [os.path.join(self.folder_path, v) for v in os.listdir(folder_path)]

        self.LR = [v for v in folders if v.endswith("LR")][0]
        self.HR = [v for v in folders if v.endswith("HR")][0]

        self.class_folders = [self.LR, self.HR]
        
        print(self.LR, self.HR)
        assert os.listdir(self.LR) == os.listdir(self.HR), "the number of samples in Low Resolution has to be the same as High Resolution"

        # get the samples 
        self.samples = os.listdir(self.LR)
        
        # limit the data (for faster prototyping )
        if data_limit > 0:
            self.samples = self.samples[:data_limit]
                
        # Datapoints
        self.LR_data = []
        self.HR_data = []
            
        pbar = tqdm(self.samples)
        for path in pbar:
            # load from the low resolution
            img1 = np.load(os.path.join(self.LR, path))
            self.LR_data.append(torch.Tensor(img1))            
            
            # load from the high resolution
            img2 = np.load(os.path.join(self.HR, path))
            self.HR_data.append(torch.Tensor(img2))
            
            pbar.set_description("Loading dataset : ")
        
        self.samples = np.array(self.samples)
        self.LR_data = torch.stack(self.LR_data)
        self.HR_data = torch.stack(self.HR_data)
        
        if randomize_dataset:
            self.randomize_dataset()
        
        
    # To override later (if any preprocessing is required)
    def preprocess_LR_func(self, x : torch.Tensor) -> torch.Tensor:
        return x
        
    def preprocess_HR_func(self, x : torch.Tensor) -> torch.Tensor:
        return x
        
    def randomize_dataset(self):
        idxes = np.arange(len(self.LR_data))
        random.shuffle(idxes)

        self.samples = self.samples[idxes]
        self.LR_data = self.LR_data[idxes]
        self.HR_data = self.HR_data[idxes]
    
    def preprocess_input(self, x : np.ndarray) -> torch.Tensor:
        return torch.tensor( (x - self.mean) / self.std).float()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.LR_data[idx], self.HR_data[idx]
    

In [None]:
def test_task2(model : nn.Module, val_dataset : DataLoader, cfg : CONFIG, metrics : List[Metric], test_params : Dict[str, int] = {"save_in_total" : None, "save_every" : 0} ,  run = None):        
    # change the model to evaluation
    model.eval()
    
    # get the number of datapoints
    number_of_datapoints = len(val_dataset.dataset)    

    # allocate the memory for these datapoints (no need to keep appending the data, which will make it slower)
    metrics_names = [metric.name for metric in metrics]
    metrics_vals = np.zeros((number_of_datapoints, len(metrics_names)))

    save_images_every = 0
    if "save_in_total" in test_params and test_params["save_in_total"] is not None and test_params["save_in_total"] > 0:
        save_images_every = number_of_datapoints // test_params["save_in_total"]
        
    elif "save_every" in test_params:
        save_images_every = test_params["save_every"]

    if save_images_every > 0:
        shape = val_dataset.dataset[0][1].shape
        
        saved_images_pred = np.zeros((number_of_datapoints // save_images_every, shape[1], shape[2], shape[0]))
        saved_images_true = np.zeros((number_of_datapoints // save_images_every, shape[1], shape[2], shape[0]))
        img_c = 0
    
    # get the number of batches
    dataset_len = len(val_dataset)

    # create the progreess bar 
    pbar = tqdm(val_dataset)

    # variable that will track where we are in terms of all data (after iteration add batch size to it)
    c = 0
    for i, (x,y) in enumerate(pbar): 
        # get the predictions
        pred = model(x.to(CONFIG.DEVICE))
        y = y.to(CONFIG.DEVICE)
 
        # get the batch size
        bs = x.shape[0]

        # calculate the metric for every image in the batch:
        for img_i in range(bs):
            y_pred, y_ = pred[img_i], y[img_i]
            for j, metric in enumerate(metrics):
                metrics_vals[c, j] = metric.eval(torch.stack([y_pred]), torch.stack([y_]))
                                    
            if save_images_every > 0 and c % save_images_every == 0:
                saved_images_pred[img_c] = y_pred.detach().cpu().numpy().transpose(1, 2, 0)
                saved_images_true[img_c] = y_.detach().cpu().numpy().transpose(1, 2, 0)
                img_c += 1
                
            c += 1
                  
        if i % max((dataset_len//10),1) == 0 or i == dataset_len -1:
            s = ""

            for i,metric in enumerate(metrics):
                if metric.average:
                    s += f"{metric.name}={np.mean(metrics_vals[:(c-1), i])} ; "

            pbar.set_description(f"examples seen so far : {c} " + s)
 
    ret = {}
    
    for i,metric in enumerate(metrics):
        ret[metric.name] = metrics_vals[:, i]   
        
    if save_images_every > 0:
        ret["img_pred"] = saved_images_pred
        ret["img_true"] = saved_images_true
    
    return ret 

def report_metrics_task2(results : Dict, epoch : int, metrics : List[Metric], WANDB_ON : bool = True, prefix="val", run=None) -> Dict:
    
    ret = {}
    for metric in metrics:
        if metric.average:
            avg = np.average(results[metric.name])
            name_to_save = f"{prefix}_{metric.name}"
            ret[name_to_save] = avg

            if WANDB_ON:
                wandb.log({name_to_save : avg})
    
    if "img_pred" in results and "img_true" in results:
        size = results["img_pred"].shape
        imgs = []
        
        for b in range(size[0]):
            img = np.concatenate([results["img_pred"][b], results["img_true"][b]], axis=1)
            img_to_save = wandb.Image(img, caption="Left: predicted, right : true")
            wandb.log({f"Epoch={epoch}" : img_to_save})
            
    return ret
 
def run_experiment_task2(train_dataloader : torch.utils.data.DataLoader,