# Definitions of plotting routines

In [1]:
import galsim
import matplotlib.pyplot as plt
import seaborn as sns

# Tweak plot resolution and styling
%config InlineBackend.figure_format = "retina"
sns.set(style="white", palette=None, rc={"axes.linewidth": 1})
plt.rc("image", cmap="viridis")

In [None]:
def true_vs_pred(y_true, y_pred, snr, set_limits=False):
    titles = ['Flux', 'Sersic Index', 'Sersic Radius', 'g1']
    limits = [
        [15_000, 415_000],  # Flux
        [0.2, 6.3],  # Sersic index
        [0.08, 0.62],  # Sersic radius
        [-0.72, 0.72],  # g1
    ]
    
    fig, axes = plt.subplots(2, 2, figsize=(20, 15))
    for i in range(4):
        ax = axes[i//2, i%2]
        im = ax.scatter(y_true[:, i], y_pred[:, i], c=snr, marker='.', cmap='RdYlBu')
        fig.colorbar(im, ax=ax)
        ax.set_title(f"{titles[i]} Colored by SNR", fontsize=20)
        ax.set_xlabel('True Value', fontsize=16)
        ax.set_ylabel('Predicted Value', fontsize=16)
        ax.plot([0, 1], [0, 1], transform=ax.transAxes, color="tab:red", linestyle="--")
    
        if set_limits:
            ax.set(xlim=limits[i], ylim=limits[i])

In [None]:
def true_vs_error(y_true, y_pred, snr):
    titles = ['Flux', 'Sersic Index', 'Sersic Radius', 'g1']
    error = y_pred - y_true
    
    fig, axes = plt.subplots(2, 2, figsize=(20, 15))
    for i in range(4):
        ax = axes[i//2, i%2]
        sns.scatterplot(y_true[:, i], error[:, i], hue=snr, hue_norm=(10, 100), ax=ax, palette='RdYlBu')
        ax.set_title(titles[i], fontsize=20)
        ax.set_xlabel('True Value', fontsize=16)
        ax.set_ylabel('Error', fontsize=16)
        ax.legend(title="SNR", loc="upper right")
        ax.axhline(0, color="tab:red", linestyle="--")

In [None]:
def error_hist(y_true, y_pred):
    titles = ['Flux', 'Sersic Index', 'Sersic Radius', 'g1']
    error = y_pred - y_true
    
    fig, axes = plt.subplots(2, 2, figsize=(20, 15))
    for i in range(4):
        ax = axes[i//2, i%2]
        sns.distplot(error[:, i], bins=50, ax=ax)
        ax.set_title(titles[i], fontsize=20)
        ax.set_xlabel('Error', fontsize=16)
        ax.set_ylabel('Density', fontsize=16)
        ax.axvline(0, color="tab:red", linestyle="--")

In [None]:
def galsim_image(flux, sersic_index, sersic_radius, g1, g2, psf_r, sigma):
    """Generate a noiseless image using GalSim.
    """
    image_size = 64  # n x n pixels
    pixel_scale = 0.23  # arcsec / pixel
    psf_beta = 2  # moffat parameter
    
    gal = galsim.Sersic(sersic_index, half_light_radius=sersic_radius)
    gal = gal.withFlux(flux)
    gal = gal.shear(g1=g1, g2=g2)
    psf = galsim.Moffat(beta=psf_beta, flux=1.0, fwhm=psf_r)
    final = galsim.Convolve([psf, gal])
    image = galsim.ImageF(image_size, image_size, scale=pixel_scale)
    final.drawImage(image=image)

    return image.array

In [None]:
def plot_galsim_reconstruction(noisy, noiseless, reconstructed):
    difference = reconstructed - noiseless
    images = [noisy, noiseless, reconstructed, difference]
    titles = ["Noisy Image", "Noiseless Image", "Reconstruction", "Difference"]

    fig, axes = plt.subplots(2, 2, figsize=(9.5, 8), constrained_layout=True)
    for image, title, ax in zip(images, titles, axes.flat):
        im = ax.imshow(image)
        ax.axis("off")
        ax.set_title(title)
        fig.colorbar(im, ax=ax)
    fig.suptitle("Galaxy Reconstruction for SNR 60");