# Group 4 Project Version 1 Submission


## Paper Information and Our Information
### **Paper Title:** SeD Semantic-Aware Discriminator for Image Super-Resolution
### **Paper Description:** 
### Github Repository: [Link](https://github.com/YigitEkin/sed)
<img src="img/framework.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>
In this work, researchers highlight the use of Generative Adversarial Networks (GANs) for image super-resolution tasks, particularly focusing on texture recovery. They note a limitation in existing methods where a single discriminator is employed to teach the super-resolution network the distribution of high-quality real-world images, leading to coarse learning and unexpected output. To address this, they introduce a Semantic-aware Discriminator (SeD), which incorporates image semantics to guide the network in learning fine-grained image distributions.

The SeD leverages image semantics extracted from a trained semantic model, allowing the discriminator to discern real and fake images based on different semantic conditions. By integrating semantic features into the discriminator using spatial cross-attention modules, they aim to enhance the SR network's ability to generate more realistic and visually appealing images. The approach capitalizes on pretrained vision models and extensive datasets to enrich the understanding of image semantics and improve the fidelity of super-resolved images.
         
### **Authors:**  Yigit Ekin and Mustafa Utku Aydogdu
### **Mail:** e270207@metu.edu.tr e270206@metu.edu.tr

## Hyper-parameters of your model

We aim to compare the effect of SeD discriminator with vanilla discriminator. As a result, we have two different training setups. Before reading the hyperparameters, please note that the hyperparameters are the same for both models except for the discriminator part. In addition, the losses used for the model can be seen from the image below where L_s is VGG  perceptual loss, L_p is the pixelwise MSE loss and L_adv is the adverserial loss.

<img src="img/losses.png"> <br/> <br/>
The hyperparameters of the models are as follows:

### Vanilla Discriminator
- **Accelerator**: 'gpu'
- **Device**: 'cuda'
- **PL Trainer**:
  - `max_epochs`: 1000
  - `accelerator`: 'gpu'
  - `log_every_n_steps`: 50
  - `strategy`: DDPStrategy(find_unused_parameters=True)
  - `devices`: Number of available CUDA devices (determined by `torch.cuda.device_count()`)
  - `sync_batchnorm`: True
- **Train Batch Size**: 16
- **Validation Batch Size**: 8
- **Test Batch Size**: 8
- **Image Size**: 256
- **Dataset Module**:
  - `num_workers`: 4
  - `train_batch_size`: 16
  - `val_batch_size`: 8
  - `test_batch_size`: 8
  - **Train Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/dataset_cropped/hr"
    - `image_dir_lr`: "data/dataset_cropped/lr"
    - `downsample_factor`: 4 (downsampling factor for low-resolution images)
    - `mirror_augment_prob`: 0.5 (probability of applying mirroring w.r.t. y axis as a data augmentation)
  - **Validation Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/evaluation/hr/manga109"
    - `image_dir_lr`: "data/evaluation/lr/manga109"
  - **Test Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/evaluation/hr/manga109"
    - `image_dir_lr`: "data/evaluation/lr/manga109"
- **Losses**:
  - **VGG**:
    - `weight`: 5e-5 
    - `model_config`:
      - `path`: "pretrained_models/vgg16.pth"
      - `output_layer_idx`: 23 (index of the layer to extract features from)
      - `resize_input`: False
  - **Adversarial_G**:
    - `weight`: 1.0
  - **MSE**:
    - `weight`: 1.0
  - **Adversarial_D**:
    - `r1_gamma`: 10.0 (constant for wasserstein GP)
    - `r2_gamma`: 0.0 (constant for wasserstein GP)
- **Super Resolution Module Configuration**:
  - `generator_learning_rate`: 1e-4
  - `discriminator_learning_rate`: 1e-5
  - `generator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `discriminator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `generator_decay_gamma`: 0.5
  - `discriminator_decay_gamma`: 0.5
  - `clip_generator_outputs`: False (whether to clip generator outputs to valid pixel range [-1,1])
  - `use_sed_discriminator`: False (whether to use SeD discriminator)

### SeD Discriminator
- **Accelerator**: 'gpu'
- **Device**: 'cuda'
- **PL Trainer**:
  - `max_epochs`: 1000
  - `accelerator`: 'gpu'
  - `log_every_n_steps`: 50
  - `strategy`: DDPStrategy(find_unused_parameters=True)
  - `devices`: Number of available CUDA devices (determined by `torch.cuda.device_count()`)
  - `sync_batchnorm`: True
- **Train Batch Size**: 16
- **Validation Batch Size**: 8
- **Test Batch Size**: 8
- **Image Size**: 256
- **Dataset Module**:
  - `num_workers`: 4
  - `train_batch_size`: 16
  - `val_batch_size`: 8
  - `test_batch_size`: 8
  - **Train Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/dataset_cropped/hr"
    - `image_dir_lr`: "data/dataset_cropped/lr"
    - `downsample_factor`: 4
    - `mirror_augment_prob`: 0.5 (probability of applying mirroring w.r.t. y axis as a data augmentation)
  - **Validation and Test Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/evaluation/hr/manga109"
    - `image_dir_lr`: "data/evaluation/lr/manga109"
- **Losses**:
  - **VGG**:
    - `weight`: 5e-5
    - `model_config`:
      - `path`: "pretrained_models/vgg16.pth"
      - `output_layer_idx`: 23 (index of the layer to extract features from)
      - `resize_input`: False
  - **Adversarial_G**:
    - `weight`: 1.0
  - **MSE**:
    - `weight`: 1.0
  - **Adversarial_D**:
    - `r1_gamma`: 10.0 (constant for wasserstein GP)
    - `r2_gamma`: 0.0 (constant for wasserstein GP)
- **Super Resolution Module Configuration**:
  - `generator_learning_rate`: 1e-4
  - `discriminator_learning_rate`: 1e-5
  - `generator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `discriminator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `generator_decay_gamma`: 0.5
  - `discriminator_decay_gamma`: 0.5
  - `clip_generator_outputs`: False (whether to clip generator outputs to valid pixel range [-1,1])
  - `use_sed_discriminator`: True (whether to use SeD discriminator)

## Training and saving of the model.

### Training with SeD

#### **IMPORTANT NOTE:** the training of the model is done on a remote server where we have not used jupyter notebook. Normally, scripts in the first 3 cells are used to train the model. However, in order to not overly crowd the jupyter notebook for the reviewers, we have included the code that is responsible for training but the training logs will be displayed in the last cell of this section named as training loop which abstracts all this logic

PLEASE DO NOT CHANGE THE FILE STRUCTURE THAT THE SUBMISSION HAS PROVIDED. THIS CAN CAUSE ERRORS IN THE TRAINING OF THE MODEL.

### TRAINING CONFIG

```python
import torch
from pytorch_lightning.strategies import DDPStrategy

accelerator = 'gpu'
device = torch.device("cuda") if accelerator=="gpu" else torch.device("cpu")
if accelerator == 'cpu':
    pl_trainer = dict(max_epochs=1000, accelerator=accelerator, log_every_n_steps=50, strategy=DDPStrategy(find_unused_parameters=True), devices=1, sync_batchnorm=True) # CHECK sync_batchnorm in this and below part !!!
else:
    pl_trainer = dict(max_epochs=1000, accelerator=accelerator, log_every_n_steps=50, strategy=DDPStrategy(find_unused_parameters=True), devices=torch.cuda.device_count(), sync_batchnorm=True)  # CHECK strategy and find_unused_parameters!!!

train_batch_size = 16
val_batch_size = 8
test_batch_size = 8

image_size = 256


###########################
##### Dataset Configs #####
###########################

dataset_module = dict(
    num_workers=4,
    train_batch_size=train_batch_size,
    val_batch_size=val_batch_size,
    test_batch_size=test_batch_size,
    train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0.5),
    val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
    test_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
)

##################
##### Losses #####
##################
vgg_ckpt_path="pretrained_models/vgg16.pth"
loss_dict = dict(
    VGG=dict(weight=5e-5, model_config=dict(path=vgg_ckpt_path, output_layer_idx=23, resize_input=False)),
    Adversarial_G=dict(weight=1.0),
    MSE=dict(weight=1.0),
    Adversarial_D=dict(r1_gamma=10.0, r2_gamma=0.0)
)

#########################
##### Model Configs #####
#########################

super_resolution_module_config = dict(loss_dict=loss_dict, 
    generator_learning_rate=1e-4, discriminator_learning_rate=1e-5, 
    generator_decay_steps=[50_000, 100_000, 150_000, 200_000, 250_000], 
    discriminator_decay_steps=[50_000, 100_000, 150_000, 200_000, 250_000], 
    generator_decay_gamma=0.5, discriminator_decay_gamma=0.5,
    clip_generator_outputs=False,
    use_sed_discriminator=True)

#######################
###### Callbacks ######
#######################

ckpt_callback = dict(every_n_train_steps=4000, save_top_k=1, save_last=True, monitor='fid_test', mode='min')
synthesize_callback_train = dict(num_samples=12, eval_every=2000) # TODO: 4000
synthesize_callback_test = dict(num_samples=6, eval_every=2000)
fid_callback = dict(eval_every=4000)
```

### TRAIN.PY training logic implementation for SeD

```python
#TRAIN.PY training logic implementation
import pytorch_lightning as pl
from datasets.dataset_module import DatasetModule
from models.super_resolution_module import SuperResolutionModule
from argparse import ArgumentParser
from utils.config_utils import parse_config
import pytorch_lightning.loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from callbacks.logger.image_logger import ImageLoggerCallback
from callbacks.eval.fid import FIDCallback
from pytorch_lightning.utilities import rank_zero_only
import os
import shutil
import torch
import numpy as np
import random
from datetime import datetime

# import warnings
# warnings.filterwarnings('ignore')

def seed_all(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic=True

@rank_zero_only
def log_config_file(config):
    os.makedirs(config.log_path, exist_ok=True)
    shutil.copyfile(args.config_file, os.path.join(config.log_path, "config.py"))


def extend_config_parameters(config, is_debug):
    experiment_name = args.config_file.split("/")[-1].split(".")[0]
    now = datetime.now()
    date_time = now.strftime("%Y-%m-%d_%H-%M-%S")
    experiment_name = f'{date_time}_{experiment_name}'
    config["experiment_name"] = experiment_name
    if is_debug:
        config["log_path"] = os.path.join('logs', '.debug', experiment_name)
    else:
        config["log_path"] = os.path.join('logs', experiment_name)
    config.ckpt_callback["dirpath"] = os.path.join(config["log_path"], 'checkpoint')

def train(args):
    config = parse_config(args.config_file)
    extend_config_parameters(config, is_debug=args.debug)
    model = SuperResolutionModule(**config.super_resolution_module_config)

    ckpt_path = None
    if args.resume_from is not None:
        ckpt_path = os.path.join(args.resume_from, "checkpoint", "last.ckpt")
        config.log_path = args.resume_from
        config.ckpt_callback.dirpath = os.path.join(config.log_path, 'checkpoint')
            
    log_config_file(config)

    data_module = DatasetModule(**config.dataset_module)
    data_module.setup('training')
    data_module.setup('test')

    train_dataloader = data_module.train_dataloader()
    if not args.debug:
        test_dataloader = data_module.test_dataloader()

    csv_logger = pl_loggers.CSVLogger(save_dir=config.log_path, flush_logs_every_n_steps=50)
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=config.log_path+"/tensorboard")

    
    ckpt_callback = ModelCheckpoint(**config.ckpt_callback)
    lr_monitor_callback = LearningRateMonitor(logging_interval="step")
    synthesize_callback_train = ImageLoggerCallback(data_module.train_dataset, "training", **config.synthesize_callback_train)
    synthesize_callback_test = ImageLoggerCallback(data_module.test_dataset, "test", **config.synthesize_callback_test)
    
    if not args.debug:
        fid_callback_test = FIDCallback(test_dataloader, dataset_type="test", **config.fid_callback)

    if not args.debug:
        trainer = pl.Trainer(logger=[csv_logger, tb_logger], 
                            callbacks=[ fid_callback_test, 
                                        synthesize_callback_train, 
                                        synthesize_callback_test, lr_monitor_callback,
                                        ckpt_callback], 
                            **config.pl_trainer)
    else:
        trainer = pl.Trainer(logger=[csv_logger, tb_logger], 
                            callbacks=[synthesize_callback_train, 
                                        synthesize_callback_test, 
                                        lr_monitor_callback, ckpt_callback], 
                            **config.pl_trainer)

    seed_all(seed=0)
    trainer.fit(model, train_dataloaders=train_dataloader, ckpt_path=ckpt_path)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--config_file', type=str, default='configs/default_config.py', help='Path to config file')
    parser.add_argument('--resume_from', type=str, default=None, help='Log folder of the model to be resumed')
    parser.add_argument('--debug', action='store_true', help='Start in debugging mode')
    args = parser.parse_args()
    train(args)
```

### PYTORCH LIGHTNING ALLOWS US TO USE CALLBACK FUNCTIONS DURING TRAINING. HENCE, WE HAVE USED CALLBACKS TO SAVE WEIGHTS OF THE MODEL. THE CALLBACK IS THE FOLLOWING:

```python
from pytorch_lightning.callbacks import ModelCheckpoint
```

### Training loop (next cell is used to train the model)


In [None]:
CFG="configs/patchgan_sed.py" #Training of patchgan discriminator with SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | generator     | RRDBNet                   | 15.4 M
1 | discriminator | PatchDiscriminatorWithSeD | 4.7 M 
2 | clip          | CLIPRN50                  | 23.4 M
------------------------------------------------------------
20.1 M    Trainable params
23.4 M    Non-trainable params
43.5 M    Total params
173.867   Total estimated model para

In [None]:
CFG="configs/patchgan.py" #Training of patchgan discriminator without SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | generator     | RRDBNet            | 15.4 M
1 | discriminator | PatchDiscriminator | 2.8 M 
2 | clip          | CLIPRN50           | 23.4 M
-----------------------------------------------------
18.2 M    Trainable params
23.4 M    Non-trainable params
41.5 M    Total params
166.180   Total estimated model params size (MB)
SLURM auto-requeueing enabled

In [2]:
CFG="configs/pixelwise_sed.py" #Training of pixelwise discriminator with SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | generator     | RRDBNet                   | 15.4 M
1 | discriminator | PatchDiscriminatorWithSeD | 4.7 M 
2 | clip          | CLIPRN50                  | 23.4 M
------------------------------------------------------------
20.1 M    Trainable params
23.4 M    Non-trainable params
43.5 M    Total params
173.867   Total estimated model para

### The same training logic applies to this as well did not execute due to incident gökberk hoca is informed about the situtation

In [None]:
CFG="configs/pixelwise.py" #Training of pixelwise discriminator without SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

## Loading a pre-trained model and computing qualitative samples/outputs from that model.

### **IMPORTANT! loss curves and outputs of the training process is logged under the directory /logs during training loop in order to display the results, one can execute the command below at the terminal in order to display the results**

```bash
tensorboard --logdir=logs/<experiment_name_under_the_directory>
```

### To make reviewer's job easier, we have provided the code needed to load a pretrained model and compute qualitative samples from the model but added the tensorboard logs from the training loop that had been executed from the training cells from the previous section.

In [14]:
from models.super_resolution_module import SuperResolutionModule
from datasets.dataset_module import DatasetModule
from tqdm import tqdm
from PIL import Image
import torch 
import numpy as np

def postprocess_image(image, min_val=-1.0, max_val=1.0):
    image = image.astype(np.float64)
    image = np.clip(image, -1, 1)
    image = (image - min_val) * 255 / (max_val - min_val)
    image = image.astype(np.uint8)
    # image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
    image = image.transpose(1, 2, 0)
    return image

torch.manual_seed(1256)
np.random.seed(1256)
ckpt="logs/2024-05-06_00-34-33_patchgan_sed/checkpoint/epoch=4-step=80000.ckpt"
ckpt2="logs/2024-05-06_12-26-29_pixelwise_sed/checkpoint/epoch=2-step=44000.ckpt"
ckpt3="logs/2024-05-06_01-21-09_pixelwise/checkpoint/epoch=4-step=80000.ckpt"
ckpt4="/scratch/users/hpc-yekin/hpc_run/SeD/logs/2024-05-06_00-35-57_patchgan/checkpoint/epoch=4-step=80000.ckpt"

model = SuperResolutionModule.load_from_checkpoint(ckpt) 

train_batch_size = 2  # given as temporary data
val_batch_size = 2 # given as temporary data
test_batch_size = 2 # given as temporary data

model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 256
dataset_module = dict(
    num_workers=4,
    train_batch_size=train_batch_size,
    val_batch_size=val_batch_size,
    test_batch_size=test_batch_size,
    train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0.5),
    val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
    test_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
)


data_module_gt = DatasetModule(**dataset_module)
data_module_gt.setup('test')
dataloader = data_module_gt.test_dataloader()


cnt = 0
for batch in tqdm(dataloader, desc=f"Calculating FID on SR images", total=len(dataloader)):
    sr_images = model.make_high_resolution(batch)
    sr_images = sr_images ['generated_super_resolution_image'].to(device)
    #save the sr images to the "sr_pngs" folder
    for i in range(len(sr_images)):
        img = sr_images[i]
        
        img = postprocess_image(img.detach().cpu().numpy())
        img = Image.fromarray(img)
        img.save(f"patchgansed/{cnt}.png")
        cnt += 1

Calculating FID on SR images: 100%|██████████| 55/55 [00:06<00:00,  8.65it/s]


### Results and logs of training SeD patchwise discriminator
### Log curves of the training processes loss functions

#### adverserial loss of generator per epoch
<img src="img/adv_g_sed.png">

#### adverserial loss of discriminator per epoch
<img src="img/adv_d_sed.png">

#### mse loss per epoch
<img src="img/mse_sed.png">

#### vgg perceptual loss per epoch
<img src="img/vgg_sed.png">

#### results of training (input | ground truth | output) (it is important note that the training dataset has size 10 and it is trained 219 epochs)
<img src="img/sed.png">

### Results and logs of training patchwise discriminator
#### adverserial loss of generator per epoch
<img src="img/adv_g.png">

#### adverserial loss of discriminator per epoch
<img src="img/adv_d.png">

#### mse loss per epoch
<img src="img/mse.png">

#### vgg perceptual loss per epoch
<img src="img/vgg.png">

#### results of training (input | ground truth | output) (it is important note that the training dataset has size 10 and it is trained 80 epochs)
<img src="img/without_sed.png">

## Reproducing results

as explained in the goals.txt an updated version will be provided with the trained model weights as soon as possible. The incident that happened is known by gökberk hoca.

In [None]:
"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.
This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License. To view a copy of this license, visit
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""

#WE HAVE IMPLEMENTED THIS CODE BLOCK BY USING THE REFERENCE AT THE TOP AS A GUIDANCE

import torch
import numpy as np
from tqdm import tqdm
from datasets.dataset_module import DatasetModule
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
import torch
import numpy as np
from tqdm import tqdm
from losses.lpips.lpips import LPIPS
import torch.nn.functional as F

def print_metrics_given_path(path):
    print("calculating metrics for " + path)
    train_batch_size = 2  # given as temporary data
    val_batch_size = 2 # given as temporary data
    test_batch_size = 2 # given as temporary data
    
    
    ################ lpips
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lpips_model = LPIPS(net_type='alex', device=device).to('cpu')
    lpips_model.eval()
    image_size = 256
    dataset_module_gt = dict(
        num_workers=4,
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size,
        test_batch_size=test_batch_size,
        train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0),
        val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
        test_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
    )
    
    dataset_module_gt = DatasetModule(**dataset_module_gt)
    dataset_module_gt.setup('test')
    first_dataloader = dataset_module_gt.test_dataloader()
    
    
    dataset_module_sr = dataset_module = dict( #UPDATE DIRS
        num_workers=4,
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size,
        test_batch_size=test_batch_size,
        train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0),
        val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
        test_dataset_config=dict(image_size=256, image_dir_hr=path, image_dir_lr="data/evaluation/lr/manga109/"),
    )
    
    data_module_sr = DatasetModule(**dataset_module_sr)
    data_module_sr.setup('test')
    second_dataloader = data_module_sr.test_dataloader()
    
    def get_lpips_mean(dataloader1,dataloader2,lpips_model,device,dataset_type):
        lpips_model.to(device)
        lpips_list = []
        with torch.no_grad():
            for batch1,batch2 in tqdm(zip(dataloader1,dataloader2), desc=f"Calculating {dataset_type} LPIPS on sr images", total=len(dataloader1)):
                gt_images = batch1["image_hr"].to(device) * 0.5 + 0.5
                sr_images = batch2["image_hr"].to(device) * 0.5 + 0.5
                lpips = lpips_model(gt_images, sr_images, return_similarity=True)
                lpips_list.append(lpips.cpu())
        lpips_list = torch.cat(lpips_list).numpy()
        lpips_mean = np.nanmean(lpips_list)
        lpips_model.to('cpu')
        return lpips_mean
    
    
    
    lpips_mean = get_lpips_mean(first_dataloader,second_dataloader,lpips_model,device,"lpips")
    
    print("lpips: ",lpips_mean)
    
    #### SSIM
    
    def ssim(img1, img2):
        # Calculate SSIM (Structural Similarity Index)
        ssim_val = torch.mean((2 * img1 * img2 + 1e-8) * (2 * torch.abs(img1 - img2) + 1e-8) / (img1**2 + img2**2 + 1e-8), dim=(1, 2, 3))
        return ssim_val
    
    def get_ssim_mean(dataloader1,dataloader2,ssim,device,dataset_type):
        ssim_list = []
        with torch.no_grad():
            for batch1,batch2 in tqdm(zip(dataloader1,dataloader2), desc=f"Calculating {dataset_type} SSIM on sr images", total=len(dataloader1)):
                gt_images = batch1["image_hr"].to(device) * 0.5 + 0.5
                sr_images = batch2["image_hr"].to(device) * 0.5 + 0.5
                ssim_val = ssim(sr_images, gt_images)
                ssim_list.append(ssim_val.cpu())
        ssim_list = torch.cat(ssim_list).numpy()
        ssim_mean = np.nanmean(ssim_list)
        return ssim_mean
    
    ssim_mean = get_ssim_mean(first_dataloader,second_dataloader,ssim,device,"ssim")
    print("ssim: ",ssim_mean)
    
    #### PSNR
    
    def psnr(img1, img2, max_val=1.0):
        # Convert images to float tensors
        img1 = img1.float()
        img2 = img2.float()
        
        max_val = img1.max()
        # Calculate MSE (Mean Squared Error)
        mse = F.mse_loss(img1, img2)
        
        # Calculate PSNR (Peak Signal-to-Noise Ratio)
        psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
        
        return psnr.item()
    
    def get_psnr_mean(dataloader1,dataloader2,device,dataset_type):
        psnr_list = []
        with torch.no_grad():
            for batch1,batch2 in tqdm(zip(dataloader1,dataloader2), desc=f"Calculating {dataset_type} PSNR on sr images", total=len(dataloader1)):
                gt_images = batch1["image_hr"].to(device) * 0.5 + 0.5
                sr_images = batch2["image_hr"].to(device) * 0.5 + 0.5
                psnr_val = psnr(sr_images, gt_images)
                psnr_list.append(psnr_val)
        psnr_mean = np.nanmean(psnr_list)
        return psnr_mean
    
    
    psnr_mean = get_psnr_mean(first_dataloader,second_dataloader,device,"psnr")
    print("psnr: ",psnr_mean)

print_metrics_given_path("patchgan/")
print_metrics_given_path("patchgansed/")
print_metrics_given_path("pixelwise/")
print_metrics_given_path("pixelwise_sed/")


calculating metrics for patchgan/


Calculating lpips LPIPS on sr images: 100%|██████████| 55/55 [00:01<00:00, 41.71it/s]


lpips:  0.6864215


Calculating ssim SSIM on sr images: 100%|██████████| 55/55 [00:01<00:00, 42.31it/s]

ssim:  0.37364906



Calculating psnr PSNR on sr images: 100%|██████████| 55/55 [00:01<00:00, 41.70it/s]


psnr:  7.709990440715443
calculating metrics for patchgansed/


## Challenges we have encountered when implementing the paper

Implementing a super resolution model based solely on a paper, without access to the accompanying code, was challenging due to the complexities of understanding and implementing the loss function, architecture, and performance metrics described in the paper. Dealing with dimensionality inconsistencies in paper. Some are listed below.


### Our Assumptions:
* we assumed that the group normalization has 32 groups (not stated in the paper)
* we assumed that the conv block in patchwise discriminator is a  convolution block that doubles the channel size and with kernel_size of 4, stride=2 and padding=1 followed by a batch normalization block followed by a leaky relu block (not included in the last convolution block) which is not stated in the paper.
* They did not specified the adverserial loss function details. As a result, we have decided to go with wassertein loss with gradient penalty to achieve a more stable training.
* They did not specify how they have preprocessed the dataset. As a result, due to small number of images in the dataset, we have decided to conduct a literature survey on how different models have overcome this issue and found that ESRGAN does combine 2 datasets and crops random patches from each image to increase the number of images.
* We have decided to move with crop size of 400 for hr images and 100 for lr images. This means that during training our model inputs 100x100 crops and tries to generate 400x400 hr version of it.
* For cross attention, we have decided to use single head attention rather than multi head attention
* CLIP preprocessor normally downscales the image to 224x224 before extracting embeddings. We believed that this can downgrade the performance w.r.t hr images as a result, we did not use this preprocessor.
* To obtain same spatial dimensionality with the clip embeddings (for concatenation specified in the image below in part d), we added extra convolution layer that did not change the channel size but decreases the spatial dimensions.
* The authors did not describe the weight (lambda) values of the loss functions as a result, we have decided to go with 1 for mse and 10 for gradient penalty in wasserstein loss