# Querying butler WEP outputs

Owner: **Guillem Megias** ([@gmegh](https://github.com/lsst-ts/ts_aos_analysis/issues/new?body=@gmegh)) 

Last Verified to Run: **2024-08-22**

Software Versions:
* lsst_distrib: **w_2024_32**

Requirements:
* [summit_utils](https://github.com/lsst-sitcom/summit_utils)


## Notebook Objective
The goal of this notebook is to show the user how to query and access basic Wavefront Estimation Pipeline (WEP) outputs from the butler. 

## Logistics
This notebook can be run from USDF or the Summit. The current version uses some collections present in the summit butler, if they have been wiped by the time you run this notebook, you will have to use a different collection and/or the butler in USDF. 

If running from USDF the following butler and collections are recommended as an example:
* collection: 'sitcomtn-135/directDetectTimingTest_SourceLimit5'
* butler: '/sdf/data/rubin/repo/aos_imsim'

## Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as pf
from lsst.daf.butler import Butler
from lsst.ip.isr import IsrTask, IsrTaskConfig
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
from lsst.summit.utils.plotting import plot
%matplotlib inline

  "cipher": algorithms.TripleDES,
  "class": algorithms.TripleDES,


## Setting up the butler and collections from mtaos runs

In [3]:
collections = ['mtaos_wep_Script:102889_20240502T042949505']

butler = Butler('/repo/LSSTComCam', collections=collections)
registry = butler.registry

## Get PostISR and plot exposure

In [4]:
postisr = list(registry.queryDatasets('postISRCCD', collections=collections, detector = 0))
exposure_ids = np.array([data.dataId['exposure'] for data in postisr])
exposure = butler.get(postisr[0])

In [None]:
fig = plt.figure(figsize=(10, 10))
_ = plot(exposure, figure=fig, stretch='ccs')

### Plot intra focal postISRCCDs for each detector

In [None]:
plt.figure(figsize = (10, 10))
for selected_detector in range(9):
    plt.subplot(3,3, selected_detector + 1)
    list_of_postisrs = list(registry.queryDatasets('postISRCCD', collections=collections, detector = selected_detector, exposure = np.min(exposure_ids)))
    plt.imshow(butler.get(list_of_postisrs[0]).image.array, norm=LogNorm(vmax = 3e3), cmap = 'gray')
    plt.title(f'Detector {selected_detector + 1} - Intra focal')
    plt.xlabel('pixels')
    plt.ylabel('pixels')

plt.tight_layout()


### Plot extra focal postISRCCDs for each detector

In [None]:
plt.figure(figsize = (10, 10))
for selected_detector in range(9):
    plt.subplot(3,3, selected_detector + 1)
    list_of_postisrs = list(registry.queryDatasets('postISRCCD', collections=collections, detector = selected_detector, exposure = np.max(exposure_ids)))
    plt.imshow(butler.get(list_of_postisrs[0]).image.array, norm=LogNorm(vmax = 3e3), cmap = 'gray')
    plt.title(f'Detector {selected_detector + 1} - Extra focal')
    plt.xlabel('pixels')
    plt.ylabel('pixels')

plt.tight_layout()

## Donut Stamps

### Donut stamps object

In [7]:
selected_detector = 2
list_stamps_cutout = list(registry.queryDatasets("donutStampsIntra", collections=collections, detector = selected_detector))
list_of_stamps = butler.get(list_stamps_cutout[0])

# print what's inside of a donut stamp
print(list_of_stamps[0])

DonutStamp(stamp_im=lsst.afw.image._maskedImage.MaskedImageF=(image=[[ -4.731903    18.62201     16.953766   ... -11.362457     3.6515808
    3.6523743 ]
 [ 48.97763     -1.0666504    5.605835   ... -26.04692     -6.0284424
   40.680237  ]
 [ 31.960602     5.2703247   13.610962   ...  23.661438    20.325958
    0.30908203]
 ...
 [ -6.791687    16.56192     -6.7924194  ...  13.240967    41.600067
   14.9105835 ]
 [ 11.279266    14.615204     7.9422607  ...  54.66574      6.2904663
   -5.385803  ]
 [  8.218292    19.894897     4.8812866  ...  14.905579     8.233765
   11.570801  ]],
mask=[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], maskPlaneDict={'BAD': 0, 'CR': 3, 'DETECTED': 5, 'DETECTED_NEGATIVE': 6, 'EDGE': 4, 'INTRP': 2, 'NO_DATA': 8, 'SAT': 1, 'SUSPECT': 7}
variance=[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ..

### Intra focal stamps

In [None]:
for selected_detector in range(9):
    list_stamps_cutout = list(registry.queryDatasets("donutStampsIntra", collections=collections, detector = selected_detector))
    list_of_stamps = butler.get(list_stamps_cutout[0])
    total_stamps = len(list_of_stamps)
    
    plt.figure(figsize = (17,4))
    for idx in range(total_stamps):
        plt.subplot((total_stamps // 25) + 1, 25, idx + 1)
        plt.imshow(list_of_stamps[idx].stamp_im.image.array)
        plt.axis('off') 
    
    plt.suptitle(f"Intra stamps - Detector {selected_detector + 1}")
    plt.tight_layout()

### Extra focal stamps

In [None]:
for selected_detector in range(9):
    list_stamps_cutout = list(registry.queryDatasets("donutStampsExtra", collections=collections, detector = selected_detector))
    list_of_stamps = butler.get(list_stamps_cutout[0])
    total_stamps = len(list_of_stamps)
    
    plt.figure(figsize = (17,4))
    for idx in range(total_stamps):
        plt.subplot((total_stamps // 25) + 1, 25, idx + 1)
        plt.imshow(list_of_stamps[idx].stamp_im.image.array)
        plt.axis('off') 
    
    plt.suptitle(f"Extra stamps - Detector {selected_detector + 1}")
    plt.tight_layout()

## Zernike estimates from WEP


In [None]:
detectors_label = ['R22_S00', 'R22_S01', 'R22_S02', 'R22_S10', 'R22_S11', 'R22_S12', 'R22_S20', 'R22_S21', 'R22_S22']     

# Create a figure and subplots with shared x and y axes
fig, axs = plt.subplots(3, 3, figsize=(15, 10), sharex=True, sharey=True)
zk_avg = np.zeros((6,9,19))
for detector, ax in enumerate(axs.flatten()):
    list_of_collection_estimates = list(registry.queryDatasets('zernikeEstimateAvg', detector = detector))
    for idx, element in enumerate(list_of_collection_estimates):
        data = butler.get(element)  # Assuming this returns the data to be plotted
        zk_avg[idx, detector, :] = data
        ax.plot(data, '.', label=element.run)
        
    ax.set_title(f'Detector {detector + 1} - {detectors_label[detector]} ')

    # Label only the bottom row and the leftmost column
    if detector % 3 == 0:  # First column
        ax.set_ylabel('um')
    if detector >= 6:  # Bottom row
        ax.set_xlabel('Zernike Noll Index')
        

plt.tight_layout()
plt.show()