In [None]:
import bhnerf
import bhnerf.constants as consts
import numpy as np
import os
from datetime import datetime
from astropy import units
import ehtim as eh
from bhnerf.optimization import LogFn
from flax.training import checkpoints
from tqdm.auto import tqdm

import xarray as xr
import importlib
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional

ROOT = Path('/srv/tmp/kyle/bhnerf/checkpoints/ensemble')

In [None]:
"""
Generate synthetic observations of a hot-spot
"""
fov_M = 16.0
spin = 0.2
inclination = np.deg2rad(60.0)      
nt = 64

array = 'ngEHT'             
flux_scale = 0.1                    # scale image-plane fluxes to `reasonable` values in Jy
tstart = 2.0 * units.hour           # observation start time
tstop = tstart + 40.0 * units.min   # observation stop time

# Compute geodesics (see Tutorial1)
geos = bhnerf.kgeo.image_plane_geos(
    spin, inclination, 
    num_alpha=64, num_beta=64, 
    alpha_range=[-fov_M/2, fov_M/2],
    beta_range=[-fov_M/2, fov_M/2]
)
Omega = np.sign(spin + np.finfo(float).eps) * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
t_injection = -float(geos.r_o)

# Generate hotspot measurements (see Tutorial2) 
emission_0 = flux_scale * bhnerf.emission.generate_hotspot_xr(
    resolution=(64, 64, 64), 
    rot_axis=[0.0, 0.0, 1.0], 
    rot_angle=0.0,
    orbit_radius=5.5,
    std=0.7,
    r_isco=bhnerf.constants.isco_pro(spin),
    fov=(fov_M, 'GM/c^2')
)
obs_params = {
    'mjd': 57851,                       # night of april 6-7, 2017
    'timetype': 'GMST',
    'nt': nt,                           # number of time samples 
    'tstart': tstart.to('hr').value,    # start of observations
    'tstop': tstop.to('hr').value,      # end of observation 
    'tint': 30.0,                       # integration time,
    'array': eh.array.load_txt('../eht_arrays/{}.txt'.format(array))
}
obs_empty = bhnerf.observation.empty_eht_obs(**obs_params)
fov_rad = (fov_M * consts.GM_c2(consts.sgra_mass) / consts.sgra_distance.to('m')) * units.rad
psize = fov_rad.value / geos.alpha.size 
obs_args = {'psize': psize, 'ra': obs_empty.ra, 'dec': obs_empty.dec, 'rf': obs_empty.rf, 'mjd': obs_empty.mjd}
t_frames = np.linspace(tstart, tstop, nt)
image_plane = bhnerf.emission.image_plane_dynamics(emission_0, geos, Omega, t_frames, t_injection)
movie = eh.movie.Movie(image_plane, times=t_frames.value, **obs_args)
obs = bhnerf.observation.observe_same(movie, obs_empty, ttype='direct', seed=None)

raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0])

In [None]:
def find_run_and_checkpoint(seed_dir, ROOT=Path('/srv/tmp/kyle/bhnerf/checkpoints/ensemble')):
    run_dirs = [d for d in seed_dir.iterdir() if d.is_dir()]
    if not run_dirs:
        raise ValueError(f'[skip] {seed_dir.name}: no run directory found')
    run_dir = max(run_dirs, key=lambda p: p.stat().st_mtime)   # newest run

    ckpts = list(run_dir.glob('checkpoint_*'))
    if not ckpts:
        raise ValueError(f'[skip] {run_dir}: no checkpoints')
    ckpt_dir = max(ckpts, key=lambda p: int(p.name.split('_')[-1]))
    return run_dir, ckpt_dir

def load_predictor(checkpoint_dir, run_dir):
    """Return (apply_fn, params) for a NeRF_Predictor checkpoint."""
    predictor = bhnerf.network.NeRF_Predictor.from_yml(run_dir)
    params = checkpoints.restore_checkpoint(checkpoint_dir, None)['params']
    return predictor, params

def omega_grid_kepler_cyl(fov_M: float, R: int, spin: float, M: float = 1.0) -> np.ndarray:
    """Keplerian angular freq Ω(ρ) on a cubic (R,R,R) grid in geometric units.
       ρ = sqrt(x^2 + y^2); rotation about z-axis; Kerr-ish Ω ≈ 1/(ρ^{3/2}+a)."""
    side = np.linspace(-fov_M/2, fov_M/2, R, dtype=np.float32)
    X, Y, Z = np.meshgrid(side, side, side, indexing='ij')
    rho = np.sqrt(X*X + Y*Y) + 1e-9
    a = np.float32(spin)
    # Kerr equatorial orbital frequency (M=1 units): Ωφ = 1 / (ρ^{3/2} + a)
    # If you carry M explicitly: Ωφ = sqrt(M) / (ρ^{3/2} + a*sqrt(M))
    Ms = np.sqrt(np.float32(M))
    return (Ms / (rho**1.5 + a*Ms)).astype(np.float32) * np.sign(a if a != 0 else 1.0)

def predict_movie_2d_or_3d(apply_fn, params, raytracing_args, t_frames, fov, Omega_3d, twoD=False, resolution=64, chunk=16):
    """
    Evaluate a network on *all* requested times in either 2D or 3D

    Args:
    apply_fn: nn.Module
        A coordinate-based neural net for predicting the emission values as a continuous function
    params: dict
        the network's param from, state['params']
    raytracing_args: OrderedDict
        raytracing arguments for rendering lensed emissions
    t_frames: list[units.quantitiy]
        the frames to evaluate the network across
    fov: float
        world field of view
    twoD: bool
        whether to compute a 2d movie (nt, na, nb, ngeo). Defaults to true, and if not true, computes a 3d movie
    resolution: int
        If rendering in 3d, the resolution of the coordinate grid to evaluate the network on
    chunk: int
        how many frames to sample at a time in 3D
    """
    coords, Omega = raytracing_args['coords'], raytracing_args['Omega']
    J, g, dtau = raytracing_args['J'],  raytracing_args['g'], raytracing_args['dtau']
    Sigma = raytracing_args['Sigma']
    t_start_obs = raytracing_args['t_start_obs']
    t_geos = raytracing_args['t_geos']
    t_injection = raytracing_args['t_injection']
    t_units = t_start_obs.unit if isinstance(t_start_obs, units.Quantity) else None

    frames = []
    if twoD:
        for k in tqdm(range(0, len(t_frames), chunk), desc='frame (chunked)', leave=False):
            t_batch = t_frames[k:k + chunk]
            batch = bhnerf.network.image_plane_prediction(
                params, apply_fn, t_batch, coords, Omega, J, g, dtau, Sigma,
                t_start_obs, t_geos, t_injection, t_units
            )
            frames.append(np.asarray(batch))
        return np.concatenate(frames, axis=0)
    else:
        for t in tqdm(t_frames, desc='frame', leave=False):
            volume = bhnerf.network.sample_3d_grid(apply_fn=apply_fn, params=params, t_frame=t, t_start_obs=t_start_obs, fov=fov, Omega=Omega_3d, resolution=resolution)
            frames.append(np.asarray(volume))
        return np.stack(frames, axis=0)

In [None]:
"""1) Find 3D uncertainty for one frame"""
from pathlib import Path

t_frame = t_frames[0]
resolution = 128
samples=[]
for seed_dir in tqdm(sorted(ROOT.iterdir()), desc='ensemble member'):
    run_dir, ckpt_dir = find_run_and_checkpoint(seed_dir)

    predictor = bhnerf.network.NeRF_Predictor.from_yml(run_dir)
    state = checkpoints.restore_checkpoint(ckpt_dir, None)
    #emission = bhnerf.network.sample_3d_grid(predictor.apply, state['params'], t_frame=t_frame, fov=fov_M, resolution=128)
    emission = np.squeeze(predict_movie_2d_or_3d(predictor.apply, state['params'], raytracing_args, [t_frame], fov_M, Omega_3d=0, twoD=False, resolution=resolution))
    samples.append(emission.astype(np.float32))


In [None]:
samples_stack = np.stack(samples)
mean_vol = samples_stack.mean(axis=0)
std_vol = samples_stack.std(axis=0)

print(samples_stack.shape)
print(mean_vol.shape)
print(std_vol.shape)

In [None]:
"""Render 3D image of uncertainty for one frame"""
importlib.reload(bhnerf.visualization)
import bhnerf.visualization
std_norm = np.clip((std_vol - std_vol.min())/(std_vol.max()-std_vol.min()), 0, 1)
mean_norm = np.clip((mean_vol - mean_vol.min()) / (mean_vol.max()-mean_vol.min()), 0, 1)

#bhnerf.visualization.show_uncert_volume(std_norm, fov_M, cmap_name="inferno", level_norm=[.0,.4,.8,1.0], opacity =[0,.1,.5,.9])

%matplotlib widget
fig = bhnerf.visualization.render_volumes_plotly((std_norm, mean_norm), fov_M, titles=('Standard deviation emission (frame 0)', 'Mean emission (frame 0)'), percentiles=[(80, 99.5), (80, 99.5)], surface_count=6)
fig.show()

In [None]:
"""2) Find 2D uncertainty across time"""
movies=[]
for seed_dir in tqdm(sorted(ROOT.iterdir()), desc='ensemble member'):
    run_dir, ckpt_dir = find_run_and_checkpoint(seed_dir)

    predictor, params = load_predictor(ckpt_dir, run_dir)
    movie = predict_movie_2d_or_3d(apply_fn=predictor.apply, params=params, raytracing_args=raytracing_args, t_frames=t_frames, fov=fov_M, twoD=True, resolution=128, chunk=4)
    movies.append(movie.astype(np.float32))

movies = np.stack(movies)

In [None]:
"""2) show 2D uncertainty movie"""
print(movies.shape)
Nens, nt, Ny, Nx = movies.shape
y, x = np.arange(Ny), np.arange(Nx)
plt.close("all")

mov_da = xr.DataArray(movies, dims=('ens', 't', 'y', 'x'), coords={'t':t_frames, 'y':y, 'x':x})

%matplotlib widget
mean_da = mov_da.mean('ens')
std_da = mov_da.std('ens')
fig, axes = plt.subplots(1, 2, figsize=(8, 6))

out = Path("/srv/tmp/kyle/bhnerf/results/naive_ensemble/tutorial4_mean_and_uncertainty_framescaled2D.mp4")
#m = mean_da.bv_viz.animate(cmap='afmhot', ax=axes[0], output=out_m, fps=20)
#s = std_da.bv_viz.animate(cmap='plasma', ax=axes[1], output=out_s, fps=20)
anim = bhnerf.visualization.animate_movies_synced([mean_da, std_da], axes=axes, cmaps=['afmhot', 'plasma'], titles=['Naive ensemble mean', 'Naive ensemble std'], output=out)

axes[0].set_title('Naive ensemble mean')
axes[1].set_title('Naive ensemble std')
plt.tight_layout()

In [None]:
"""3) build 3D ensemble emissions across time"""
from tqdm.auto import tqdm

bh_radius = 2.0
cam_r = 37.
linewidth = 0.1
azimuth = 120.0
zenith=np.pi/3

n_ens = 100
nt = 64
resolution = 128

#volumes = np.zeros((n_ens, nt, resolution, resolution, resolution), dtype=np.float32)
for i, seed_dir in enumerate(tqdm(sorted(ROOT.iterdir()), desc='ensemble member')):
    run_dir, ckpt_dir = find_run_and_checkpoint(seed_dir)
    predictor, params = load_predictor(ckpt_dir, run_dir)
    """
    vol_stack = []
    for t in t_frames:
        vol = bhnerf.network.sample_3d_grid(
            predictor.apply,
            params,
            t_frame=t,
            t_start_obs=raytracing_args['t_start_obs'],
            Omega=Omega,
            fov=fov_M,
            resolution=128
        )
        vol = vol/vol.max()
        vol_stack.append(np.asarray(vol))no
        
    vol_stack = np.stack(vol_stack, axis=0)
    """
    import bhnerf.uncertainty
    #Omega_3d = omega_grid_kepler_cyl(fov_M, resolution, spin, float(geos.M))
    Omega_3d = bhnerf.uncertainty.omega_grid_kepler(fov_M, resolution, spin, float(geos.M))
    vols = predict_movie_2d_or_3d(predictor.apply, params, raytracing_args, t_frames=t_frames, fov=fov_M, Omega_3d=Omega_3d, twoD=False, resolution=resolution)
    volumes[i] = vols
    


In [None]:
for i, seed_dir in enumerate(tqdm(sorted(ROOT.iterdir()), desc='ensemble member')):
    run_dir, ckpt_dir = find_run_and_checkpoint(seed_dir)
    predictor, params = load_predictor(ckpt_dir, run_dir)
    emission_0 = 

In [None]:
print(volumes.shape)

In [None]:
"""3) render the time varying 3D movie"""

importlib.reload(bhnerf.visualization)
import bhnerf.visualization

std_vol = np.std(volumes, axis=0)
mean_vol = np.mean(volumes, axis=0)

std_norm = np.clip((std_vol - std_vol.min())/(std_vol.max()-std_vol.min()), 0, 1)
mean_norm = np.clip((mean_vol - mean_vol.min()) / (mean_vol.max()-mean_vol.min()), 0, 1)

visualizer = bhnerf.visualization.VolumeVisualizer(128, 128, 128)

%matplotlib widget
anim = bhnerf.visualization.render_3d_movie(
    std_norm, t_frames, visualizer,
    cam_r = 37.0,
    rmax = predictor.rmax,
    bh_radius = 2.0,
    linewidth = 0.1,
    fps = 20,
)

In [None]:
importlib.reload(bhnerf.visualization)
import bhnerf.visualization
mean_all = volumes.mean(axis=(0, 1))
std_all = volumes.std(axis=(0, 1), ddof=1)

std_norm = np.clip((std_all - std_all.min())/(std_all.max()-std_all.min()), 0, 1)
mean_norm = np.clip((mean_all - mean_all.min())/(mean_all.max()-mean_all.min()), 0, 1)

print(std_norm.shape)
#bhnerf.visualization.show_uncert_volume_plotly(std_norm, fov_M, cmap="Plasma")
#bhnerf.visualization.show_multi_volumes_plotly((std_norm, mean_norm), fov_M, ('Standard deviation', 'Mean emission'), )

In [None]:
importlib.reload(bhnerf.visualization)
import bhnerf.visualization

fig = bhnerf.visualization.render_volumes_plotly((std_norm, mean_norm), fov_M, ('Standard deviation', 'Mean emission'), interactive_sync=True)

In [None]:
from IPython.display import display
display(fig)

In [None]:
out = Path("/srv/tmp/kyle/bhnerf/results/naive_ensemble/tutorial7_uncertainty_3D_time_fixed.mp4")
anim.save(out, writer='ffmpeg', fps=20)

In [None]:
import json, os
from pathlib import Path

P = Path("/srv/tmp/kyle/bhnerf/tutorials/Tutorial7 - Estimate uncertainty via naive ensemble.ipynb")
nb = json.loads(P.read_text())

def jsz(x):  # approximate size in bytes if saved as JSON
    return len(json.dumps(x, separators=(',',':')).encode('utf-8'))

print(f"Notebook size on disk: {P.stat().st_size/1e6:.1f} MB")

# root metadata that often bloats
print("Root metadata keys:", list(nb.get("metadata", {}).keys()))
for k in ("widgets","widget_state","varInspector","latex_envs"):
    if k in nb.get("metadata", {}):
        print(f"  -> metadata['{k}'] size ~ {jsz(nb['metadata'][k])/1e6:.2f} MB")

# top 10 heaviest cells
cells = [(i, c["cell_type"], jsz(c)) for i,c in enumerate(nb["cells"])]
cells.sort(key=lambda t: t[2], reverse=True)
print("\nHeaviest cells:")
for i, typ, sz in cells[:10]:
    c = nb["cells"][i]
    outs = len(c.get("outputs", []))
    atts = len(c.get("attachments", {}) or {})
    print(f"  Cell {i:>3} | {typ:<8} | {sz/1e6:6.2f} MB | outputs={outs} | attachments={atts}")
    if outs:
        mimes=set()
        for o in c["outputs"]:
            if "data" in o: mimes |= set(o["data"].keys())
        print("      output mimes:", sorted(mimes))
