In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from mpl_toolkits.mplot3d import Axes3D    # noqa: F401

In [None]:
AOD_Data = np.load('results/AOD_data/FFNP/targets.npy')
AOD_Locs = np.load('results/AOD_data/FFNP/pred_locs.npy')

In [None]:
STACI_pred = np.load('results/AOD_data/FFNP/preds.npy')
STACI_pred_var = np.load('results/AOD_data/FFNP/pred_uncer.npy')
STACI_pred_conf = np.load('results/AOD_data/FFNP/pred_uncer_conf.npy')

In [None]:
def plotTests_sphere(datapath, save_dir, save_dir2, conformal, prefix,
                          cmap_mean="RdBu_r", cmap_unc="viridis",
                          vmin_mean=-1, vmax_mean=3,
                          vmin_unc=0,  vmax_unc=6,
                          sphere_res=(60, 120),   # target resolution (v, u)
                          sphere_size=6):
    # --- load & build your full-resolution 2D arrays as before ---
    pred_locs = np.load(os.path.join(save_dir, 'pred_locs.npy'))
    test_locs = pd.DataFrame(pred_locs[:, :2], columns=['x','y'])
    targets   = np.load(os.path.join(save_dir, 'targets.npy'))

    dx = np.diff(np.sort(test_locs["x"].unique())).min()
    dy = np.diff(np.sort(test_locs["y"].unique())).min()
    x0, x1 = test_locs["x"].min(), test_locs["x"].max()
    y0, y1 = test_locs["y"].min(), test_locs["y"].max()
    x_grid, y_grid = np.meshgrid(
        np.arange(x0, x1+dx, dx),
        np.arange(y0, y1+dy, dy)
    )

    ix = np.round((test_locs['x']-x0)/dx).astype(int)
    iy = np.round((test_locs['y']-y0)/dy).astype(int)

    def fill2d(arr):
        grid = np.full_like(x_grid, np.nan, dtype=float)
        grid[iy, ix] = arr
        return grid

    fields = {
        "ground_truth": fill2d(targets),
        "pred_mean":    fill2d(np.load(os.path.join(save_dir, 'preds.npy'))),
        "bayes_uncer":  fill2d(np.load(os.path.join(save_dir, 'pred_uncer.npy'))),
        "resids":       fill2d(np.load(os.path.join(save_dir, 'resids.npy')))
    }
    if conformal:
        fields["conf_uncer"] = fill2d(
            np.load(os.path.join(save_dir, 'pred_uncer_conf.npy'))
        )

    # --- downsample each field to (sphere_res) for speed ---
    M, N = fields["ground_truth"].shape
    target_m, target_n = sphere_res
    step_m = max(1, M // target_m)
    step_n = max(1, N // target_n)

    for k,v in fields.items():
        fields[k] = v[::step_m, ::step_n]

    # new mesh sizes
    M2, N2 = fields["ground_truth"].shape

    # --- precompute a coarse unit-sphere mesh ---
    u = np.linspace(0, 2*np.pi, N2)
    v = np.linspace(0,     np.pi,  M2)
    uu, vv = np.meshgrid(u, v)
    Xs = np.sin(vv) * np.cos(uu)
    Ys = np.sin(vv) * np.sin(uu)
    Zs = np.cos(vv)

    def _plot_sphere(field2d, fname, cmap, vmin, vmax):
        fig = plt.figure(figsize=(sphere_size, sphere_size))
        ax  = fig.add_subplot(1,1,1, projection='3d')
        ax.set_axis_off()

        # force full-figure, zero-margins
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        ax.set_position([0, 0, 1, 1])
        
        # force equal aspect so the sphere isn't an ellipsoid
        #ax.set_box_aspect((1,1,1))
        #1. equal aspect ratio
        ax.set_box_aspect((1,1,1))
        # 2. ensure the plotting limits are symmetric
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_zlim(-1, 1)
        # 3. optional: remove perspective
        try:
            ax.set_proj_type('ortho')
        except AttributeError:
            pass
        # set camera distance: lower = zoom in
        ax.dist = 7

        norm     = Normalize(vmin=vmin, vmax=vmax, clip=True)
        mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
        facecols = mappable.to_rgba(field2d)

        ax.plot_surface(
            Xs, Ys, Zs,
            rstride=1, cstride=1,
            facecolors=facecols,
            linewidth=0, antialiased=False, shade=False
        )

        out = os.path.join(save_dir2, fname)
        fig.savefig(out, dpi=500, bbox_inches='tight', pad_inches=0)
        plt.close(fig)

    # --- render and save ---
    if conformal:
        _plot_sphere(fields["ground_truth"],
                     f"{prefix}_ground_truth_{datapath}.png",
                     cmap_mean, vmin_mean, vmax_mean)
    _plot_sphere(fields["pred_mean"],
                 f"{prefix}_pred_mean_{datapath}.png",
                 cmap_mean, vmin_mean, vmax_mean)
    _plot_sphere(fields["bayes_uncer"],
                 f"{prefix}_bayes_uncer_{datapath}.png",
                 cmap_unc,  vmin_unc,  vmax_unc)
    if conformal:
        _plot_sphere(fields["conf_uncer"],
                     f"{prefix}_conf_uncer_{datapath}.png",
                     cmap_unc, vmin_unc, vmax_unc)

    print("Fast sphere plots saved to", save_dir2)

In [None]:
plotTests_sphere('STACI', 'results/AOD_data/FFNP', 'results/AOD_data/Plots_Sphere', True, 'AOD_data', sphere_res=(400, 750))