# Definitions of plotting routines

In [1]:
import galsim
import matplotlib.pyplot as plt
import seaborn as sns

# Tweak plot resolution and styling
%config InlineBackend.figure_format = "retina"
sns.set(style="white", palette=None, rc={"axes.linewidth": 1})
plt.rc("image", cmap="viridis")

In [4]:
def true_vs_pred(y_true, y_pred, snr):
    titles = ['Flux [$10^5$]', 'Sérsic Index', 'Sérsic Radius', 'g1']
    
    # Display Flux in 10^5
    y_true = y_true.copy()
    y_true[:, 0] /= 10**5
    y_pred = y_pred.copy()
    y_pred[:, 0] /= 10**5
    
    # Ensure all axes are displayed as floats
    formatter = StrMethodFormatter("{x:.1f}")
    
    fig, axes = plt.subplots(2, 2, figsize=(13.5, 11.), constrained_layout=True)
    for i in range(4):
        ax = axes[i//2, i%2]
        im = ax.scatter(y_true[:, i], y_pred[:, i], c=snr, edgecolor="w", cmap='RdYlBu')
        cbar = fig.colorbar(im, ax=ax)
        cbar.ax.set_title("SNR")
        ax.set_title(titles[i], fontsize=17)
        ax.set_xlabel('True Value', fontsize=14)
        ax.set_ylabel('Predicted Value', fontsize=14)
        ax.xaxis.set_major_formatter(formatter)
        ax.yaxis.set_major_formatter(formatter)
        ax.plot([0, 1], [0, 1], transform=ax.transAxes, color="tab:red", linestyle="--")
        
    plt.savefig("./graphs/pred.png", dpi=100)

In [None]:
def true_vs_error(y_true, y_pred, snr, set_limits=False):
    titles = ['Flux [$10^5$]', 'Sérsic Index', 'Sérsic Radius', 'g1']

    # Display Flux in 10^5
    y_true = y_true.copy()
    y_true[:, 0] /= 10**5
    y_pred = y_pred.copy()
    y_pred[:, 0] /= 10**5
    
    error = y_pred - y_true
    
    # Ensure all axes are displayed as floats
    formatter = StrMethodFormatter("{x:.1f}")
    
    fig, axes = plt.subplots(2, 2, figsize=(13.5, 11.), constrained_layout=True)
    for i in range(4):
        ax = axes[i//2, i%2]
        im = ax.scatter(y_true[:, i], error[:, i], c=snr, edgecolor='w', cmap='RdYlBu')
        cbar = fig.colorbar(im, ax=ax)
        cbar.ax.set_title("SNR")
        ax.set_title(titles[i], fontsize=17)
        ax.set_xlabel('True Value', fontsize=14)
        ax.set_ylabel('Error', fontsize=14)
        ax.xaxis.set_major_formatter(formatter)
        ax.yaxis.set_major_formatter(formatter)
        ax.axhline(0, color="tab:red", linestyle="--")
        
    plt.savefig("./graphs/err-snr.png", dpi=100)

In [None]:
def error_hist(y_true, y_pred):
    titles = ['Flux [$10^5$]', 'Sérsic Index', 'Sérsic Radius', 'g1']
    error = y_pred - y_true
    
    # Display Flux in 10^5
    error[:, 0] /= 10**5
    
    # Ensure all axes are displayed as floats
    formatter = StrMethodFormatter("{x:.1f}")
    
    fig, axes = plt.subplots(2, 2, figsize=(13.5, 11.), constrained_layout=True)
    for i in range(4):
        ax = axes[i//2, i%2]
        sns.distplot(error[:, i], bins=50, ax=ax)
        ax.set_title(titles[i], fontsize=16)
        ax.set_xlabel('Error', fontsize=14)
        ax.set_ylabel('Density', fontsize=14)
        ax.xaxis.set_major_formatter(formatter)
        ax.yaxis.set_major_formatter(formatter)
        ax.axvline(0, color="tab:red", linestyle="--")
        
    plt.savefig("./graphs/err-dist.png", dpi=100)

In [None]:
def galsim_image(flux, sersic_index, sersic_radius, g1, g2, psf_r):
    """Generate a noiseless image using GalSim.
    """
    image_size = 64  # n x n pixels
    pixel_scale = 0.23  # arcsec / pixel
    psf_beta = 2  # moffat parameter
    
    gal = galsim.Sersic(sersic_index, half_light_radius=sersic_radius)
    gal = gal.withFlux(flux)
    gal = gal.shear(g1=g1, g2=g2)
    psf = galsim.Moffat(beta=psf_beta, flux=1.0, fwhm=psf_r)
    final = galsim.Convolve([psf, gal])
    image = galsim.ImageF(image_size, image_size, scale=pixel_scale)
    final.drawImage(image=image)

    return image.array

In [None]:
def plot_galsim_reconstruction(noisy, noiseless, reconstructed):
    difference = reconstructed - noiseless
    images = [noisy, noiseless, reconstructed, difference]
    titles = ["Noisy Image", "Noiseless Image", "Reconstruction", "Difference"]

    fig, axes = plt.subplots(2, 2, figsize=(9.5, 8), constrained_layout=True)
    for image, title, ax in zip(images, titles, axes.flat):
        im = ax.imshow(image)
        ax.axis("off")
        ax.set_title(title)
        fig.colorbar(im, ax=ax)
    fig.suptitle("Galaxy Reconstruction for SNR 60");

In [None]:
def generate_images(AE, val_ds_AE, idx):
    test = val_ds_AE.take(1)
    (images, stats), (clean, labels) = next(iter(test))
    pred_img, pred_lab = AE.predict([images, stats])
    pred = pred_img.reshape(pred_img.shape[:-1])
    
    # Generate images from estimated labels using GalSim
#     psf_r = stats[..., 1].numpy()
#     galsim_img = np.empty((images.shape[:3]))
#     for i in idx:
# #         import pdb; pdb.set_trace()
#         try:
#             galsim_img[i] = galsim_image(*pred_lab[i], psf_r[i])
#         except:
#             print(i)
#             print(pred_lab[i])
#             raise

    rows = [
        {
            "title": "Original Noisy Images",
            "image": lambda j: images[idx[j], :, :, 0].numpy(),
            "path": "./graphs/ori.png",
        },
        {
            "title": "True Noiseless Images",
            "image": lambda j: clean[idx[j]].numpy().reshape(64, 64),
            "path": "./graphs/cle.png",
        },
        {
            "title": "Generated Images",
            "image": lambda j: pred[idx[j]],
            "path": "./graphs/gen.png",
        },
        {
            "title": "True Noiseless Images, Log Scale",
            "image": lambda j: np.log(clean[idx[j]].numpy().reshape(64, 64)),
            "path": "./graphs/ori-log.png",
        },
        {
            "title": "Generated Images, Log Scale",
            "image": lambda j: np.log(pred[idx[j]]),
            "path": "./graphs/gen-log.png",
        },
#         {
#             "title": "GalSim from Estimated Parameters, Log Scale",
#             "image": lambda j: np.log(galsim_img[idx[j]]),
#             "path": "./graphs/gen-log.png",
#         },
        {
            "title": "Residuals, Absolute Values",
            "image": lambda j: np.abs(
                clean[idx[j]].numpy().reshape(64, 64) - pred[idx[j]]
            ),
            "path": "./graphs/resi-abs.png",
            "equalize": True,
        },
        {
            "title": "Residuals (True minus Reconstructed)",
            "image": lambda j: clean[idx[j]].numpy().reshape(64, 64) - pred[idx[j]],
            "path": "./graphs/resi-scl.png",
            "equalize": True,
        },
    ]

    for row in rows:
        data = np.empty((8, 64, 64))
        for j in range(8):
            data[j] = row["image"](j)
            
        # Collect the range of all images in a row
        equalize = row.get("equalize", False)
        vmin, vmax = data.ravel().min(), data.ravel().max() 
            
        fig, axes = plt.subplots(1, 8, figsize=(20, 3), constrained_layout=True)
        plt.suptitle(row["title"], fontsize=25)
        for j, ax in enumerate(axes.flat):
            if equalize:
                # Apply the same colormap range to all images in a row
                ax.imshow(data[j], vmin=vmin, vmax=vmax)
            else:
                ax.imshow(data[j])
            ax.axis("off")
    
        plt.savefig(row["path"], dpi=100)
        plt.show()