In [1]:
import jax.numpy as np
import dLux as dl
import dLuxToliman as dlT
import dLux.utils as dlu
import matplotlib.pyplot as plt
import jax

# constructing the telescope:
wf_npixels = 256 # nice and low res
diameter = 0.125
period = 304e-6
difference = np.pi * 0.348

apertureLayer = dlT.TolimanApertureLayer(wf_npixels)

layers = [
    (
        'aperture',
        apertureLayer
    )
]

psf_npixels = 300
true_pixel_scale = 0.375
oversample = 4 # high level oversample

optics = dl.AngularOpticalSystem(wf_npixels,diameter,layers,psf_npixels,true_pixel_scale,oversample)

In [2]:
import pandas as pd
from scipy.stats import binned_statistic

# Step 1: Load from CSV
df = pd.read_csv("alpha_cen_A_spectrum.csv")  # replace with correct path if needed

# Step 2: Apply mask (wavelengths in Ångstroms, like original)
mask = (df["wavelength"] >= 5200) & (df["wavelength"] <= 6500)

# Step 3: Extract and convert to meters
wavelengths = df["wavelength"][mask].to_numpy() * 1e-10  # meters
weights = df["flux"][mask].to_numpy()

# Step 4: Downsample using binning
n_bins = 500
bin_means, bin_edges, _ = binned_statistic(
    wavelengths, weights, statistic='mean', bins=n_bins
)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

# Step 5: Create the point source
alpha_cen_a_source = dl.PointSource(
    wavelengths=bin_centers,
    weights=bin_means,
    flux=2.909*10**7 # from max charles code
)

In [3]:
alpha_cen_a_scope = dl.Telescope(optics,alpha_cen_a_source)

sidelobescope = dlT.SideLobeTelescope(alpha_cen_a_scope, period, difference)

In [4]:
center_wl = 585e-9
pixel_scale = 0.375
# 4 sidelobes! make sure to set downsample = oversmaple
sidelobes_raw = sidelobescope.model_4_sidelobes(center_wavelength = center_wl, assumed_pixel_scale = pixel_scale, downsample = oversample)

Model time: 25.9373 seconds.
Model time: 25.3961 seconds.
Model time: 21.0553 seconds.
Model time: 27.4330 seconds.


In [5]:
import jax.random as jr
key = 0
sidelobes_poisson = jr.poisson(jr.PRNGKey(key),sidelobes_raw)

In [10]:
# let's try recreating the figure but better this time:
# no need for poisson on the model
weights_path = 'telescope.source.spectrum.weights'

model = sidelobescope.set(weights_path, np.ones(n_bins)/n_bins)
# model_psf = model.model_4_sidelobes(center_wavelength = center_wl, assumed_pixel_scale = pixel_scale, downsample = oversample)

# Optimisation
import zodiax as zdx
import optax

weights_optimiser = optax.adam(1e-4)

optim, opt_state = zdx.get_optimiser(model, weights_path, weights_optimiser)

In [11]:
@zdx.filter_jit
@zdx.filter_value_and_grad(weights_path)
def loss_fn(model, data):
    out = model.model_4_sidelobes(center_wavelength = center_wl, assumed_pixel_scale = pixel_scale, downsample = oversample)
    return -np.sum(jax.scipy.stats.poisson.logpmf(data, out))

In [13]:
%%time
loss, grads = loss_fn(model, sidelobes_poisson) # Compile
print("Initial Loss: {}".format(int(loss)))

Model time: 6.9274 seconds.
Model time: 6.6153 seconds.
Model time: 6.9445 seconds.
Model time: 7.0203 seconds.


2025-05-22 13:41:52.462851: E external/xla/xla/service/slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_loss_fn] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-05-22 13:41:52.698723: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 2m0.237762s

********************************
[Compiling module jit_loss_fn] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


Initial Loss: 409894
CPU times: user 10min 43s, sys: 55.2 s, total: 11min 39s
Wall time: 7min 18s


In [None]:
from tqdm import tqdm

losses, models_out = [], []
with tqdm(range(100),desc='Gradient Descent') as t:
    for i in t: 
        loss, grads = loss_fn(model, sidelobes_poisson)    
        updates, opt_state = optim.update(grads, opt_state)
        #print(updates.source.weights[0])
        # prevent negative weights
        # Get current weights
        weights = model.get(weights_path)

        #print(weights[0])
        # # Get update for weights
        weight_updates = updates.get(weights_path)
        #print(weight_updates[0])
        # # Clamp the update: ensure weights + update >= 0
        epsilon = 1e-6
        safe_updates = np.where(weights + weight_updates < epsilon, epsilon-weights, weight_updates)
        #print(safe_updates[0])
        # # Create a copy of updates and replace the weights update with safe update
        updates = updates.set(weights_path, safe_updates)
        

        # Now apply
        model = zdx.apply_updates(model, updates)

        model = model.set(weights_path, np.maximum(np.zeros(n_bins),model.telescope.source.spectrum.weights))
        losses.append(loss)
        models_out.append(model)
        t.set_description("Log Loss: {:.3f}".format(np.log10(loss))) # update the progress bar

Log Loss: 5.589:   2%|▏         | 2/100 [05:17<4:18:33, 158.30s/it]