In [1]:
# 1. Clone the repository
!git clone https://github.com/AhsanBaidar/MuLA_GAN.git

# 2. Navigate into the folder
%cd MuLA_GAN

# 3. Mount your Google Drive
from google.colab import drive
drive.mount('/content/drive')

Cloning into 'MuLA_GAN'...
remote: Enumerating objects: 125, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 125 (delta 30), reused 78 (delta 15), pack-reused 28 (from 1)[K
Receiving objects: 100% (125/125), 144.94 MiB | 34.59 MiB/s, done.
Resolving deltas: 100% (33/33), done.
/content/MuLA_GAN
Mounted at /content/drive


In [2]:
%%writefile utils/data_utils.py
import os
import glob
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import yaml

class Dataloader(Dataset):
    def __init__(self, root, dataset_name, transforms_=None, config_path="configs/train_MuLA-GAN.yaml"):
        self.transform = transforms.Compose(transforms_)
        try:
            with open(config_path) as f:
                cfg = yaml.load(f, Loader=yaml.FullLoader)
            input_folder_rel = cfg.get("TRAIN_INPUT", 'train/trainA')
            gt_folder_rel = cfg.get("TRAIN_GT", 'train/trainB')
            base_path = cfg.get("dataset_path", root)
        except Exception as e:
            print(f"Error reading config file {config_path} in Dataloader: {e}")
            raise FileNotFoundError(f"Could not read/parse paths from {config_path}")

        input_path = os.path.join(base_path, input_folder_rel)
        gt_path = os.path.join(base_path, gt_folder_rel)

        print("--- Dataloader Init ---")
        print(f"Base Path (from config): {base_path}")
        print(f"Constructed Input Path: {input_path}")
        print(f"Constructed GT Path: {gt_path}")

        self.filesA = sorted(glob.glob(input_path + "/*.*"))
        self.filesB = sorted(glob.glob(gt_path + "/*.*"))

        print(f"Found {len(self.filesA)} input files.")
        print(f"Found {len(self.filesB)} GT files.")

        if not self.filesA or not self.filesB:
             print("Error: Did not find files in one or both directories.")
             self.len = 0
        else:
            self.len = min(len(self.filesA), len(self.filesB))

        if self.len == 0:
             print("Critical Error: Dataset length is 0. Cannot train.")

        print(f"Setting dataset length to: {self.len}")
        print("-----------------------")

    def __getitem__(self, index):
        if self.len == 0:
             raise IndexError("Dataset is empty, cannot get item.")
        actual_index = index % self.len
        try:
            img_A_path = self.filesA[actual_index]
            img_B_path = self.filesB[actual_index]
            img_A = Image.open(img_A_path).convert('RGB')
            img_B = Image.open(img_B_path).convert('RGB')
        except Exception as e:
            print(f"Error opening image at index {actual_index}: {e}")
            dummy_tensor = torch.zeros((3, 256, 256))
            return {"A": dummy_tensor, "B": dummy_tensor}

        if np.random.random() < 0.5:
          try:
                img_A_np = np.array(img_A)
                img_B_np = np.array(img_B)
                if len(img_A_np.shape) == 3 and len(img_B_np.shape) == 3:
                     img_A = Image.fromarray(img_A_np[:, ::-1, :], "RGB")
                     img_B = Image.fromarray(img_B_np[:, ::-1, :], "RGB")
          except Exception as e:
                print(f"Error during random flip at index {actual_index}: {e}")

        try:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
        except Exception as e:
            print(f"Error transforming image at index {actual_index}: {e}")
            dummy_tensor = torch.zeros((3, 256, 256))
            return {"A": dummy_tensor, "B": dummy_tensor}

        return {"A": img_A, "B": img_B}

    def __len__(self):
        return self.len

Overwriting utils/data_utils.py


In [3]:
%%writefile configs/train_MuLA-GAN.yaml
# --- Training Configuration for MuLA-GAN on LSUI (Split) ---
dataset_path: '/content/drive/MyDrive/LSUI_Split' # Base path to new split data
TRAIN_INPUT: 'train/trainA'      # Relative path to blurry train
TRAIN_GT: 'train/trainB'         # Relative path to sharp train
TEST_INPUT: 'test/testA'         # Relative path to blurry test
TEST_GT: 'test/testB'            # Relative path to sharp test
dataset_name: "LSUI-Split"       # Updated dataset name

# Image Info
im_width: 256
im_height: 256
chans: 3

# Model Info
MODEL_NAME: 'MuLA_GAN'
GENERATOR: 'mula_gan_g'
DISCRIMINATOR: 'mula_gan_d'

# Training Params
BATCH_SIZE: 8           # <<< WARNING: May cause OOM. Change to 4 or 2 if needed.
NUM_EPOCHS: 200
LR_G: 0.0002
LR_D: 0.0002
B1: 0.5
B2: 0.999
WEIGHT_DECAY: 0.0001
LAMBDA_L1: 100
LAMBDA_ADV: 1
LAMBDA_PERCEPTUAL: 10

# Checkpoint/Logging Params
ckpt_interval: 20
SAMPLE_INTERVAL: 1000
LOG_INTERVAL: 50
val_interval: 1000

# Output Paths
CHECKPOINT_DIR: 'checkpoints'
SAMPLE_DIR: 'samples'
RESULTS_PATH: 'results'

Overwriting configs/train_MuLA-GAN.yaml


In [4]:
!python train_MuLA_GAN.py

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100% 548M/548M [00:02<00:00, 196MB/s]
@@@@@@@@@@@@@@
--- Dataloader Init ---
Base Path (from config): /content/drive/MyDrive/LSUI_Split
Constructed Input Path: /content/drive/MyDrive/LSUI_Split/train/trainA
Constructed GT Path: /content/drive/MyDrive/LSUI_Split/train/trainB
Found 3424 input files.
Found 3424 GT files.
Setting dataset length to: 3424
-----------------------
  valid = Variable(Tensor(np.ones((imgs_distorted.size(0), *patch))), requires_grad=False)
[Epoch 300/301: batch 1700/1712] [DLoss: 0.730, GLoss: 2.230, AdvLoss: 1.069]

In [5]:
# --- IMPORTANT: ---
# 1. Change 'LSUI-Split' if your checkpoint folder is different.
# 2. Change 'generator_199.pth' to the final generator file you found.

!python test.py \
  --weights_path "checkpoints/LSUI-Split/generator_300.pth" \
  --data_dir "/content/drive/My Drive/LSUI_Split/test/testA/" \
  --sample_dir "./output_lsui/"

Loaded model from checkpoints/LSUI-Split/generator_300.pth
Tested: /content/drive/My Drive/LSUI_Split/test/testA/0.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1010.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1027.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1028.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1030.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1041.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1042.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1045.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1052.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1077.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1083.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1089.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1093.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1095.jpg
Tested: /content/drive/My Drive/LSUI_Split/test/testA/1096.jpg

In [6]:
%%writefile Evaluation/measure_ssim_psnr.py
"""
# > Script for measuring quantitative performances in terms of
#   - Structural Similarity Metric (SSIM)
#   - Peak Signal to Noise Ratio (PSNR)
# > Maintainer: https://github.com/xahidbuffon
"""
## python libs
import numpy as np
from PIL import Image
from glob import glob
from os.path import join, exists # Added exists
from ntpath import basename
import os
## local libs
try:
    from Evaluation.imqual_utils import getSSIM, getPSNR
except ImportError:
    try:
         from imqual_utils import getSSIM, getPSNR
    except ImportError as e:
        print(f"Error importing imqual_utils: {e}")
        exit()

## compares avg ssim and psnr
def SSIMs_PSNRs(gtr_dir, gen_dir, im_res=(256, 256)):
    gtr_paths = sorted(glob(join(gtr_dir, "*.*")))
    print(f"Found {len(gtr_paths)} potential ground truth images in: {gtr_dir}")
    print(f"Looking for corresponding generated images in: {gen_dir}")
    ssims, psnrs = [], []
    processed_count = 0
    if not gtr_paths:
        print(f"Error: No image files found in ground truth directory: {gtr_dir}")
        return np.array([]), np.array([])
    for gtr_path in gtr_paths:
        gtr_f_base = basename(gtr_path).split('.')[0]
        gen_path_expected_png = join(gen_dir, gtr_f_base + ".png")
        gen_path_expected_jpg = join(gen_dir, gtr_f_base + ".jpg")
        gen_path_expected_jpeg = join(gen_dir, gtr_f_base + ".jpeg")
        gen_path = None
        if exists(gen_path_expected_png): gen_path = gen_path_expected_png
        elif exists(gen_path_expected_jpg): gen_path = gen_path_expected_jpg
        elif exists(gen_path_expected_jpeg): gen_path = gen_path_expected_jpeg
        if gen_path:
            processed_count += 1
            try:
                r_im = Image.open(gtr_path).resize(im_res)
                g_im = Image.open(gen_path).resize(im_res)
                if r_im.mode != 'RGB': r_im = r_im.convert('RGB')
                if g_im.mode != 'RGB': g_im = g_im.convert('RGB')
                ssim_val = getSSIM(np.array(r_im), np.array(g_im))
                if np.isfinite(ssim_val): ssims.append(ssim_val)
                r_im_L = r_im.convert("L")
                g_im_L = g_im.convert("L")
                psnr_val = getPSNR(np.array(r_im_L), np.array(g_im_L))
                if np.isfinite(psnr_val): psnrs.append(psnr_val)
            except Exception as e:
                 print(f"Error processing {basename(gtr_path)}: {e}")
    if processed_count == 0:
        print("\nError: No matching image pairs found.")
        return np.array([]), np.array([])
    return np.array(ssims), np.array(psnrs)

# --- Define YOUR paths here ---
gtr_dir = "/content/drive/MyDrive/LSUI_Split/test/testB/"  # <<< Path to LSUI test ground truth
gen_dir = "./output_lsui/"                                # <<< Path to your generated LSUI test images
# -----------------------------

### compute SSIM and PSNR
SSIM_measures, PSNR_measures = SSIMs_PSNRs(gtr_dir, gen_dir)
if len(SSIM_measures) > 0:
    print ("\n--- Results ---")
    print ("SSIM on {0} matched samples".format(len(SSIM_measures)))
    print ("Mean: {0:.4f} std: {1:.4f}".format(np.mean(SSIM_measures), np.std(SSIM_measures)))
if len(PSNR_measures) > 0:
    print ("\nPSNR on {0} matched samples".format(len(PSNR_measures)))
    print ("Mean: {0:.2f} std: {1:.2f}".format(np.mean(PSNR_measures), np.std(PSNR_measures)))

Overwriting Evaluation/measure_ssim_psnr.py


In [7]:
!python Evaluation/measure_ssim_psnr.py

Found 855 potential ground truth images in: /content/drive/MyDrive/LSUI_Split/test/testB/
Looking for corresponding generated images in: ./output_lsui/

--- Results ---
SSIM on 855 matched samples
Mean: 0.8389 std: 0.0839

PSNR on 855 matched samples
Mean: 26.57 std: 3.85


In [8]:
%%writefile Evaluation/measure_uiqm.py
"""
# > Script for measuring quantitative performance in terms of UIQM
# > Maintainer: https://github.com/xahidbuffon
"""
## python libs
import numpy as np
from PIL import Image, ImageOps
from glob import glob
from os.path import join
from ntpath import basename
import os
## local libs
try:
    from Evaluation.uqim_utils import getUIQM
except ImportError:
    try:
         from uqim_utils import getUIQM
    except ImportError as e:
        print(f"Error importing uqim_utils: {e}")
        exit()

def measure_UIQMs(dir_name, im_res=(256, 256)):
    paths = sorted(glob(join(dir_name, "*.*")))
    print(f"Found {len(paths)} files in: {dir_name}")
    uqims = []
    i=0
    if not paths:
        print(f"Error: No images found in directory: {dir_name}")
        return np.array([])
    for img_path in paths:
        i=i+1
        try:
            im = Image.open(img_path).resize(im_res)
            if im.mode != 'RGB': im = im.convert('RGB')
            im_array = np.array(im)
            uiqm = getUIQM(im_array)
            if np.isfinite(uiqm):
                 uqims.append(uiqm)
        except Exception as e:
             print(f"Error processing {basename(img_path)}: {e}")
    if not uqims:
        print("\nError: UIQM calculation failed for all images.")
        return np.array([])
    return np.array(uqims)

# --- Define YOUR path here ---
gen_dir = "./output_lsui/" # <<< Path to your generated LSUI test images
# -----------------------------

### compute UIQM
gen_uqims = measure_UIQMs(gen_dir)
if len(gen_uqims) > 0:
    print ("\n--- Results ---")
    print ("UIQM on {0} samples".format(len(gen_uqims)))
    print ("Mean: {0:.4f} std: {1:.4f}".format(np.mean(gen_uqims), np.std(gen_uqims)))

Overwriting Evaluation/measure_uiqm.py


In [9]:
!python Evaluation/measure_uiqm.py

Found 855 files in: ./output_lsui/

--- Results ---
UIQM on 855 samples
Mean: 2.8478 std: 0.3748
