In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sunpy.map import Map
from astropy.visualization import ImageNormalize, AsinhStretch

import torch
from torchmetrics import MeanAbsoluteError
from torchmetrics.regression import PearsonCorrCoef, ConcordanceCorrCoef
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

import sys
sys.path.append('../data/sdo')
from preprocess_aia import NormalizeEditor

In [None]:
class Validator:
    def __init__(
        self,
        real_root,
        fake_root,
        concordance=False,
    ):
        self.real_root = Path(real_root)
        self.fake_root = Path(fake_root)
        # self.real_files = sorted(list(Path(real_root).glob("*.npz")))
        # self.fake_files = sorted(list(Path(fake_root).glob("*.npz")))
        # print(len(self.real_files), len(self.fake_files))

        self.mae = MeanAbsoluteError()                 # 0.0 is best
        self.psnr = PeakSignalNoiseRatio()             # +inf is best
        self.ssim = StructuralSimilarityIndexMeasure() # 1.0 is best
        if concordance:
            self.cc = ConcordanceCorrCoef()               # 1.0 is best
        else:
            self.cc = PearsonCorrCoef()
    
    def inverse(self, x):
        x = NormalizeEditor(0, 14).inverse(x)
        return x
    
    def __call__(self, idx, stage):
        real_root = self.real_root / stage / "target"
        fake_root = self.fake_root / stage
        real_files = sorted(list(real_root.glob("*.npz")))

        real_file = real_files[idx]
        timestamp = real_file.stem[:-4]
        fake_file = Path(fake_root) / (str(real_file.stem) + "_fake.npz")

        real = np.load(real_file, allow_pickle=True)
        fake = np.load(fake_file, allow_pickle=True)

        real_map = Map(real["data"][0], real["metas"][0])
        fake_map = Map(fake["data"][0], fake["metas"][0])

        del real
        del fake

        real_map = self.inverse(real_map)
        fake_map = self.inverse(fake_map)

        fake_data = torch.Tensor(fake_map.data).unsqueeze(0).unsqueeze(0)
        real_data = torch.Tensor(real_map.data).unsqueeze(0).unsqueeze(0)
        mae_value = self.mae(fake_data, real_data)
        pixel_to_pixel_cc = self.cc(fake_data.flatten(), real_data.flatten())
        psnr_value = self.psnr(fake_data, real_data)
        ssim_value = self.ssim(fake_data, real_data)

        self.mae.reset()
        self.psnr.reset()
        self.ssim.reset()
        self.cc.reset()

        metrics = {
            "mae": mae_value,
            "cc": pixel_to_pixel_cc,
            "psnr": psnr_value,
            "ssim": ssim_value,
        }

        del real_map
        del fake_map
        del real_data
        del fake_data

        return metrics
    
        # ----------------------------------------------------------
        # fig = plt.figure(figsize=(10, 5))

        # norm = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.04))

        # ax = fig.add_subplot(1, 2, 1, projection=real_map)
        # ax.imshow(real_map.data, cmap="sdoaia193", origin="lower", norm=norm)
        # ax.axis("off")
        # ax.set_title("Target 193")

        # ax = fig.add_subplot(1, 2, 2, projection=fake_map)
        # ax.imshow(fake_map.data, cmap="sdoaia193", origin="lower", norm=norm)
        # ax.axis("off")
        # ax.set_title("AI-generated 193")

        # fig.suptitle(f"{timestamp}")
        # fig.tight_layout()
        # plt.show()
        # ----------------------------------------------------------

In [None]:
real_root = "/home/mgj/workspace/mgjeon/image-to-image/data/sdo/aia_dataset"
real_root = Path(real_root)
stage = "test"
real_files = sorted(list((real_root / stage / "target").glob("*.npz")))
len(real_files)

244

In [None]:
from tqdm import tqdm

In [None]:
def get_metrics(stage, real_root, fake_root):
    metrics = {
        "mae": [],
        "cc": [],
        "psnr": [],
        "ssim": [],
    }
    for idx in tqdm(range(len(real_files))):
        val = Validator(
            real_root=real_root,
            fake_root=fake_root,
        )
        ms = val(idx, stage=stage)
        metrics["mae"].append(ms["mae"])
        metrics["cc"].append(ms["cc"])
        metrics["psnr"].append(ms["psnr"])
        metrics["ssim"].append(ms["ssim"])
    return metrics

In [None]:
def get_mean_metrics(results):
    mean_metrics = {}
    for key in results.keys():
        print(key)
        mean_mae = np.mean(results[key]["mae"])
        mean_cc = np.mean(results[key]["cc"])
        mean_psnr = np.mean(results[key]["psnr"])
        mean_ssim = np.mean(results[key]["ssim"])
        mean_metrics[key] = {
            "mae": mean_mae,
            "cc": mean_cc,
            "psnr": mean_psnr,
            "ssim": mean_ssim,
        }
        print(f"MAE: {mean_mae:.2f}")
        print(f"CC: {mean_cc:.4f}")
        print(f"PSNR: {mean_psnr:.2f}")
        print(f"SSIM: {mean_ssim:.4f}")
        print()
    return mean_metrics

In [None]:
results = {}

results["pix2pix"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pix/unet_patchgan_1024_small/version_0")
results["pix2pixHD"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixHD/default_small/version_0")
results["pix2pixCC"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixCC/default_small/version_0")
results["ddpm_noise"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_noise/version_0")
results["ddpm_x0"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_x0/version_0")
results["fast_ddpm_noise"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_noise/version_0")
results["fast_ddpm_x0"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_x0/version_0")

results["pix2pix_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pix/unet_patchgan_1024_small_ema/version_0")
results["pix2pixHD_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixHD/default_small_ema/version_0")
results["pix2pixCC_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixCC/default_small_ema/version_0")
results["ddpm_noise_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_noise_ema/version_0")
results["ddpm_x0_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_x0_ema/version_0")
results["fast_ddpm_noise_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_noise_ema/version_0")
results["fast_ddpm_x0_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_x0_ema/version_0")

100%|██████████| 244/244 [00:18<00:00, 12.93it/s]
100%|██████████| 244/244 [00:21<00:00, 11.21it/s]
100%|██████████| 244/244 [00:21<00:00, 11.29it/s]
100%|██████████| 244/244 [00:20<00:00, 11.65it/s]
100%|██████████| 244/244 [00:24<00:00,  9.81it/s]
100%|██████████| 244/244 [00:23<00:00, 10.23it/s]
100%|██████████| 244/244 [00:21<00:00, 11.56it/s]
100%|██████████| 244/244 [00:21<00:00, 11.25it/s]
100%|██████████| 244/244 [00:22<00:00, 10.80it/s]
100%|██████████| 244/244 [00:21<00:00, 11.55it/s]
100%|██████████| 244/244 [00:22<00:00, 11.03it/s]
100%|██████████| 244/244 [00:21<00:00, 11.40it/s]
100%|██████████| 244/244 [00:20<00:00, 11.94it/s]
100%|██████████| 244/244 [00:21<00:00, 11.18it/s]


In [None]:
mean_metrics = get_mean_metrics(results)

pix2pix
MAE: 40.02
CC: 0.9433
PSNR: 39.19
SSIM: 0.9749

pix2pixHD
MAE: 42.17
CC: 0.9440
PSNR: 38.86
SSIM: 0.9673

pix2pixCC
MAE: 43.36
CC: 0.9227
PSNR: 37.96
SSIM: 0.9757

ddpm_noise
MAE: 114.68
CC: 0.7289
PSNR: 31.62
SSIM: 0.7866

ddpm_x0
MAE: 39.47
CC: 0.9565
PSNR: 40.18
SSIM: 0.9713

fast_ddpm_noise
MAE: 86.67
CC: 0.8619
PSNR: 33.29
SSIM: 0.9205

fast_ddpm_x0
MAE: 39.97
CC: 0.9559
PSNR: 40.10
SSIM: 0.9705

pix2pix_ema
MAE: 40.01
CC: 0.9490
PSNR: 39.57
SSIM: 0.9683

pix2pixHD_ema
MAE: 42.27
CC: 0.9455
PSNR: 38.98
SSIM: 0.9665

pix2pixCC_ema
MAE: 42.49
CC: 0.9404
PSNR: 38.54
SSIM: 0.9743

ddpm_noise_ema
MAE: 96.25
CC: 0.8460
PSNR: 32.96
SSIM: 0.8498

ddpm_x0_ema
MAE: 39.31
CC: 0.9577
PSNR: 40.20
SSIM: 0.9721

fast_ddpm_noise_ema
MAE: 87.68
CC: 0.8908
PSNR: 33.44
SSIM: 0.9189

fast_ddpm_x0_ema
MAE: 39.78
CC: 0.9575
PSNR: 40.18
SSIM: 0.9712

