# WET-013: WEP Performance as a function of exposure time

Owner: **Bryce Kalmbach** [@jbkalmbach](https://github.com/lsst-sitcom/sitcomtn-133/issues/new?body=@jbkalmbach) <br>
Last Verified to Run: **2024-10-17** <br>
Software Version:
  - `ts_wep`: **12.0.0**
  - `lsst_distrib`: **w_2024_42**

## Test Description

This test will look at the WEP output from multiple defocal visits across a range of exposure times to investigate if increasing exposure time helps average out the atmospheric residuals.
We will calculate the average Zernikes for each visit and then find the variation in the estimates of the Zernikes from estimates on visits with the same exposure time.

# Imports

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from lsst.daf.butler import Butler
from astropy.io import fits
from astropy import units as u
from IPython.utils import io
from lsst.ts.wep.task.pairTask import ExposurePairer, ExposurePairerConfig
from lsst.ts.wep.utils import convertZernikesToPsfWidth
from astropy.table import Table, QTable, unique
from scipy.optimize import curve_fit
%matplotlib inline

In [None]:
# Change this path to appropriate butler when on-sky images arrive
path_to_aos_butler = '/sdf/data/rubin/repo/aos_imsim/'
butler = Butler(path_to_aos_butler)

## Load Zernike Estimates

When running exposure time tests we will run the Wavefront Estimation Pipeline (WEP) on the images. 
Once this is done all we need is the collection name used when running the pipeline and we can generate our analysis using the code below.

In [None]:
collection_name = 'WET-013/directDetectCatalog_RefitWcs'

In [None]:
# Load the data ids from the collection with the WEP output
data_ids = list(butler.registry.queryDataIds(('exposure', 'visit', 'detector'), collections=collection_name, datasets='zernikeEstimateAvg'))

In [None]:
# Gather exposure time for each output by looking at the `visitInfo` for each exposure.
print(butler.get('postISRCCD.visitInfo', dataId=data_ids[0], collections=collection_name))

In [None]:
def getZernAvgFromTable(table, z_min=4, z_max=29):
    """Gather the average zernikes in microns from the zernikes table into a single numpy array"""
    avg_row_idx = np.where(zern_table['label'] == 'average')
    zern_avg = []
    for z_num in range(z_min, z_max):
        zern_avg.append(zern_table[f'Z{z_num}'].to(u.um)[avg_row_idx].value[0])
    zern_avg = np.array(zern_avg)
    return zern_avg

In [None]:
# Gather relevant visit info and zernike outputs into an Astropy table
exp_time_list = []
airmass_list = []
visit_list = []
detector_list = []
zern_avg_list = []
for data_id in data_ids:
    zern_table = butler.get('zernikes', dataId=data_id, collections=collection_name)
    zern_avg_list.append(getZernAvgFromTable(zern_table))
    visit_list.append(data_id['visit'])
    detector_list.append(data_id['detector'])
    visitInfo = butler.get('postISRCCD.visitInfo', dataId=data_id, collections=collection_name)
    exp_time_list.append(visitInfo.exposureTime)
    airmass_list.append(visitInfo.boresightAirmass)

data_table = QTable([exp_time_list, visit_list, detector_list, airmass_list, zern_avg_list], names=['exp_time', 'visit', 'detector', 'airmass', 'zern_avg'])

## Exposure Time Analysis

### Examine the dataset

Just take a quick look at the various exposure times used in the data and the number of visits for each exposure time.

In [None]:
# Comcam detector Ids
detector_ids = np.arange(9)
# Get exposure times directly from data set
exp_times = np.unique(data_table['exp_time'])

In [None]:
exp_time_counts = []
for exp_time in exp_times:
    exp_time_counts.append(np.sum(np.logical_and(data_table['exp_time'] == exp_time, data_table['detector'] == detector_ids[0])))
plt.plot(exp_times, exp_time_counts, '-o')
plt.title('Number of Visits with each exposure time')
plt.ylabel('Number of Visits')
plt.xlabel('Exposure Time (seconds)')
plt.tight_layout()

### Consistency of Mean Value across Exposure Times

In this first plot we examine the mean value across the different runs. If we are in the same optical state during the different observations then we should see that the mean value will be approximately the same for each Zernike across the different exposure times. We can also separate it by detector to see if there are any effects on detectors with more vignetting than others.

In [None]:
fig = plt.figure(figsize=(18,12))
exp_times = np.unique(data_table['exp_time'])
for detector in range(9):
    fig.add_subplot(3,3,detector+1)
    det_table = data_table[data_table['detector'] == detector]
    for exp_time in exp_times:
        exp_time_table = det_table[det_table['exp_time'] == exp_time]
        zern_avg_array = np.array(exp_time_table['zern_avg'].value)
        plt.plot(np.arange(4, 29), convertZernikesToPsfWidth(np.mean(zern_avg_array, axis=0)), label=f'Exp Time {exp_time} sec')
        plt.xlabel('Noll Index')
        plt.ylabel('Zernike Estimate (arcsec)')
        plt.title(f'Detector {detector}')
    plt.legend(fontsize=8)
plt.suptitle('Mean Zernike Estimate on ComCam sims across Exposure Times by detector')
plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(8, 5))
exp_times = np.unique(data_table['exp_time'])
for exp_time in exp_times:
    exp_time_table = data_table[data_table['exp_time'] == exp_time]
    zern_avg_array = np.array(exp_time_table['zern_avg'].value)
    plt.plot(np.arange(4, 29), convertZernikesToPsfWidth(np.mean(zern_avg_array, axis=0)), label=f'Exp Time {exp_time} sec')
    plt.xlabel('Noll Index')
    plt.ylabel('Zernike Estimate (arcsec)')
    plt.legend(fontsize=8)
plt.title('Mean Zernike Estimate on ComCam sims across Exposure Times averaged across all detectors')

### Variability in the measurements for each Zernike

Since all the means across each detector for each exposure time seem fairly consistent we can compare the variability in the measurements for each Zernike on each detector by plotting the standard deviation for each Zernike coefficient separated by the detectors.

In [None]:
fig = plt.figure(figsize=(18,12))
for detector in range(9):
    fig.add_subplot(3,3,detector+1)
    det_table = data_table[data_table['detector'] == detector]
    for exp_time in exp_times:
        exp_time_table = det_table[det_table['exp_time'] == exp_time]
        zern_std_array = np.array(exp_time_table['zern_avg'].value)
        plt.plot(np.arange(4, 29), convertZernikesToPsfWidth(np.std(zern_std_array, axis=0)), label=f'Exp Time {exp_time} sec')
        plt.xlabel('Noll Index')
        plt.ylabel('Standard Deviation (arcsec)')
        plt.title(f'Detector {detector}')
    plt.legend(fontsize=8)
plt.suptitle('Standard Deviation of the Zernike estimate across 4 runs at each exposure time', size=18)
plt.tight_layout()

The plot above is rather busy and nothing really sticks out so the next thing to do is to combine all the measurements from all the detectors.
However, each detector has a slightly different true value for each Zernike.
In place of a true value we can subtract the mean value for each detector from each measurement and use this information to calculate a standard deviation across the whole camera.
This translates mathematically to:

$\sigma_{camera} = \sqrt{\frac{1}{detectors*visits} \sum \limits^{detectors}_{i} \sum \limits^{visits}_{j} (x_{i,j} - \overline{x}_{i})^{2}} $

In [None]:
fig = plt.figure(figsize=(10,6))
for exp_time in exp_times:
    num_rows = 0
    deviations = []
    for detector in detector_ids:
        use_rows = np.logical_and(data_table['exp_time'] == exp_time, data_table['detector'] == detector)
        detector_table = data_table[use_rows]
        deviations.append(detector_table['zern_avg'] - np.mean(detector_table['zern_avg'], axis=0))
        num_rows += len(detector_table)
    deviations = np.array(deviations).reshape(num_rows, 25)
    zern_std_array = np.sqrt(1 / (len(deviations)) * np.sum(np.square(deviations), axis=0))
    plt.plot(np.arange(4, 29), convertZernikesToPsfWidth(zern_std_array), label=f'Exp Time {exp_time} sec')
    plt.xlabel('Noll Index')
    plt.ylabel('Standard Deviation  (arcsec)')
    plt.legend(fontsize=8)
plt.title('Standard Deviation across 4 runs at each exposure time across all detectors')

The plot shows the general expected trend of the variation decreasing with increased exposure time.
Below we plot just the values at 10, 30, and 90 seconds to make the trend clearer.

In [None]:
fig = plt.figure(figsize=(10,6))
for exp_time in [10, 30, 90]:
    num_rows = 0
    deviations = []
    for detector in detector_ids:
        use_rows = np.logical_and(data_table['exp_time'] == exp_time, data_table['detector'] == detector)
        detector_table = data_table[use_rows]
        deviations.append(detector_table['zern_avg'] - np.mean(detector_table['zern_avg'], axis=0))
        num_rows += len(detector_table)
    deviations = np.array(deviations).reshape(num_rows, 25)
    zern_std_array = np.sqrt(1 / (len(deviations)) * np.sum(np.square(deviations), axis=0))
    plt.plot(np.arange(4, 29), convertZernikesToPsfWidth(zern_std_array), label=f'Exp Time {exp_time} sec')
    plt.xlabel('Noll Index')
    plt.ylabel('Standard Deviation (arcsec)')
    plt.legend(fontsize=8)
plt.title('Standard Deviation across 4 runs at each exposure time across all detectors')

Finally we look at the same information but take a cross section across each individual Zernike coefficient.

In [None]:
zern_std_exp_times_all = []
for exp_time in exp_times:
    num_rows = 0
    deviations = []
    for detector in detector_ids:
        use_rows = np.logical_and(data_table['exp_time'] == exp_time, data_table['detector'] == detector)
        detector_table = data_table[use_rows]
        deviations.append(convertZernikesToPsfWidth(detector_table['zern_avg']) - np.mean(convertZernikesToPsfWidth(detector_table['zern_avg']), axis=0))
        num_rows += len(detector_table)
    deviations = np.array(deviations).reshape(num_rows, 25)
    zern_std_array = np.sqrt(1 / (len(deviations)) * np.sum(np.square(deviations), axis=0))
    zern_std_exp_times_all.append(zern_std_array)
zern_std_exp_times_all = np.array(zern_std_exp_times_all)

In [None]:
fig = plt.figure(figsize=(20, 12))

for idx in range(25):
    fig.add_subplot(5, 5, idx+1)
    plt.scatter(exp_times, zern_std_exp_times_all[:, idx])
    plt.xlabel('Exp Time (sec)')
    plt.ylabel('Std. Dev. (arcsec)')
    plt.title(f'Z{idx+4}')
plt.suptitle('Standard Deviation as function of exposure time')
plt.tight_layout()


And then we can plot a fit and display the fit coefficients as a function of time.

In [None]:
fig = plt.figure(figsize=(20, 12))

for idx in range(25):
    fig.add_subplot(5, 5, idx+1)
    plt.scatter(exp_times, zern_std_exp_times_all[:, idx]**2)

    def fit_func(x, a, c):
        x0 = exp_times[0]
        return c * ((x/x0)**a)

    fit_exp, fit_var = curve_fit(fit_func, exp_times, zern_std_exp_times_all[:, idx]**2)

    def fit_func_fixed_tm1(x, c):
        x0 = exp_times[0]
        return c * ((x/x0)**-1)

    fit_exp_tm1, fit_var = curve_fit(fit_func_fixed_tm1, exp_times, zern_std_exp_times_all[:, idx]**2)
    
    plt.plot(exp_times, fit_exp_tm1[0]*(exp_times / exp_times[0])**-(1), label='t^-1')
    plt.plot(exp_times, fit_exp[1]*(exp_times / exp_times[0])**fit_exp[0], label=f'Fit: t^{fit_exp[0]:.2f}')
    plt.xlabel('Exp Time (sec)')
    plt.ylabel('Variance ($arcsec^2$)')
    plt.title(f'Z{idx+4}')
    plt.legend()
plt.suptitle('Variance as function of exposure time')
plt.tight_layout()
