In [None]:
import bhnerf
import bhnerf.constants as consts
import numpy as np
from astropy import units
import ehtim as eh
from flax.training import checkpoints
import matplotlib.pyplot as plt
import importlib

import numpy as np
from astropy import units
import xarray as xr
import bhnerf
from datetime import datetime
from bhnerf.optimization import LogFn

In [None]:
def shift_z_xr(da: xr.DataArray, dz: float) -> xr.DataArray:
    if abs(dz) < 1e-12: 
        return da
    x, y, z = np.meshgrid(da.x, da.y, da.z, indexing='ij')
    coords = np.stack([x, y, z-dz], axis=-1)
    data = bhnerf.emission.interpolate_coords(da, coords)  
    return xr.DataArray(data, coords=da.coords, dims=da.dims, attrs=da.attrs)

def multi_hotspot_emission_xr(spin, fov_M, resolution=(64,64,64), orbit_radius=5.5, blobs=(), flux_scale=0.1
):
    """
    blobs: list of dicts with keys:
      ang_deg (required), amp (default 1.0), std (scalar or (sx,sy,sz), default 0.7), dz (default 0.0)
    All built with bhnerf.emission.generate_hotspot_xr; optional dz uses interpolate_coords.
    """
    r_isco = bhnerf.constants.isco_pro(spin)
    total = None
    amp_sum = 0.0

    for b in blobs:
        ang  = np.deg2rad(b['ang_deg'])
        amp  = float(b.get('amp', 1.0))
        std  = b.get('std', 0.7)
        dz   = float(b.get('dz', 0.0))

        da = bhnerf.emission.generate_hotspot_xr(
            resolution=resolution,
            rot_axis=[0.0, 0.0, 1.0],
            rot_angle=ang,
            orbit_radius=float(orbit_radius),
            std=std, r_isco=r_isco,
            fov=(fov_M, 'GM/c^2'),
            normalize=True,
        )

        if abs(dz) > 0.0:
            da = shift_z_xr(da, dz)

        total = da*amp if total is None else (total + da*amp)
        amp_sum += amp

    return total * (flux_scale / max(amp_sum, 1e-12))

In [None]:
fov_M = 16.0
spin = 0.2
inclination = np.deg2rad(60.0)
nt = 64

array = 'ngEHT'
flux_scale = 0.1
tstart = 2.0 * units.hour
tstop = tstart + 40.0 * units.min

geos = bhnerf.kgeo.image_plane_geos(
    spin, inclination, 
    num_alpha=64, num_beta=64,
    alpha_range=[-fov_M/2, fov_M/2], 
    beta_range=[-fov_M/2, fov_M/2]
)
Omega = np.sign(spin + np.finfo(float).eps) * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
t_injection = -float(geos.r_o)

three_spot = [
    {'ang_deg': 25,  'amp': 1.00, 'std': (0.7, 0.7, 0.7), 'dz': +2},  # bright, upper-right-ish
    {'ang_deg': 200, 'amp': 0.85, 'std': (0.85,0.65,0.7), 'dz': -4},  # close pair, side A
    {'ang_deg': 250, 'amp': 0.80, 'std': (0.85,0.65,0.7), 'dz': -5},  # close pair, side A
]

emission_0 = multi_hotspot_emission_xr(
    spin=spin, 
    fov_M=fov_M, 
    orbit_radius=5.5,
    blobs=three_spot, 
    flux_scale=0.1
)
obs_params = {
    'mjd': 57581,
    'timetype': 'GMST',
    'nt': nt,
    'tstart': tstart.to('hr').value,
    'tstop': tstop.to('hr').value,
    'tint': 30.0,
    'array': eh.array.load_txt('../eht_arrays/{}.txt'.format(array))
}
obs_empty = bhnerf.observation.empty_eht_obs(**obs_params)
fov_rad = (fov_M * consts.GM_c2(consts.sgra_mass) / consts.sgra_distance.to('m')) * units.rad
psize = fov_rad.value / geos.alpha.size
obs_args = {'psize': psize, 'ra': obs_empty.ra, 'dec': obs_empty.dec, 'rf': obs_empty.rf, 'mjd': obs_empty.mjd}

t_frames = np.linspace(tstart, tstop, nt)
image_plane = bhnerf.emission.image_plane_dynamics(emission_0, geos, Omega, t_frames, t_injection)
movie = eh.movie.Movie(image_plane, times=t_frames.value, **obs_args)
obs = bhnerf.observation.observe_same(movie, obs_empty, ttype='direct', seed=None)


In [None]:
bhnerf.visualization.ipyvolume_3d(emission_0, fov=fov_M, level=[0, 0.2, 0.7])

In [None]:
batchsize = 12
z_width = fov_M                # maximum disk width [M]
rmax = fov_M / 2           # maximum recovery radius 
rmin = float(geos.r.min()) # minimum recovery radius
hparams = {'num_iters': 8000, 'lr_init': 1e-4, 'lr_final': 1e-6}

# Logging 
current_time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
runname = 'tutorial8/recovery.vis.{}'.format(current_time)
writer = bhnerf.optimization.SummaryWriter(logdir='../runs/{}'.format(runname))
writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission_0), global_step=0)
log_fns = [
    LogFn(lambda opt: writer.add_scalar('log_loss/train', np.log10(np.mean(opt.loss)), global_step=opt.step)), 
    LogFn(lambda opt: writer.recovery_3d(fov_M, emission_true=emission_0)(opt), log_period=200)
]

# Observation parameters 
chisqdata = eh.imaging.imager_utils.chisqdata_vis
train_step = bhnerf.optimization.TrainStep.eht(t_frames, obs, movie.fovx(), movie.xdim, chisqdata)

# Optimization
predictor = bhnerf.network.NeRF_Predictor(rmax, rmin, rmax, z_width)
raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0])
optimizer = bhnerf.optimization.Optimizer(hparams, predictor, raytracing_args, checkpoint_dir='../checkpoints/{}'.format(runname))
optimizer.run(batchsize, train_step, raytracing_args, log_fns=log_fns)

In [None]:
chisqdata = eh.imaging.imager_utils.chisqdata_vis

checkpoint_dir = '/srv/tmp/kyle/bhnerf/checkpoints/tutorial8/recovery.vis.2025-09-02.20:26:19/'
predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
state = checkpoints.restore_checkpoint(checkpoint_dir, None)

params = state['params']
raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0])

In [None]:
emission_estimate = bhnerf.network.sample_3d_grid(predictor.apply, params, fov=fov_M)
bhnerf.visualization.ipyvolume_3d(emission_estimate, fov=fov_M, level=[0, 0.2, 0.7])

In [None]:
importlib.reload(bhnerf.uncertainty)
import bhnerf.uncertainty as unc

times, sigma, A = unc.build_A_per_frame(obs, t_frames, movie.fovx(), movie.xdim)
frames_to_include = [list(range(len(t_frames)))]
br_unc = unc.BayesRaysUncertaintyMapper(
    predictor_apply=predictor.apply, predictor_params=params, raytracing_args=raytracing_args, t_frames=t_frames, 
    frames_to_include=frames_to_include, A=A, sigma=sigma, fov=fov_M, grid_res=(64, 64, 64), lam=1e-6/64**3
)

In [None]:
H, R_eff = br_unc.compute_hessian_diag_all(sigma, batch_size=1)
V = br_unc.get_covariance(H)

print("-"*20)
print("H MATRIX:", H, "SHAPE:", H.shape)
print("V MATRIX:", V, "SHAPE:", V.shape)

In [None]:
import pickle
H = pickle.load(open('../results/bayesrays/hessian_diag_64_fig8_r1.pkl', 'rb'))
V = br_unc.get_covariance(H, False)

In [None]:
importlib.reload(bhnerf.visualization)
import bhnerf.visualization as vis

uncertainty = br_unc.prep_uncertainty_3d(H, V, fov_M, min_uncertainty=-3, max_uncertainty=6, resolution=64, squared_weights=False, mask=True)
opacityscale = [
    [0.00, 0.00],
    [0.10, 0.00],
    [0.30, 0.15],
    [0.50, 0.35],
    [0.70, 0.60],
    [0.85, 0.85],
    [1.00, 1.00],
]

fig = vis.render_volumes_plotly(volumes_norm=(uncertainty*(emission_estimate/emission_estimate.max()), emission_estimate/emission_estimate.max(), emission_0/emission_0.max()), 
                                fov=fov_M, percentiles=[(20, 100),(0, 100), (0, 100)], 
                                titles=('bayesrays uncertainty', 'network emission', 'ground_truth'), 
                                surface_count=6)

In [None]:
%matplotlib widget
fig.show()

In [None]:
print("frames:", len(br_unc.fm_per_frame))
print('coords shape', br_unc.coords.shape)
print("nvis per frame:", br_unc.nvis_per_frame, "total:", sum(br_unc.nvis_per_frame))
#print("R_eff / total:", R_eff, "/", sum(br_unc.nvis_per_frame))
#print("H min/max:", float(H.min()), float(H.max()))
def weighted_corner_occupancy(u, R):
    x = u[...,0]*(R[0]-1); i0 = np.floor(x).astype(np.int32); wx = x - i0
    y = u[...,1]*(R[1]-1); j0 = np.floor(y).astype(np.int32); wy = y - j0
    z = u[...,2]*(R[2]-1); k0 = np.floor(z).astype(np.int32); wz = z - k0
    i1, j1, k1 = i0+1, j0+1, k0+1

    w000=(1-wx)*(1-wy)*(1-wz); w100=wx*(1-wy)*(1-wz)
    w010=(1-wx)*wy*(1-wz);     w110=wx*wy*(1-wz)
    w001=(1-wx)*(1-wy)*wz;     w101=wx*(1-wy)*wz
    w011=(1-wx)*wy*wz;         w111=wx*wy*wz

    occ = np.zeros(tuple(R), dtype=np.float32)
    for (ii,jj,kk,ww) in [(i0,j0,k0,w000),(i1,j0,k0,w100),(i0,j1,k0,w010),(i1,j1,k0,w110),
                          (i0,j0,k1,w001),(i1,j0,k1,w101),(i0,j1,k1,w011),(i1,j1,k1,w111)]:
        np.add.at(occ, (np.asarray(ii).ravel(), np.asarray(jj).ravel(), np.asarray(kk).ravel()),
                  np.asarray(ww).ravel())
    return occ

occ = br_unc.param_occupancy(fov_M)
bhnerf.visualization.ipyvolume_3d(occ, fov_M)

In [None]:

active_mask = (occ > 0)                 # bool
active_frac = float(active_mask.mean())
print("active voxel fraction:", active_frac)

In [None]:
# A. check unit spans & clipping
import jax.numpy as jnp
def _percentiles(a): 
    p = jnp.percentile(a, np.array((0, 1, 5, 25, 50, 75, 95, 99, 100)), method='nearest')
    return np.array(p)

ux, uy, uz = [br_unc.coords_unit[...,i] for i in range(3)]
print("u_x percentiles:", _percentiles(np.asarray(ux)))
print("u_y percentiles:", _percentiles(np.asarray(uy)))
print("u_z percentiles:", _percentiles(np.asarray(uz)))
print("clipped frac per axis:", float((ux<=1e-6).mean()+(ux>=1-1e-6).mean()),
                                  float((uy<=1e-6).mean()+(uy>=1-1e-6).mean()),
                                  float((uz<=1e-6).mean()+(uz>=1-1e-6).mean()))

# Expect: broad spread (e.g., ~0.05..0.95 or so), and small clipped fractions.

# B. lattice occupancy (should be *volumetric*, not a dot)
occ = br_unc.param_occupancy()
print("nonzero voxels:", int((occ>0).sum()), " / ", int(np.prod(br_unc.grid_res)))

# C. finite-difference sanity on one parameter (a single control point)
from copy import deepcopy
def fd_check(one_i= br_unc.grid_res[0]//2,
             one_j= br_unc.grid_res[1]//2,
             one_k= br_unc.grid_res[2]//2,
             axis = 0, eps=1e-3, ray=0, frame=0):
    # nudge exactly one control param
    params2 = deepcopy(br_unc.def_params)
    theta    = params2['theta']                    # (Rx,Ry,Rz,3)
    theta    = theta.at[one_i, one_j, one_k, axis].add(eps)
    params2  = {'theta': theta}

    # evaluate vis change for one ray
    offsets0 = br_unc.def_grid.apply({"params": br_unc.def_params}, br_unc.coords_unit)
    offsets1 = br_unc.def_grid.apply({"params": params2},          br_unc.coords_unit)
    vis0 = br_unc.fm_per_frame[frame](br_unc.coords + offsets0*br_unc.voxel_world)
    vis1 = br_unc.fm_per_frame[frame](br_unc.coords + offsets1*br_unc.voxel_world)
    dv  = vis1[ray] - vis0[ray]
    return complex(dv)

print("finite-diff dv (should be nonzero):", fd_check())


In [None]:
%matplotlib widget
fig, ax = plt.subplots(1, 2, figsize=(8,3.5))
ax[0].set_title('UV coverage'); ax[1].set_title('Visibility amplitudes')
bhnerf.observation.plot_uv_coverage(obs_empty, ax=ax[0], cmap_ticks=[0,0.2,0.4,0.6], fontsize=11)
obs.plotall('uvdist', 'amp', axis=ax[1])
plt.tight_layout()


g = bhnerf.kgeo.doppler_factor(geos, Omega, fillna=False)
%matplotlib widget
bhnerf.visualization.plot_geodesic_3D(g, geos)   # should show where rays traverse in world coords

import scipy.ndimage as ndi
R = 64
X = np.linspace(-fov_M/2, fov_M/2, R, dtype=np.float32)
x,y,z = np.meshgrid(X,X,X, indexing='ij')
grid = (x - x.min()) / (x.max() - x.min())  # ramp along x only

# sample grid at its own coords via world->image conversion
fov = [fov_M, fov_M, fov_M]; npix = [R,R,R]
coords = np.stack([x,y,z], axis=-1)
image_idx = np.moveaxis(bhnerf.utils.world_to_image_coords(coords, fov=fov, npix=npix), -1, 0)
rec = ndi.map_coordinates(grid, image_idx, order=1, cval=0.0)

print("max abs error (should be ~0):", float(np.max(np.abs(rec - grid))))
import jax.numpy as jnp
def _pct(a,p): return float(jnp.percentile(a, p))
u = br_unc.coords_unit
for d,name in enumerate("xyz"):
    uu = u[..., d]
    print(f"u_{name}: [{float(uu.min()):.3f}, {_pct(uu,1):.3f}, {_pct(uu,50):.3f}, {_pct(uu,99):.3f}, {float(uu.max()):.3f}]")

R = np.array(br_unc.grid_res)
i = np.floor(np.asarray(u[...,0])*(R[0]-1)).astype(np.int32)
j = np.floor(np.asarray(u[...,1])*(R[1]-1)).astype(np.int32)
k = np.floor(np.asarray(u[...,2])*(R[2]-1)).astype(np.int32)
print("unique indices (x,y,z):", len(np.unique(i)), len(np.unique(j)), len(np.unique(k)))

import jax
def flatten(tree):
    return jnp.concatenate([t.ravel() for t in jax.tree_util.tree_leaves(tree)])

def grad_energy_for_ray(ray_idx):
    fm = br_unc.fm_per_frame[0]
    def re_fn(dp): return jnp.real(fm(br_unc.coords + br_unc.def_grid.apply({"params": dp}, br_unc.coords_unit)*br_unc.voxel_world))[ray_idx]
    def im_fn(dp): return jnp.imag(fm(br_unc.coords + br_unc.def_grid.apply({"params": dp}, br_unc.coords_unit)*br_unc.voxel_world))[ray_idx]
    g_re = flatten(jax.grad(re_fn)(br_unc.def_params))
    g_im = flatten(jax.grad(im_fn)(br_unc.def_params))
    g = (g_re**2 + g_im**2).reshape((*br_unc.grid_res, 3)).sum(-1)  # sum x,y,z channels
    return np.asarray(g / (g.max() + 1e-12))

heat = grad_energy_for_ray(0)
import bhnerf.visualization as vis
vis.ipyvolume_3d(heat, fov_M, level=[0.1,0.3,0.5,0.7])
