# Querying butler WEP outputs

Owner: **Guillem Megias** ([@gmegh](https://github.com/lsst-ts/ts_aos_analysis/issues/new?body=@gmegh)) 

Last Verified to Run: **2024-08-22**

Software Versions:
* lsst_distrib: **w_2024_32**

Requirements:
* [summit_utils](https://github.com/lsst-sitcom/summit_utils)


## Notebook Objective
The goal of this notebook is to show the user how to query and access basic Wavefront Estimation Pipeline (WEP) outputs from the butler. 

## Logistics
This notebook can be run from USDF or the Summit. The current version uses some collections present in the summit butler, if they have been wiped by the time you run this notebook, you will have to use a different collection and/or the butler in USDF. 

If running from USDF the following butler and collections are recommended as an example:
* collection: 'sitcomtn-135/directDetectTimingTest_SourceLimit5'
* butler: '/sdf/data/rubin/repo/aos_imsim'

## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as pf
from lsst.daf.butler import Butler
from lsst.ip.isr import IsrTask, IsrTaskConfig
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
from lsst.summit.utils.plotting import plot
%matplotlib inline

## Setting up the butler and collections from mtaos runs

In [None]:
#collections = ['u/saluser/ra_wep_testing4']
#collections = ['mtaos_wep_Script:102889_20240502T042949505']
butler = Butler('/repo/LSSTComCam', collections=["LSSTComCam/raw/all","LSSTComCam/calib", "LSSTComCam/quickLook"])
registry = butler.registry

In [None]:
registry.queryDataIds()

In [None]:
collections=["LSSTComCam/raw/all","LSSTComCam/calib", "LSSTComCam/quickLook"]
butler = Butler('/repo/LSSTComCam', collections=collections)
registry = butler.registry
dayObs = 20241112
visit = 2024111200266
name = 'zernikeEstimateAvg'
registry = butler.registry
for dtype in registry.queryDatasetTypes()[:]:
    if dtype.name == name:
        datasetRefs = list(registry.queryDatasets(datasetType=dtype,collections=collections, day_obs=dayObs, visit=visit))
        print(len(datasetRefs), dtype )


In [None]:
#collections=['u/saluser/ra_wep_testing4']
registry = butler.registry
dayObs = 20241112
for dtype in registry.queryDatasetTypes()[:]:
    try:
        datasetRefs = list(registry.queryDatasets(datasetType=dtype,collections=collections, day_obs=dayObs))
        if len(datasetRefs) > 0:
            print(len(datasetRefs), dtype )
    except:
        print("Error", dtype )

In [None]:
dtype.name

## Get PostISR and plot exposure

In [None]:
postisr = list(registry.queryDatasets('postISRCCD', collections=collections, detector = 0))
exposure_ids = np.array([data.dataId['exposure'] for data in postisr])
exposure = butler.get(postisr[0])

In [None]:
fig = plt.figure(figsize=(10, 10))
_ = plot(exposure, figure=fig, stretch='ccs')

### Plot intra focal postISRCCDs for each detector

In [None]:
plt.figure(figsize = (10, 10))
for selected_detector in range(9):
    plt.subplot(3,3, selected_detector + 1)
    list_of_postisrs = list(registry.queryDatasets('postISRCCD', collections=collections, detector = selected_detector, exposure = np.min(exposure_ids)))
    plt.imshow(butler.get(list_of_postisrs[0]).image.array, norm=LogNorm(vmax = 3e3), cmap = 'gray')
    plt.title(f'Detector {selected_detector + 1} - Intra focal')
    plt.xlabel('pixels')
    plt.ylabel('pixels')

plt.tight_layout()


### Plot extra focal postISRCCDs for each detector

In [None]:
plt.figure(figsize = (10, 10))
for selected_detector in range(9):
    plt.subplot(3,3, selected_detector + 1)
    list_of_postisrs = list(registry.queryDatasets('postISRCCD', collections=collections, detector = selected_detector, exposure = np.max(exposure_ids)))
    plt.imshow(butler.get(list_of_postisrs[0]).image.array, norm=LogNorm(vmax = 3e3), cmap = 'gray')
    plt.title(f'Detector {selected_detector + 1} - Extra focal')
    plt.xlabel('pixels')
    plt.ylabel('pixels')

plt.tight_layout()

## Donut Stamps

### Donut stamps object

In [None]:
selected_detector = 4
visit = 2024102900053
list_stamps_cutout = list(registry.queryDatasets("donutStampsExtra", collections=collections, \
                         detector = selected_detector, visit=visit))

In [None]:
list_stamps_cutout[0]

In [None]:
list_stamps_cutout[0].dataId.full_values

In [None]:
for item in list_stamps_cutout:
    print(item.dataId.full_values[2])

In [None]:
selected_detector = 2
list_stamps_cutout = list(registry.queryDatasets("donutStampsIntra", collections=collections, detector = selected_detector))
list_of_stamps = butler.get(list_stamps_cutout[0])

# print what's inside of a donut stamp
print(list_of_stamps[0])

### Intra focal stamps

In [None]:
for selected_detector in range(9):
    list_stamps_cutout = list(registry.queryDatasets("donutStampsIntra", collections=collections, detector = selected_detector))
    list_of_stamps = butler.get(list_stamps_cutout[0])
    total_stamps = len(list_of_stamps)
    
    plt.figure(figsize = (17,4))
    for idx in range(total_stamps):
        plt.subplot((total_stamps // 25) + 1, 25, idx + 1)
        plt.imshow(list_of_stamps[idx].stamp_im.image.array)
        plt.axis('off') 
    
    plt.suptitle(f"Intra stamps - Detector {selected_detector + 1}")
    plt.tight_layout()

### Extra focal stamps

In [None]:
for selected_detector in range(9):
    list_stamps_cutout = list(registry.queryDatasets("donutStampsExtra", collections=collections, detector = selected_detector))
    list_of_stamps = butler.get(list_stamps_cutout[0])
    total_stamps = len(list_of_stamps)
    
    plt.figure(figsize = (17,4))
    for idx in range(total_stamps):
        plt.subplot((total_stamps // 25) + 1, 25, idx + 1)
        plt.imshow(list_of_stamps[idx].stamp_im.image.array)
        plt.axis('off') 
    
    plt.suptitle(f"Extra stamps - Detector {selected_detector + 1}")
    plt.tight_layout()

## Zernike estimates from WEP


In [None]:
detectors_label = ['R22_S00', 'R22_S01', 'R22_S02', 'R22_S10', 'R22_S11', 'R22_S12', 'R22_S20', 'R22_S21', 'R22_S22']     

# Create a figure and subplots with shared x and y axes
fig, axs = plt.subplots(3, 3, figsize=(15, 10), sharex=True, sharey=True)
xaxis = list(range(4, 11))
#zk_avg = np.zeros((6,9,19))
zk_avg = np.zeros((11,9,25))
for detector, ax in enumerate(axs.flatten()):
    list_of_collection_estimates = list(registry.queryDatasets('zernikeEstimateAvg', detector = detector))
    for idx, element in enumerate(list_of_collection_estimates):
        data = butler.get(element)  # Assuming this returns the data to be plotted
        zk_avg[idx, detector, :] = data
        ax.plot(xaxis, data[0][0:7], '.', label=element.run)
        
    ax.set_title(f'Detector {detector + 1} - {detectors_label[detector]} ')

    # Label only the bottom row and the leftmost column
    if detector % 3 == 0:  # First column
        ax.set_ylabel('um')
    if detector >= 6:  # Bottom row
        ax.set_xlabel('Zernike Noll Index')
        

plt.tight_layout()
plt.savefig(f"/home/cslage/MTAOS/zernikes/ComCam_Zernikes_24Oct24.png")

In [None]:
data[0]

# Plot the Donuts for a single visit pair

In [None]:
collections = ['LSSTComCam/quickLook']
butler = Butler('/repo/LSSTComCam', collections=collections)
registry = butler.registry
selected_detector = 5
visit = 2024111100285

list_stamps_cutout = list(registry.queryDatasets("donutStampsIntra", collections=collections, \
                                                 detector = selected_detector, visit=visit))
list_of_stamps = butler.get(list_stamps_cutout[0])
fig = plt.figure(figsize=(8,4))
plt.suptitle(f"Simonyi ComCam Donuts {visit}")
plt.subplot(1,2,1)
plt.title("Intra-Focus")
plt.imshow(list_of_stamps[0].stamp_im.image.array)
list_stamps_cutout = list(registry.queryDatasets("donutStampsExtra", collections=collections, \
                                                 detector = selected_detector, visit=visit))
print(len(list_stamps_cutout))
list_of_stamps = butler.get(list_stamps_cutout[0])
plt.subplot(1,2,2)
plt.title("Extra-Focus")
plt.imshow(list_of_stamps[0].stamp_im.image.array)

plt.savefig(f"/home/cslage/MTAOS/zernikes/ComCam_Donuts_{visit}_{selected_detector}.png")

In [None]:
test = butler.get(list_stamps_cutout[0])

In [None]:
md = test.metadata

In [None]:
for key in md.keys():
    print(key, md[key])

# Plot the Zernikes for a single visit pair

In [None]:
%matplotlib inline
#collections = ['u/saluser/ra_wep_testing4']
collections = ['LSSTComCam/quickLook']
butler = Butler('/repo/LSSTComCam', collections=collections)
registry = butler.registry
detector=4
visit = 2024111100285
list_of_collection_estimates = list(registry.queryDatasets('zernikeEstimateAvg', detector = detector, visit=visit))
element = list_of_collection_estimates[0]
data = butler.get(element)  

# Now plot it
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
xaxis = list(range(4, 16))
ax.plot(xaxis, data[0][0:12], marker='x', ls='',ms=10, mew=4)
ax.set_xlabel('Zernike Noll Index', fontsize=12)
ax.set_xticks(xaxis)
ax.set_ylabel('Zernike value(um)', fontsize=12)
ax.set_title(f"Zernike Plot {visit} Detector {detector}", fontsize=14)
plt.savefig(f"/home/cslage/MTAOS/zernikes/ComCam_Zernikes_{visit}_{detector}.png")

In [None]:
%matplotlib inline
collectionsA = ['u/saluser/ra_wep_testing4']
collectionsB = ['LSSTComCam/raw/all','LSSTComCam/calib','LSSTComCam/quickLook']
collectionsList = [collectionsA, collectionsB]
#butler = Butler('/repo/LSSTComCam', collections=collections)
#registry = butler.registry
detector=4
visits = [2024110500201, 2024111100286]
seqNums = [201, 266]
stamps = ['Intra', 'Extra']
markers = ['x', '+']
colors = ['red', 'green']
xaxis = list(range(4, 16))
fig = plt.figure(figsize=(12,10))
grid = plt.GridSpec(4, 4, wspace=0, hspace=0.2)
axs = []
axs.append(plt.subplot(grid[0:2, 0:2]))
axs.append(plt.subplot(grid[0, 2]))

axs.append(plt.subplot(grid[0, 3]))
axs.append(plt.subplot(grid[1, 2]))
axs.append(plt.subplot(grid[1, 3]))
for ax in axs[1:5]:
    ax.set_xticks([])
    ax.set_yticks([])

for i in range(2):
    butler = Butler('/repo/LSSTComCam', collections=collectionsList[i])
    registry = butler.registry

    list_of_collection_estimates = list(registry.queryDatasets('zernikeEstimateAvg', detector = detector, visit=visits[i]))
    element = list_of_collection_estimates[0]
    data = butler.get(element)  
    axs[0].plot(xaxis, data[0][0:12], marker=markers[i], color=colors[i], ls='',ms=10, mew=4, label=visits[i])
    axs[0].legend()
    axs[0].set_xlabel('Zernike Noll Index', fontsize=12)
    axs[0].set_xticks(xaxis)
    axs[0].axhline(0.0, ls='--', color='black')
    axs[0].set_ylabel('Zernike value(um)', fontsize=12)
    axs[0].set_ylim(-3,4)

plotCounter = 1
for i in range(2):
    for j in range(2):
        list_stamps_cutout = list(registry.queryDatasets(f"donutStamps{stamps[j]}", collections=collectionsList[i], \
                                                         detector = detector, visit=visits[i]))
        list_of_stamps = butler.get(list_stamps_cutout[0])
        axs[plotCounter].imshow(list_of_stamps[0].stamp_im.image.array)
        axs[plotCounter].set_title(f"{seqNums[i]}-{stamps[j]}")
        plotCounter += 1

plt.savefig(f"/home/cslage/MTAOS/zernikes/ComCam_Zernikes_Comparison_12Nov24.png")

In [None]:
%matplotlib inline
collectionsA = ['u/saluser/ra_wep_testing4']
collectionsB = ['LSSTComCam/quickLook']
collectionsList = [collectionsA, collectionsB]
butler = Butler('/repo/LSSTComCam', collections=collections)
registry = butler.registry
detector=5
visits = [2024110500201, 2024110800242]
seqNums = [201, 242]
stamps = ['Intra', 'Extra']
markers = ['x', '+']
colors = ['red', 'green']
xaxis = list(range(4, 16))
fig, ax = plt.subplots(1,1,figsize=(5,5))
ax.set_title("Zernike Comparison")

for i in range(2):
    butler = Butler('/repo/LSSTComCam', collections=collectionsList[i])
    registry = butler.registry

    list_of_collection_estimates = list(registry.queryDatasets('zernikeEstimateAvg', detector = detector, visit=visits[i]))
    element = list_of_collection_estimates[0]
    data = butler.get(element)  
    ax.plot(xaxis, data[0][0:12], marker=markers[i], color=colors[i], ls='',ms=10, mew=4, label=visits[i])
    ax.legend()
    ax.set_xlabel('Zernike Noll Index', fontsize=12)
    ax.set_xticks(xaxis)
    ax.set_ylabel('Zernike value(um)', fontsize=12)
    ax.set_ylim(-4,4)
    ax.axhline(0.0, ls='--', color='black')
    


plt.savefig(f"/home/cslage/MTAOS/zernikes/ComCam_Zernikes_Comparison_08Nov24.png")