In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sunpy.map import Map
from astropy.visualization import ImageNormalize, AsinhStretch

import torch
from torchmetrics import MeanAbsoluteError
from torchmetrics.regression import PearsonCorrCoef, ConcordanceCorrCoef
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

import sys
sys.path.append('../data/sdo')
from preprocess_aia import NormalizeEditor

In [None]:
class Validator:
    def __init__(
        self,
        real_root,
        fake_root,
        concordance=False,
    ):
        self.real_root = Path(real_root)
        self.fake_root = Path(fake_root)
        # self.real_files = sorted(list(Path(real_root).glob("*.npz")))
        # self.fake_files = sorted(list(Path(fake_root).glob("*.npz")))
        # print(len(self.real_files), len(self.fake_files))

        self.mae = MeanAbsoluteError()                 # 0.0 is best
        self.psnr = PeakSignalNoiseRatio()             # +inf is best
        self.ssim = StructuralSimilarityIndexMeasure() # 1.0 is best
        if concordance:
            self.cc = ConcordanceCorrCoef()               # 1.0 is best
        else:
            self.cc = PearsonCorrCoef()
    
    def inverse(self, x):
        x = NormalizeEditor(0, 14).inverse(x)
        return x
    
    def __call__(self, idx, stage):
        real_root = self.real_root / stage / "target"
        fake_root = self.fake_root / stage
        real_files = sorted(list(real_root.glob("*.npz")))

        real_file = real_files[idx]
        timestamp = real_file.stem[:-4]
        fake_file = Path(fake_root) / (str(real_file.stem) + "_fake.npz")

        real = np.load(real_file, allow_pickle=True)
        fake = np.load(fake_file, allow_pickle=True)

        real_map = Map(real["data"][0], real["metas"][0])
        fake_map = Map(fake["data"][0], fake["metas"][0])

        del real
        del fake

        real_map = self.inverse(real_map)
        fake_map = self.inverse(fake_map)

        fake_data = torch.Tensor(fake_map.data).unsqueeze(0).unsqueeze(0)
        real_data = torch.Tensor(real_map.data).unsqueeze(0).unsqueeze(0)
        mae_value = self.mae(fake_data, real_data)
        pixel_to_pixel_cc = self.cc(fake_data.flatten(), real_data.flatten())
        psnr_value = self.psnr(fake_data, real_data)
        ssim_value = self.ssim(fake_data, real_data)

        self.mae.reset()
        self.psnr.reset()
        self.ssim.reset()
        self.cc.reset()

        metrics = {
            "mae": mae_value,
            "cc": pixel_to_pixel_cc,
            "psnr": psnr_value,
            "ssim": ssim_value,
        }

        # print(f"MAE: {metrics['mae']:.4f}")
        # print(f"CC: {metrics['cc']:.4f}")
        # print(f"PSNR: {metrics['psnr']:.4f}")
        # print(f"SSIM: {metrics['ssim']:.4f}")

        real_root_input = self.real_root / stage / "input"
        real_file_input = real_root_input / str(fake_file.name).replace("_193_fake", "_171_304")
        real_input = np.load(real_file_input, allow_pickle=True)
        real_map_input_171 = Map(real_input["data"][0], real_input["metas"][0])
        real_map_input_171 = self.inverse(real_map_input_171)
        real_map_input_304 = Map(real_input["data"][1], real_input["metas"][1])
        real_map_input_304 = self.inverse(real_map_input_304)

        real_res = {
            "timestamp": timestamp,
            "real_map": real_map,
            "real_map_input_171": real_map_input_171,
            "real_map_input_304": real_map_input_304,
        }

        return metrics, fake_map, real_res
    
        # ----------------------------------------------------------
        # fig = plt.figure(figsize=(10, 5))

        # norm = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.04))

        # ax = fig.add_subplot(1, 2, 1, projection=real_map)
        # ax.imshow(real_map.data, cmap="sdoaia193", origin="lower", norm=norm)
        # ax.axis("off")
        # ax.set_title("Target 193")

        # ax = fig.add_subplot(1, 2, 2, projection=fake_map)
        # ax.imshow(fake_map.data, cmap="sdoaia193", origin="lower", norm=norm)
        # ax.axis("off")
        # ax.set_title("AI-generated 193")

        # fig.suptitle(f"{timestamp}")
        # fig.tight_layout()
        # plt.show()
        # ----------------------------------------------------------

In [None]:
real_root = "/home/mgj/workspace/mgjeon/image-to-image/data/sdo/aia_dataset"
real_root = Path(real_root)
stage = "test"
real_files = sorted(list((real_root / stage / "target").glob("*.npz")))
len(real_files)

244

In [None]:
res_path = Path("res")
res_path.mkdir(exist_ok=True, parents=True)
res_ema_path = Path("res_ema")
res_ema_path.mkdir(exist_ok=True, parents=True)

In [None]:
from tqdm import tqdm

for idx in tqdm(range(len(real_files))):
    # --------------------------------------------------------------------------------------------
    results = []

    pix2pix = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pix/unet_patchgan_1024_small/version_0"
    )
    metrics, fake_maps, real_res = pix2pix(idx, stage=stage)
    results.append(
        {
            "model": "pix2pix",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    pix2pixHD = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixHD/default_small/version_0"
    )
    metrics, fake_maps, _ = pix2pixHD(idx, stage=stage)
    results.append(
        {
            "model": "pix2pixHD",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    pix2pixCC = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixCC/default_small/version_0"
    )
    metrics, fake_maps, _ = pix2pixCC(idx, stage=stage)
    results.append(
        {
            "model": "pix2pixCC",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    ddpm_noise = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_noise/version_0"
    )
    metrics, fake_maps, _ = ddpm_noise(idx, stage=stage)
    results.append(
        {
            "model": "ddpm_noise",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    ddpm_x0 = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_x0/version_0"
    )
    metrics, fake_maps, _ = ddpm_x0(idx, stage=stage)
    results.append(
        {
            "model": "ddpm_x0",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    fast_ddpm_noise = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_noise/version_0"
    )
    metrics, fake_maps, _ = fast_ddpm_noise(idx, stage=stage)
    results.append(
        {
            "model": "fast_ddpm_noise",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    fast_ddpm_x0 = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_x0/version_0"
    )
    metrics, fake_maps, _ = fast_ddpm_x0(idx, stage=stage)
    results.append(
        {
            "model": "fast_ddpm_x0",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )


    # --------------------------------------------------------------------------------------------
    nrows = 4
    ncols = 3

    plt.rcParams["font.size"] = 15

    fig = plt.figure(figsize=(5*ncols, 5*nrows))

    norm = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.04))

    norm171 = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.02))
    norm304 = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.005))

    timestamp = real_res["timestamp"]
    real_map = real_res["real_map"]
    real_map_input_171 = real_res["real_map_input_171"]
    real_map_input_304 = real_res["real_map_input_304"]

    ax = fig.add_subplot(nrows, ncols, 1)
    ax.imshow(real_map_input_171.data, cmap="sdoaia171", origin="lower", norm=norm171)
    ax.axis("off")
    ax.set_title("Input 171")

    ax = fig.add_subplot(nrows, ncols, 2)
    ax.imshow(real_map_input_304.data, cmap="sdoaia304", origin="lower", norm=norm304)
    ax.axis("off")
    ax.set_title("Input 304")

    ax = fig.add_subplot(nrows, ncols, 3)
    ax.imshow(real_map.data, cmap="sdoaia193", origin="lower", norm=norm)
    ax.axis("off")
    ax.set_title("Target 193")

    for i in range(3):
        ax = fig.add_subplot(nrows, ncols, i+4)
        ax.imshow(results[i]["fake_maps"].data, cmap="sdoaia193", origin="lower", norm=norm)
        ax.axis("off")
        # ax.set_title(f"{results[i]['model']}\n MAE {results[i]['metrics']['mae']:.4f} \n CC {results[i]['metrics']['cc']:.4f} \n PSNR {results[i]['metrics']['psnr']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")
        ax.set_title(f"{results[i]['model']} \n CC {results[i]['metrics']['cc']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")

    for i in range(3, 5):
        ax = fig.add_subplot(nrows, ncols, i+4)
        ax.imshow(results[i]["fake_maps"].data, cmap="sdoaia193", origin="lower", norm=norm)
        ax.axis("off")
        # ax.set_title(f"{results[i]['model']}\n MAE {results[i]['metrics']['mae']:.4f} \n CC {results[i]['metrics']['cc']:.4f} \n PSNR {results[i]['metrics']['psnr']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")
        ax.set_title(f"{results[i]['model']} \n CC {results[i]['metrics']['cc']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")

    for i in range(5, 7):
        ax = fig.add_subplot(nrows, ncols, i+5)
        ax.imshow(results[i]["fake_maps"].data, cmap="sdoaia193", origin="lower", norm=norm)
        ax.axis("off")
        # ax.set_title(f"{results[i]['model']}\n MAE {results[i]['metrics']['mae']:.4f} \n CC {results[i]['metrics']['cc']:.4f} \n PSNR {results[i]['metrics']['psnr']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")
        ax.set_title(f"{results[i]['model']} \n CC {results[i]['metrics']['cc']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")

    fig.suptitle(f"{timestamp}")
    fig.tight_layout()
    # fig.savefig(f"res_{timestamp}.png")
    fig.savefig(res_path / f"res_{timestamp}.png")
    # plt.show()
    plt.close(fig)

    # --------------------------------------------------------------------------------------------
    results = []

    pix2pix = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pix/unet_patchgan_1024_small_ema/version_0"
    )
    metrics, fake_maps, real_res = pix2pix(idx, stage=stage)
    results.append(
        {
            "model": "pix2pix_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    pix2pixHD = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixHD/default_small_ema/version_0"
    )
    metrics, fake_maps, _ = pix2pixHD(idx, stage=stage)
    results.append(
        {
            "model": "pix2pixHD_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    pix2pixCC = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/pix2pixCC/default_small_ema/version_0"
    )
    metrics, fake_maps, _ = pix2pixCC(idx, stage=stage)
    results.append(
        {
            "model": "pix2pixCC_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    ddpm_noise = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_noise_ema/version_0"
    )
    metrics, fake_maps, _ = ddpm_noise(idx, stage=stage)
    results.append(
        {
            "model": "ddpm_noise_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    ddpm_x0 = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/ddpm_x0_ema/version_0"
    )
    metrics, fake_maps, _ = ddpm_x0(idx, stage=stage)
    results.append(
        {
            "model": "ddpm_x0_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    fast_ddpm_noise = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_noise_ema/version_0"
    )
    metrics, fake_maps, _ = fast_ddpm_noise(idx, stage=stage)
    results.append(
        {
            "model": "fast_ddpm_noise_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )

    fast_ddpm_x0 = Validator(
        real_root=real_root,
        fake_root="/home/mgj/workspace/mgjeon/image-to-image/metrics/sdo/diffusion/fast_ddpm_x0_ema/version_0"
    )
    metrics, fake_maps, _ = fast_ddpm_x0(idx, stage=stage)
    results.append(
        {
            "model": "fast_ddpm_x0_ema",
            "metrics": metrics,
            "fake_maps": fake_maps,
        }
    )


    # --------------------------------------------------------------------------------------------
    nrows = 4
    ncols = 3

    plt.rcParams["font.size"] = 15

    fig = plt.figure(figsize=(5*ncols, 5*nrows))

    norm = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.04))

    norm171 = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.02))
    norm304 = ImageNormalize(vmin=0, vmax=2000, stretch=AsinhStretch(0.005))

    timestamp = real_res["timestamp"]
    real_map = real_res["real_map"]
    real_map_input_171 = real_res["real_map_input_171"]
    real_map_input_304 = real_res["real_map_input_304"]

    ax = fig.add_subplot(nrows, ncols, 1)
    ax.imshow(real_map_input_171.data, cmap="sdoaia171", origin="lower", norm=norm171)
    ax.axis("off")
    ax.set_title("Input 171")

    ax = fig.add_subplot(nrows, ncols, 2)
    ax.imshow(real_map_input_304.data, cmap="sdoaia304", origin="lower", norm=norm304)
    ax.axis("off")
    ax.set_title("Input 304")

    ax = fig.add_subplot(nrows, ncols, 3)
    ax.imshow(real_map.data, cmap="sdoaia193", origin="lower", norm=norm)
    ax.axis("off")
    ax.set_title("Target 193")

    for i in range(3):
        ax = fig.add_subplot(nrows, ncols, i+4)
        ax.imshow(results[i]["fake_maps"].data, cmap="sdoaia193", origin="lower", norm=norm)
        ax.axis("off")
        # ax.set_title(f"{results[i]['model']}\n MAE {results[i]['metrics']['mae']:.4f} \n CC {results[i]['metrics']['cc']:.4f} \n PSNR {results[i]['metrics']['psnr']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")
        ax.set_title(f"{results[i]['model']} \n CC {results[i]['metrics']['cc']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")

    for i in range(3, 5):
        ax = fig.add_subplot(nrows, ncols, i+4)
        ax.imshow(results[i]["fake_maps"].data, cmap="sdoaia193", origin="lower", norm=norm)
        ax.axis("off")
        # ax.set_title(f"{results[i]['model']}\n MAE {results[i]['metrics']['mae']:.4f} \n CC {results[i]['metrics']['cc']:.4f} \n PSNR {results[i]['metrics']['psnr']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")
        ax.set_title(f"{results[i]['model']} \n CC {results[i]['metrics']['cc']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")

    for i in range(5, 7):
        ax = fig.add_subplot(nrows, ncols, i+5)
        ax.imshow(results[i]["fake_maps"].data, cmap="sdoaia193", origin="lower", norm=norm)
        ax.axis("off")
        # ax.set_title(f"{results[i]['model']}\n MAE {results[i]['metrics']['mae']:.4f} \n CC {results[i]['metrics']['cc']:.4f} \n PSNR {results[i]['metrics']['psnr']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")
        ax.set_title(f"{results[i]['model']} \n CC {results[i]['metrics']['cc']:.4f} \n SSIM {results[i]['metrics']['ssim']:.4f}")

    fig.suptitle(f"{timestamp}")
    fig.tight_layout()
    # fig.savefig(f"res_ema_{timestamp}.png")
    fig.savefig(res_ema_path / f"res_ema_{timestamp}.png")
    # plt.show()
    plt.close(fig)

100%|██████████| 18/18 [01:16<00:00,  4.27s/it]
