In [1]:
# Core jax
import jax
# from jax.config import config
jax.config.update("jax_enable_x64", True)

import jax.numpy as np
import jax.random as jr

# Optimisation
import zodiax as zdx
import optax

# Optics
import dLux as dl
import dLux.utils as dlu

# Plotting/visualisation
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# gonna have to define my own log likelihoods which are analytic contiuations???

from jax.scipy.special import gammaln, xlogy

def poisson_loglike_gamma(k, mu):
    # analytic continuation using gammaln(k+1)
    k = np.asarray(k)
    mu = np.asarray(mu)
    valid = (k >= 0) & np.isfinite(k) & (mu > 0) & np.isfinite(mu)
    logp = xlogy(k, mu) - mu - gammaln(k + 1.0)
    return np.where(valid, logp, -np.inf)

def poisson_loglike_unnormalized(k, mu):
    # returns k*log(mu) - mu  (no normalization constant)
    return xlogy(k, mu) - mu

# define the likelihood function:
def poiss_loglike(pytree, data) -> float:
    """
    Poissonian log likelihood of the pytree given the data. Assumes the pytree
    has a .model() function.

    Parameters
    ----------
    pytree : Base
        Pytree with a .model() function.
    data : Array
        Data to compare the model to.

    Returns
    -------
    log_likelihood : Array
        Log likelihood of the pytree given the data.
    """
    return poisson_loglike_gamma(pytree.model(), data).sum()

plt.rcParams["image.origin"] = 'lower'

from pathlib import Path

In [2]:
# Marginal params
marginal_params = 'position'

# Loss function
opt_param = 'pupil.coefficients'
@zdx.filter_jit
@zdx.filter_value_and_grad(opt_param)
def fim_loss_func(model, parameters):
    #updating this to work for us
    # define a new field 'data'
    # this is basically like taking the expected value
    data = model.model()
    cov = zdx.covariance_matrix(model, parameters, poiss_loglike, data=data)
    # Loss function: trace of covariance in x and y
    return np.log(np.trace(cov))

In [3]:
from dLuxToliman import ApplyBasisCLIMB, TolimanOpticalSystem
for fac in np.array([4,8,16]):
    for idx in np.arange(10):
        # Random seed key
        key = jr.PRNGKey(idx)

        # Construct source and layers
        # For this to work, best to just keep single wavelength of 585e-9
        # Why? Because toliman optical system gives a OPD layer which only corresponds to a pi phase shift at 585 nm
        wf_npix = 256 
        diameter = 0.125 # This is default of toliman optical system. Can change.

        ### Can apply these to the input if you just want a clear aperture
        m2_diameter = 0
        strut_width = 0
        ###

        wavelength = 585e-9

        # This is to what degree we undersample Nyquist
        nyquist_factor = fac
        psf_npix = 256 / nyquist_factor
        psf_pixel_scale = nyquist_factor * dlu.rad2arcsec(wavelength/(2*diameter))
        oversample = nyquist_factor * 4 # sample at 4x nyquist originally

        # Place the source in the centre of the top right from center pixel.
        # That's what the position argument does here
        source = dl.PointSource(flux=1e6,position=dlu.arcsec2rad(np.array([psf_pixel_scale/2, psf_pixel_scale/2])),wavelengths=np.array([wavelength]))

        # Optimisable binary mask
        climb_basis = np.load("../xk4/files/basis.npy")
        coefficients = 100*jr.normal(key, [len(climb_basis)])
        mean_wl = source.wavelengths.mean()

        # have to make sure that mean_wl is in meters.
        mask_layer = ApplyBasisCLIMB(climb_basis, mean_wl, coefficients)

        # Construct instrument
        optics = TolimanOpticalSystem(wf_npixels=wf_npix,psf_npixels=psf_npix, mask=mask_layer, 
                                    radial_orders=[2, 3], psf_pixel_scale=psf_pixel_scale, oversample=oversample)

        # detector layer
        detector_layers = [
            (
                'downsample',
                dl.detector_layers.Downsample(oversample)
            )
        ]

        detector = dl.LayeredDetector(layers = detector_layers)
        model = dl.Telescope(optics, source, detector)

        optim, opt_state = zdx.get_optimiser(model, opt_param, optax.adam(2e1))

        losses, models_out = [], [model]
        with tqdm(range(100),desc='Gradient Descent') as t:
            for i in t: 
                loss, grads = fim_loss_func(model, marginal_params)
                updates, opt_state = optim.update(grads, opt_state)
                model = zdx.apply_updates(model, updates)
                models_out.append(model)
                losses.append(loss)
                t.set_description("Loss: {:.3f}".format(loss)) # update the progress bar

        final_phase = models_out[int(np.argmin(np.array(losses)))].get_binary_phase()
        # Make output directory
        out_dir = Path("phases") / f"{fac}_subsample"
        out_dir.mkdir(parents=True, exist_ok=True)
        np.save(out_dir / f"index_{idx}.npy", final_phase)
        print("saved")

Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved


Gradient Descent:   0%|          | 0/100 [00:00<?, ?it/s]

saved
