In [None]:
from pathlib import Path

In [None]:
real_root = "/home/mgj/workspace/mgjeon/image-to-image/data/sdo/aia_dataset"
real_root = Path(real_root)
stage = "test"
input_files = sorted(list((real_root / stage / "input").glob("*.npz")))
target_files = sorted(list((real_root / stage / "target").glob("*.npz")))

print(len(input_files), len(target_files))

244 244


In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sunpy.map import Map
from astropy.visualization import ImageNormalize, AsinhStretch

import torch
from torchmetrics import MeanAbsoluteError
from torchmetrics.regression import PearsonCorrCoef, ConcordanceCorrCoef
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

import sys
sys.path.append('../data/sdo')
from preprocess_aia import NormalizeEditor

In [None]:
class Validator:
    def __init__(
        self,
        real_root,
        fake_root,
        concordance=False,
        plot=False,
    ):
        self.real_root = Path(real_root)
        self.fake_root = Path(fake_root)
        # self.real_files = sorted(list(Path(real_root).glob("*.npz")))
        # self.fake_files = sorted(list(Path(fake_root).glob("*.npz")))
        # print(len(self.real_files), len(self.fake_files))

        self.mae = MeanAbsoluteError()                 # 0.0 is best
        self.psnr = PeakSignalNoiseRatio()             # +inf is best
        self.ssim = StructuralSimilarityIndexMeasure() # 1.0 is best
        if concordance:
            self.cc = ConcordanceCorrCoef()               # 1.0 is best
        else:
            self.cc = PearsonCorrCoef()
        self.plot = plot
    
    def inverse(self, x):
        x = NormalizeEditor(0, 14).inverse(x)
        return x
    
    def __call__(self, idx, stage):
        real_root = self.real_root / stage / "target"
        fake_root = self.fake_root / stage
        real_files = sorted(list(real_root.glob("*.npz")))

        real_file = real_files[idx]
        timestamp = real_file.stem[:-4]
        fake_file = Path(fake_root) / (str(real_file.stem) + "_fake.npz")

        real = np.load(real_file, allow_pickle=True)
        fake = np.load(fake_file, allow_pickle=True)

        real_map = Map(real["data"][0], real["metas"][0])
        fake_map = Map(fake["data"][0], fake["metas"][0])

        del real
        del fake

        real_map = self.inverse(real_map)
        fake_map = self.inverse(fake_map)

        Y, X = np.ogrid[:real_map.data.shape[0], :real_map.data.shape[1]]
        xc, yc = real_map.wcs.world_to_pixel(real_map.center)
        dist = np.sqrt((X-xc)**2 + (Y-yc)**2)
        mask = dist <= real_map.meta['r_sun']  # Mask points inside the circle

        fake_data = torch.Tensor(fake_map.data*mask).unsqueeze(0).unsqueeze(0)
        real_data = torch.Tensor(real_map.data*mask).unsqueeze(0).unsqueeze(0)
        mae_value = self.mae(fake_data, real_data)
        pixel_to_pixel_cc = self.cc(fake_data.flatten(), real_data.flatten())
        psnr_value = self.psnr(fake_data, real_data)
        ssim_value = self.ssim(fake_data, real_data)

        self.mae.reset()
        self.psnr.reset()
        self.ssim.reset()
        self.cc.reset()

        metrics = {
            "mae": mae_value,
            "cc": pixel_to_pixel_cc,
            "psnr": psnr_value,
            "ssim": ssim_value,
        }

        if self.plot:

            fig = plt.figure(figsize=(10, 5))

            norm = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.04))

            ax = fig.add_subplot(1, 2, 1, projection=real_map)
            ax.imshow(real_data.squeeze(), cmap="sdoaia193", origin="lower", norm=norm)
            ax.axis("off")
            ax.set_title("Target 193")

            ax = fig.add_subplot(1, 2, 2, projection=fake_map)
            ax.imshow(fake_data.squeeze(), cmap="sdoaia193", origin="lower", norm=norm)
            ax.axis("off")
            ax.set_title("AI-generated 193")

            fig.suptitle(f"{timestamp}")
            fig.tight_layout()
            plt.show()
        
        del real_map
        del fake_map
        del real_data
        del fake_data

        return metrics

In [None]:
real_root = "/home/mgj/workspace/mgjeon/image-to-image/data/sdo/aia_dataset"
real_root = Path(real_root)
stage = "test"
real_files = sorted(list((real_root / stage / "target").glob("*.npz")))
len(real_files)

244

In [None]:
from tqdm import tqdm

In [None]:
def get_metrics(stage, real_root, fake_root):
    metrics = {
        "mae": [],
        "cc": [],
        "psnr": [],
        "ssim": [],
    }
    for idx in tqdm(range(len(real_files))):
        val = Validator(
            real_root=real_root,
            fake_root=fake_root,
        )
        ms = val(idx, stage=stage)
        metrics["mae"].append(ms["mae"])
        metrics["cc"].append(ms["cc"])
        metrics["psnr"].append(ms["psnr"])
        metrics["ssim"].append(ms["ssim"])
    return metrics

In [None]:
def get_mean_metrics(results):
    mean_metrics = {}
    for key in results.keys():
        print(key)
        mean_mae = np.mean(results[key]["mae"])
        mean_cc = np.mean(results[key]["cc"])
        mean_psnr = np.mean(results[key]["psnr"])
        mean_ssim = np.mean(results[key]["ssim"])
        mean_metrics[key] = {
            "mae": mean_mae,
            "cc": mean_cc,
            "psnr": mean_psnr,
            "ssim": mean_ssim,
        }
        print(f"MAE: {mean_mae:.2f}")
        print(f"CC: {mean_cc:.4f}")
        print(f"PSNR: {mean_psnr:.2f}")
        print(f"SSIM: {mean_ssim:.4f}")
        print()
    return mean_metrics

In [None]:
results = {}

results["pix2pix"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pix/unet_patchgan_1024_small/version_0")
results["pix2pixHD"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixHD/default_small/version_0")
results["pix2pixCC"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixCC/default_small/version_0")
results["ddpm_noise"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_noise/version_0")
results["ddpm_x0"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_x0/version_0")
results["fast_ddpm_noise"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_noise/version_0")
results["fast_ddpm_x0"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_x0/version_0")

results["pix2pix_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pix/unet_patchgan_1024_small_ema/version_0")
results["pix2pixHD_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixHD/default_small_ema/version_0")
results["pix2pixCC_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixCC/default_small_ema/version_0")
results["ddpm_noise_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_noise_ema/version_0")
results["ddpm_x0_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_x0_ema/version_0")
results["fast_ddpm_noise_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_noise_ema/version_0")
results["fast_ddpm_x0_ema"] = get_metrics(stage, real_root, "/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_x0_ema/version_0")

100%|██████████| 244/244 [00:40<00:00,  5.97it/s]
100%|██████████| 244/244 [00:33<00:00,  7.20it/s]
100%|██████████| 244/244 [00:30<00:00,  8.02it/s]
100%|██████████| 244/244 [00:32<00:00,  7.60it/s]
100%|██████████| 244/244 [00:34<00:00,  7.08it/s]
100%|██████████| 244/244 [00:33<00:00,  7.21it/s]
100%|██████████| 244/244 [00:32<00:00,  7.60it/s]
100%|██████████| 244/244 [00:37<00:00,  6.55it/s]
100%|██████████| 244/244 [00:33<00:00,  7.30it/s]
100%|██████████| 244/244 [00:31<00:00,  7.65it/s]
100%|██████████| 244/244 [00:32<00:00,  7.54it/s]
100%|██████████| 244/244 [00:37<00:00,  6.53it/s]
100%|██████████| 244/244 [00:31<00:00,  7.79it/s]
100%|██████████| 244/244 [00:31<00:00,  7.64it/s]


In [None]:
mean_metrics = get_mean_metrics(results)

pix2pix
MAE: 29.36
CC: 0.9439
PSNR: 40.46
SSIM: 0.9731

pix2pixHD
MAE: 30.07
CC: 0.9403
PSNR: 39.82
SSIM: 0.9739

pix2pixCC
MAE: 30.47
CC: 0.9385
PSNR: 39.80
SSIM: 0.9720

ddpm_noise
MAE: 67.46
CC: 0.7957
PSNR: 33.70
SSIM: 0.8968

ddpm_x0
MAE: 28.47
CC: 0.9539
PSNR: 41.22
SSIM: 0.9766

fast_ddpm_noise
MAE: 53.74
CC: 0.9115
PSNR: 36.07
SSIM: 0.9340

fast_ddpm_x0
MAE: 28.98
CC: 0.9528
PSNR: 41.09
SSIM: 0.9759

pix2pix_ema
MAE: 29.49
CC: 0.9439
PSNR: 40.43
SSIM: 0.9729

pix2pixHD_ema
MAE: 30.33
CC: 0.9419
PSNR: 39.95
SSIM: 0.9730

pix2pixCC_ema
MAE: 30.62
CC: 0.9414
PSNR: 39.75
SSIM: 0.9737

ddpm_noise_ema
MAE: 58.24
CC: 0.8701
PSNR: 34.87
SSIM: 0.9264

ddpm_x0_ema
MAE: 28.62
CC: 0.9547
PSNR: 41.20
SSIM: 0.9770

fast_ddpm_noise_ema
MAE: 52.12
CC: 0.9234
PSNR: 36.33
SSIM: 0.9413

fast_ddpm_x0_ema
MAE: 28.78
CC: 0.9545
PSNR: 41.20
SSIM: 0.9764



In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame(mean_metrics).T
df

Unnamed: 0,mae,cc,psnr,ssim
pix2pix,29.355825,0.943943,40.456551,0.973117
pix2pixHD,30.074755,0.9403,39.81522,0.973901
pix2pixCC,30.468699,0.938501,39.803967,0.971977
ddpm_noise,67.459549,0.795691,33.696991,0.896771
ddpm_x0,28.472622,0.953945,41.220028,0.976552
fast_ddpm_noise,53.743309,0.911523,36.067909,0.934042
fast_ddpm_x0,28.981403,0.952785,41.085266,0.975857
pix2pix_ema,29.490059,0.943885,40.427376,0.972947
pix2pixHD_ema,30.326172,0.941855,39.946705,0.973003
pix2pixCC_ema,30.618757,0.94137,39.753212,0.973672


In [None]:
df = df.sort_values(["mae", "cc", "psnr", "ssim"], ascending=[True, False, False, False])
df

Unnamed: 0,mae,cc,psnr,ssim
ddpm_x0,28.472622,0.953945,41.220028,0.976552
ddpm_x0_ema,28.617199,0.954708,41.20454,0.976973
fast_ddpm_x0_ema,28.78344,0.954463,41.203133,0.976352
fast_ddpm_x0,28.981403,0.952785,41.085266,0.975857
pix2pix,29.355825,0.943943,40.456551,0.973117
pix2pix_ema,29.490059,0.943885,40.427376,0.972947
pix2pixHD,30.074755,0.9403,39.81522,0.973901
pix2pixHD_ema,30.326172,0.941855,39.946705,0.973003
pix2pixCC,30.468699,0.938501,39.803967,0.971977
pix2pixCC_ema,30.618757,0.94137,39.753212,0.973672


In [None]:
df.to_csv("metrics_within_disk.csv")

In [None]:
df.sort_values(["cc"], ascending=[False])

Unnamed: 0,mae,cc,psnr,ssim
ddpm_x0_ema,28.617199,0.954708,41.20454,0.976973
fast_ddpm_x0_ema,28.78344,0.954463,41.203133,0.976352
ddpm_x0,28.472622,0.953945,41.220028,0.976552
fast_ddpm_x0,28.981403,0.952785,41.085266,0.975857
pix2pix,29.355825,0.943943,40.456551,0.973117
pix2pix_ema,29.490059,0.943885,40.427376,0.972947
pix2pixHD_ema,30.326172,0.941855,39.946705,0.973003
pix2pixCC_ema,30.618757,0.94137,39.753212,0.973672
pix2pixHD,30.074755,0.9403,39.81522,0.973901
pix2pixCC,30.468699,0.938501,39.803967,0.971977


In [None]:
df.sort_values(["psnr"], ascending=[False])

Unnamed: 0,mae,cc,psnr,ssim
ddpm_x0,28.472622,0.953945,41.220028,0.976552
ddpm_x0_ema,28.617199,0.954708,41.20454,0.976973
fast_ddpm_x0_ema,28.78344,0.954463,41.203133,0.976352
fast_ddpm_x0,28.981403,0.952785,41.085266,0.975857
pix2pix,29.355825,0.943943,40.456551,0.973117
pix2pix_ema,29.490059,0.943885,40.427376,0.972947
pix2pixHD_ema,30.326172,0.941855,39.946705,0.973003
pix2pixHD,30.074755,0.9403,39.81522,0.973901
pix2pixCC,30.468699,0.938501,39.803967,0.971977
pix2pixCC_ema,30.618757,0.94137,39.753212,0.973672


In [None]:
df.sort_values(["ssim"], ascending=[False])

Unnamed: 0,mae,cc,psnr,ssim
ddpm_x0_ema,28.617199,0.954708,41.20454,0.976973
ddpm_x0,28.472622,0.953945,41.220028,0.976552
fast_ddpm_x0_ema,28.78344,0.954463,41.203133,0.976352
fast_ddpm_x0,28.981403,0.952785,41.085266,0.975857
pix2pixHD,30.074755,0.9403,39.81522,0.973901
pix2pixCC_ema,30.618757,0.94137,39.753212,0.973672
pix2pix,29.355825,0.943943,40.456551,0.973117
pix2pixHD_ema,30.326172,0.941855,39.946705,0.973003
pix2pix_ema,29.490059,0.943885,40.427376,0.972947
pix2pixCC,30.468699,0.938501,39.803967,0.971977
