In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact
from jax import vmap, jit, numpy as jnp, grad, hessian
from scipy.optimize import minimize
import pandas as pd
import seaborn as sns


shape = (32, 32)
noise_scale = 0.02
noisemap = noise_scale * jnp.ones(shape)
noise = np.random.normal(0, noise_scale, size=shape)
sigma_x = sigma_y = 1.5
y_indices, x_indices = jnp.indices(shape, dtype=jnp.float32)

def point_source(args):
    x, y, flux = args
    flux_eff = flux / jnp.sqrt(4*np.pi**2 * sigma_x**2 * sigma_y**2)
    return  flux_eff * jnp.exp(-(((x_indices - x)**2 / (2 * sigma_x**2)) + ((y_indices - y)**2 / (2 * sigma_y**2))))
    
@jit
def image_model(sources):
    # sources: array of shape (N, 3), each row is (x, y, flux)
    multi_ps = vmap(point_source, in_axes=(0,))
    image = jnp.sum(multi_ps(sources), axis=0)
    return image

@jit
def negative_log_likelihood(params, image):
    sources = params.reshape((2,3))
    return 0.5 * jnp.sum((image_model(sources) - image)**2 / noisemap**2) / (image.size-params.size)



likelihood_grad = grad(negative_log_likelihood)
hess = hessian(negative_log_likelihood, argnums=0)

def interactive_photometry(position):
    sources = jnp.asarray([[12., 16., 20.], [position, 16., 20.]]) 
    image = noise.copy()
    image += image_model(sources)

    params_ini = sources.flatten().copy()
    hh = hess(params_ini, image)
    cov = jnp.linalg.inv(hh)
    labels = ['x1', 'y1', 'f1', 'x2', 'y2', 'f2']
    df = pd.DataFrame(cov, index=labels, columns=labels)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    ax1.imshow(image, origin='lower', cmap='gray', vmin=-2*noise_scale, vmax=0.6)
    ax1.axis('off')
    sns.heatmap(df, annot=True, fmt=".2f", cmap='coolwarm', linewidths=0.5, ax=ax2,
                vmin=-5, vmax=30)

    
    plt.show()

interact(interactive_photometry, position=widgets.FloatSlider(min=12.0, max=22.0, step=0.1, value=22.0));
