In [1]:
import amigo
from amigo.optical_models import AMIOptics
import os
import sys
import jax
import jax.numpy as np
import matplotlib.pyplot as plt
import dLux.utils as dlu
import dLux as dl

# Basic jax import
jax.config.update("jax_enable_x64", True)


from jax.scipy.special import gammaln, xlogy

def poisson_loglike_gamma(k, mu):
    # analytic continuation using gammaln(k+1)
    k = np.asarray(k)
    mu = np.asarray(mu)
    valid = (k >= 0) & np.isfinite(k) & (mu > 0) & np.isfinite(mu)
    logp = xlogy(k, mu) - mu - gammaln(k + 1.0)
    return np.where(valid, logp, -np.inf)

def poisson_loglike_unnormalized(k, mu):
    # returns k*log(mu) - mu  (no normalization constant)
    return xlogy(k, mu) - mu

# define the likelihood function:
def poiss_loglike(pytree, data) -> float:
    """
    Poissonian log likelihood of the pytree given the data. Assumes the pytree
    has a .model() function.

    Parameters
    ----------
    pytree : Base
        Pytree with a .model() function.
    data : Array
        Data to compare the model to.

    Returns
    -------
    log_likelihood : Array
        Log likelihood of the pytree given the data.
    """
    return poisson_loglike_gamma(pytree.model(), data).sum()

plt.rcParams["image.origin"] = 'lower'

# import alpha cen
from dLuxToliman import AlphaCen, TolimanOpticalSystem
import zodiax as zdx

In [2]:
ami_optics = AMIOptics(psf_npixels=40)
# 40 npixels because we're simulating the bottom right array which is the best according to Ben
print(ami_optics)

AMIOptics(
  wf_npixels=1024,
  diameter=6.603464,
  layers={
    'InvertY': Flip(axes=0),
    'pupil_mask':
    StaticApertureMask(
      transmission=f64[1024,1024],
      normalise=True,
      abb_basis=f64[7,10,180,180],
      abb_coeffs=f64[7,10],
      amp_basis=f64[7,10,180,180],
      amp_coeffs=f64[7,10],
      corners=i64[7,2]
    )
  },
  psf_npixels=40,
  oversample=3,
  psf_pixel_scale=0.065524085,
  filters={'F380M': f64[2,9], 'F430M': f64[2,9], 'F480M': f64[2,9]},
  defocus_type='fft',
  defocus=f64[],
  corners=i64[7,2],
  psf_upsample=3
)


In [3]:
# I wonder if I can make it a dLux telescope?
# from F480M filter
wlweights = np.array([[4.58283333e-06, 4.64250000e-06, 4.70216667e-06, 4.76183333e-06,
        4.82150000e-06, 4.88116667e-06, 4.94083333e-06, 5.00050000e-06,
        5.06016667e-06],
       [4.83897537e-03, 4.27569626e-02, 1.79674591e-01, 1.84285712e-01,
        2.00072602e-01, 1.85606063e-01, 1.65577628e-01, 3.61394316e-02,
        1.04803396e-03]])

# just a bunch of wavelengths
wavelengths = wlweights[0]
weights = wlweights[1]

# let us instead make a polynomial spectrum
# of the form wavelengths, coefficients
# this will double the spectral weight over the bandpass
spectrum = dl.PolySpectrum(wavelengths, np.array([-8.54,2083333]))

In [4]:
log_fluxes = np.linspace(7,10,4)

flux = 10**(log_fluxes[2]) #6.016
# binary source,
# going to have to import alpha cen because it has individial x and y position ? no
# source = AlphaCen(wavelengths=wavelengths, log_flux=np.log10(flux))
separation = 2*4.80e-6/6.603464 # lambda/d separation

# 100:1 contrast ratio?
contrast = 100
source = dl.BinarySource(spectrum=spectrum, position=dlu.arcsec2rad(np.array([0,0])), 
                        mean_flux=flux, separation=separation, position_angle=np.pi/2, contrast=contrast)

## could just make a new point source
# source = dl.PointSource(wavelengths=wavelengths, position = dlu.arcsec2rad(np.array([0,0])), flux = flux)

# optional detector
detector_layers = [
    (
        'downsample',
        dl.detector_layers.Downsample(3)
    )
]
detector = dl.LayeredDetector(detector_layers)

# try extracting the AMI optics and putting it into a normal dlux system.. this is SO JANK
transmission = ami_optics.transmission
aperture_layer = dl.TransmissiveLayer(transmission, True)

layers = [
    (
        'aperture',
        aperture_layer
    )
]
optics = dl.AngularOpticalSystem(wf_npixels=1024, diameter=6.603464, psf_npixels=80, oversample=3, psf_pixel_scale=0.065524085, layers=layers)

# setting up the dithers

In [9]:
dithers = dlu.arcsec2rad(np.array([[7,0],[0,-7]]))
model = dl.Dither(dithers, optics, source, detector)

data = model.model()
# plt.imshow(np.log10(data[0]))
# plt.show()

In [10]:
# parameters
marginal_params = ['separation', 'position_angle', 'mean_flux', 'contrast', 'wavelengths', 'coefficients', 'position']
# marginalise over mean wavelength ..??? ?? 
shape_dict = {'wavelengths': (1,)}

cov = zdx.covariance_matrix(model, marginal_params, poiss_loglike, data=data, shape_dict=shape_dict)

In [11]:
print(model)

Dither(
  optics=AngularOpticalSystem(
    wf_npixels=1024,
    diameter=6.603464,
    layers={
      'aperture': TransmissiveLayer(transmission=f64[1024,1024], normalise=True)
    },
    psf_npixels=80,
    oversample=3,
    psf_pixel_scale=0.065524085
  ),
  source=BinarySource(
    spectrum=PolySpectrum(wavelengths=f64[9], coefficients=f64[2]),
    position=f64[2],
    mean_flux=1000000000.0,
    separation=1.4537824390350277e-06,
    position_angle=1.5707963267948966,
    contrast=100.0
  ),
  detector=LayeredDetector(layers={'downsample': Downsample(kernel_size=3)}),
  dithers=f64[2,2]
)


In [15]:
# --- hyperparams you can tune ---
MAX_FLUX = 30000     # photons/pixel (or whatever your units/normalisation are)
PENALTY_WEIGHT = 1e-3   # increase until the cap is respected
SMOOTH_TAU = 0.0       # set >0 for smooth hinge (e.g. 0.01 * MAX_FLUX)

def pixel_cap_penalty(img, cap=MAX_FLUX, weight=PENALTY_WEIGHT, tau=SMOOTH_TAU):
    """
    Penalize *every* pixel that exceeds `cap`.
    tau = 0.0  -> hard hinge (piecewise-smooth, subgrad at kink)
    tau > 0.0  -> smooth hinge using softplus with sharpness ~ 1/tau
    """
    if tau > 0.0:
        # smooth ReLU: tau * log(1 + exp((x - cap)/tau))
        excess = np.softplus((img - cap) / tau) * tau
    else:
        # hard hinge ReLU
        excess = np.clip(img - cap, a_min=0.0)
        # (effectively just adds the square error)
    # mean over all pixels keeps penalty magnitude roughly scale-invariant to image size
    return weight * np.mean(excess**2)

In [20]:
# Loss function
opt_param = 'dithers'
@zdx.filter_jit
@zdx.filter_value_and_grad(opt_param)

def fim_loss_func(model, parameters):
        
    data = model.model()

    cov = zdx.covariance_matrix(model, parameters, poiss_loglike, data=data, shape_dict=shape_dict)

    # trace of the first 3
    fim_term = np.log10(cov[0,0])

    cap_term = pixel_cap_penalty(data, cap=MAX_FLUX, weight=PENALTY_WEIGHT, tau=SMOOTH_TAU)

    loss = fim_term + cap_term
    
    return loss

In [None]:
# working
import optax
from tqdm import tqdm
optim, opt_state = zdx.get_optimiser(model, opt_param, optax.adam(dlu.arcsec2rad(0.3)))

losses, models_out = [], [model]
with tqdm(range(50), desc='Gradient Descent') as t:
    for i in t:
        loss, grads = fim_loss_func(model, marginal_params)
        updates, opt_state = optim.update(grads, opt_state)   # your zdx API
        model = zdx.apply_updates(model, updates)
        models_out.append(model)
        losses.append(loss)
        t.set_description("Loss: {:.6f}".format(float(loss)))

Gradient Descent:   0%|          | 0/50 [02:32<?, ?it/s]


KeyboardInterrupt: 

: 