In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from collections import defaultdict

import numpy as np

import lsst.daf.butler as dB
import lsst.cp.verify.notebooks.utils as utils
import lsst.afw.display as afwDisplay

In [None]:
# This cell contains parameters that can be automatically set via the papermill package.
# Examples:
#  Update parameters in input.ipynb, writing output.ipynb, but do not execute:
#   papermill --prepare-only -p calibType newBias -p cameraName LSSTCam <input> <output>
#  Disable interactive cells in input.ipynb, execute it, and write output.ipynb.
#   papermill -p interactive False <input> <output>
interactive = True

# Which repository to use.
repository = '/repo/main/'

# Which calibration type to analyse.
calibType = 'crosstalk'
detectorId = 0

# Which camera the calibration is for.
cameraName = 'LATISS'

# Which display to use.
displayBackend = 'astrowidgets'

# Which collection the calibration was constructed in.
genCollection = 'u/czw/DM-30170/crosstalkGen.20220322a'

# Which collection containing the verification outputs.
verifyCollection = 'u/czw/DM-30170/verifyCrosstalk.20220322a'

In [None]:
# Get butler and camera
butler = dB.Butler(repository, collections=[genCollection, verifyCollection])
camera = butler.get('camera', instrument=cameraName)

In [None]:
crosstalk = butler.get(calibType, instrument=cameraName, 
                       detector=detectorId, collections=genCollection)
residualCrosstalk = butler.get('verifyCrosstalk', instrument=cameraName, 
                               detector=detectorId, collections=verifyCollection)

print("Crosstalk creation:", 
      crosstalk.getMetadata().toDict().get('CALIBDATE', "Bad header"))
print("Residual creation: ", 
      residualCrosstalk.getMetadata().toDict().get('CALIBDATE', "Bad header"))

In [None]:
runStats = butler.get('verifyCrosstalkStats', instrument=cameraName, collections=verifyCollection)
runSuccess = runStats.pop('SUCCESS')

In [None]:
# Display summary table of tests and failure counts.
utils.failureTable(runStats)

In [None]:
# This cell may be easier to follow in a new view via the
#     "Create New View for Output" right-click menu.  
afwDisplay.setDefaultBackend(displayBackend)
display = afwDisplay.Display(dims=(1000, 1000))
display.embed()

In [None]:
# This block allows the residual images to be scanned for concerns.
blinkResiduals = interactive
if blinkResiduals:
    continueDisplay = True
    skipNumber = 0
    datasets = list(butler.registry.queryDatasets('verifyCrosstalkProc'))
    for datasetRef in sorted(datasets, key=lambda ds: ds.dataId['exposure']):
        # Do it this way, because runStats doesn't retain exposure info.
        if skipNumber > 0:
            skipNumber -= 1
            continue
        dataId = datasetRef.dataId
        original = butler.get('cpCrosstalkProc', dataId=dataId)
        residual = butler.get(datasetRef)
        detStats = {}
        display.mtv(original)
        display.scale('linear', 'zscale', None)
        
        continueDisplay, skipNumber = utils.interactiveBlock(f"Orig: {dataId['exposure']} {dataId['detector']}", 
                                                             detStats)
        display.mtv(residual)
        display.scale('linear', 'zscale', None)
        continueDisplay, skipNumber = utils.interactiveBlock(f"Resid: {dataId['exposure']} {dataId['detector']}", 
                                                             detStats)
        if continueDisplay is False:
            break

In [None]:
# Plot raw crosstalk values
detector = camera[detectorId]
detName = detector.getName()

fig, axes = plt.subplots(nrows=1, ncols=2,
                         sharex=True, sharey=False, figsize=[8.0, 4.0], 
                         constrained_layout=True)
pcm = axes[0].imshow(crosstalk.coeffs, cmap='seismic', 
                     norm=colors.SymLogNorm(linthresh=1e-6, linscale=1e-7, 
                                            vmin=-1e-2, vmax=1e-2))
axes[0].set_title(f"Crosstalk coefficients {detName}")
fig.colorbar(pcm, ax=axes[0])

pcm = axes[1].imshow(residualCrosstalk.coeffs, cmap='seismic',
                     norm=colors.SymLogNorm(linthresh=1e-6, linscale=1e-7, 
                                            vmin=-1e-2, vmax=1e-2))
axes[1].set_title(f"Residual coefficients {detName}")
fig.colorbar(pcm, ax=axes[1])

fig.suptitle(f"Measured Crosstalk terms")
plt.show()

###

fig, axes = plt.subplots(nrows=1, ncols=2,
                         sharex=True, sharey=False, figsize=[8.0, 4.0], 
                         constrained_layout=True)
pcm = axes[0].imshow(np.abs(crosstalk.coeffs / crosstalk.coeffErr / np.sqrt(crosstalk.coeffNum)), 
                     cmap='seismic', 
                     norm=colors.LogNorm(vmin=1e-1, vmax=1e0))
axes[0].set_title(f"Crosstalk significance {detName}")
fig.colorbar(pcm, ax=axes[0])

pcm = axes[1].imshow(np.abs(residualCrosstalk.coeffs / residualCrosstalk.coeffErr / np.sqrt(residualCrosstalk.coeffNum)), 
                     cmap='seismic',
                     norm=colors.LogNorm(vmin=1e-1, vmax=1e0))
axes[1].set_title(f"Residual significance {detName}")
fig.colorbar(pcm, ax=axes[1])

fig.suptitle(f"Measured Crosstalk Significance")
plt.show()

###

fig, axes = plt.subplots(nrows=1, ncols=2,
                         sharex=True, sharey=False, figsize=[8.0, 4.0], 
                         constrained_layout=True)
pcm = axes[0].imshow(crosstalk.coeffValid, cmap='seismic')
axes[0].set_title(f"Crosstalk valid {detName}")
fig.colorbar(pcm, ax=axes[0])

pcm = axes[1].imshow(residualCrosstalk.coeffValid, cmap='seismic')
axes[1].set_title(f"Crosstalk valid {detName}")
fig.colorbar(pcm, ax=axes[1])

fig.suptitle(f"Measured Crosstalk Validity")
plt.show()

In [None]:
# Load data for further study.  This cell is very slow.
XX = np.arange(-20, -11, 0.1)

for detector in camera:
    ratioDataRefs = butler.registry.queryDatasets('cpCrosstalkRatio', 
                                                  detector=detector.getId(), 
                                                  collections=genCollection)
    
    fullSet = defaultdict(lambda: defaultdict(list))
    residualSet = defaultdict(lambda: defaultdict(list))

    fluxSet = defaultdict(list)
    fluxResidualSet = defaultdict(list)
    
    for dataRef in ratioDataRefs:
        print(dataRef)
        ratioData = butler.get('cpCrosstalkRatio', dataId=dataRef.dataId, 
                               collections=genCollection)
        fluxData = butler.get('crosstalkFluxes', dataId=dataRef.dataId, 
                              collections=genCollection)
        residualData = butler.get('verifyCrosstalkRatio', dataId=dataRef.dataId,
                                  collections=verifyCollection)
        residualFluxData = butler.get('crosstalkFluxes', dataId=dataRef.dataId,
                                  collections=verifyCollection)
        
        for targetDetector, sourceData in ratioData.items():
            for sourceDetector, ratios in sourceData.items():
                for targetAmp, ampData in ratios.items():
                    for sourceAmp, values in ampData.items():
                        fullSet[targetAmp][sourceAmp].extend(values)
                            
        for targetDetector, sourceData in residualData.items():
            for sourceDetector, ratios in sourceData.items():
                for targetAmp, ampData in ratios.items():
                    for sourceAmp, values in ampData.items():
                         residualSet[targetAmp][sourceAmp].extend(values)
                            
        for targetDetector, sourceData in fluxData.items():
            for sourceAmp, values in sourceData.items():
                fluxSet[sourceAmp].extend(values)
                    
        for targetDetector, sourceData in residualFluxData.items():
            for sourceAmp, values in sourceData.items():
                fluxResidualSet[sourceAmp].extend(values)

In [None]:
# This is a helper function for trimming the far wings from the CT signal.
rejSigma = 2.0

def maskValues(inValues):
    inValues = np.array(inValues)
    if len(inValues) < 3:
        return inValues
    for rej in range(3):
        lo, med, hi = np.percentile(inValues, [25.0, 50.0, 75.0])
        sigma = 0.741*(hi - lo)
        valuesMask = np.abs(inValues - med) < rejSigma*sigma
        inValues = inValues[valuesMask]
    return inValues

In [None]:
ordering = [amp.getName() for amp in detector]

In [None]:
# Plot pre- and post- correction ratio values, along with vertical lines for the quoted coefficients.
cdf = True

with np.errstate(invalid='ignore', divide='ignore'):
    for targetAmp, targetValues in fullSet.items():
        fig, axes = plt.subplots(nrows=4, ncols=4,
                                 sharex=False, sharey=False, figsize=[12.0, 8.0], 
                                 constrained_layout=True)
        axIterator = 0
        for sourceAmp, values in targetValues.items():
            sourceIt = ordering.index(sourceAmp)
            targetIt = ordering.index(targetAmp)
        
            CTmean = crosstalk.coeffs[sourceIt][targetIt]
            CTerr = crosstalk.coeffErr[sourceIt][targetIt]
            CTN = crosstalk.coeffNum[sourceIt][targetIt]
            CTValid = crosstalk.coeffValid[sourceIt][targetIt]
        
            CTResidMean = residualCrosstalk.coeffs[sourceIt][targetIt]
            CTResidErr = residualCrosstalk.coeffErr[sourceIt][targetIt]
            CTResidN = residualCrosstalk.coeffNum[sourceIt][targetIt]
            CTResidValid = residualCrosstalk.coeffValid[sourceIt][targetIt]
        
            residualValues = residualSet[targetAmp][sourceAmp]
            nBin = int(len(values)) if cdf else 25
        
            if nBin <= 0:
                nBin = 25
            axTuple = (axIterator // 4, axIterator % 4)
        
            values = maskValues(values)
            residualValues = maskValues(residualValues)
            axes[axTuple].hist(values, bins=nBin, color="blue",
                               density=cdf, histtype='step', cumulative=cdf)
            axes[axTuple].hist(residualValues, bins=nBin, color="olive",
                               density=cdf, histtype='step', cumulative=cdf)

            axes[axTuple].axvline(CTmean, c='cyan')
            axes[axTuple].axvline(1 * CTerr, c='purple')

            axes[axTuple].axvline(CTResidMean, c='black')
            axes[axTuple].axvline(1 * CTResidErr, c='grey')
            axes[axTuple].axvline(0.0, c='red')
            axes[axTuple].axline((0.0, 0.5), slope=0.0)
        
            axes[axTuple].text(0.1, 0.8, 
                               f"{CTmean: 1.3e}\n{CTerr / np.sqrt(CTN): 1.3e}",
                               transform=axes[axTuple].transAxes,
                               horizontalalignment='left')
            axes[axTuple].text(0.9, 0.1, f"{CTResidMean: 1.3e}\n{CTResidErr / np.sqrt(CTResidN): 1.3e}",
                               transform=axes[axTuple].transAxes,
                               horizontalalignment='right')
            if cdf and len(values) > 0:   
                axes[axTuple].set_xlim(np.min((np.min(values), np.min(residualValues))),
                                       np.max((np.max(values), np.max(residualValues))))

            axes[axTuple].set_title(f"G: {CTValid} R: {CTResidValid} for {sourceAmp} -> {targetAmp}")
            axIterator += 1
        fig.suptitle(f"Crosstalk ratio data {sourceDetector} -> {targetDetector}")
        plt.show()

In [None]:
# plot the ratios as a function of flux.
XX = np.arange(-20, -11, 0.1)

for detector in camera:
    ratioDataRefs = butler.registry.queryDatasets('verifyCrosstalkRatio', 
                                                  detector=detector.getId())
    
    for targetAmp, sourceValues in fullSet.items():
        fig, axes = plt.subplots(nrows=4, ncols=4,
                                 sharex=False, sharey=False, figsize=[8.0, 8.0], 
                                 constrained_layout=True)
        axIterator = 0
        for sourceAmp, values in sourceValues.items():
            axTuple = (axIterator // 4, axIterator % 4)
            sourceIt = ordering.index(sourceAmp)
            targetIt = ordering.index(targetAmp)
        
            CTmean = crosstalk.coeffs[sourceIt][targetIt]
            CTerr = crosstalk.coeffErr[sourceIt][targetIt]

            if sourceAmp == targetAmp:
                axes[axTuple].plot([0.0], [0.0])
            else:
                axes[axTuple].set_ylim(np.min(values), np.max(values))
                axes[axTuple].scatter(fluxSet[sourceAmp], 
                                      fullSet[targetAmp][sourceAmp],
                                      s=.1)
                axes[axTuple].axline((0.0, CTmean), slope=0.0, c='cyan')
                axes[axTuple].axline((0.0, CTerr), slope=0.0, c='purple')
            axes[axIterator // 4, axIterator % 4].set_title(f"{sourceAmp} -> {targetAmp}")
            axIterator += 1
        fig.suptitle(f"Residual: {sourceDetector} -> {targetDetector}")
        plt.show()

In [None]:
# Allow particular pairs to be plotted.
def do_a_plot(sourceAmp, targetAmps):
    if not isinstance(targetAmps, list):
        targetAmps = [targetAmps]
        
    fig, axes = plt.subplots(nrows=1, ncols=len(targetAmps),
                             sharex=False, sharey=False, 
                             figsize=[8.0, 8.0 / len(targetAmps)], 
                             constrained_layout=True)
    
    for ii, target in enumerate(targetAmps): 
        sourceIt = ordering.index(sourceAmp)
        targetIt = ordering.index(target)
        
        CTmean = crosstalk.coeffs[sourceIt][targetIt]
        CTerr = crosstalk.coeffErr[sourceIt][targetIt]
        axes[ii].set_xlim(0.0 * 70000, 130000)
        axes[ii].set_ylim(-5e-4, 5e-4)
        axes[ii].scatter(fluxSet[sourceAmp], 
                         fullSet[target][sourceAmp],
                         s=.1)
        axes[ii].scatter(fluxResidualSet[sourceAmp],
                         residualSet[targetAmp][sourceAmp],
                         s=.1, c='r')
        axes[ii].axline((0.0, CTmean), slope=0.0, c='cyan')
        axes[ii].axline((0.0, CTerr), slope=0.0, c='purple')
        
        axes[ii].axline((0.0, 0.0), slope=0.0, dashes=(1,4), c='indigo')
        axes[ii].set_title(f"Measurements: {sourceAmp} -> {target}")
    plt.show()

In [None]:
do_a_plot("C14", ["C13", "C15", "C04"])

In [None]:
if False:
    source = 'C14'
    target = 'C13'
    F = np.array(fluxSet[source])
    Q = np.array(fullSet[target][source])
    np.append(Q, np.inf)
    plt.hist(np.where(F < 40000, Q, np.nan), 
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 40000,F < 50000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 50000,F < 60000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 60000,F < 70000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 70000,F < 80000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 80000,F < 90000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 90000,F < 100000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 100000,F < 110000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)
    plt.hist(np.where(np.logical_and(F > 110000,F < 120000), Q, np.nan),
             bins=100, cumulative=True, histtype='step', density=True)

    plt.axline((0.0, 0.5), slope=0.0, dashes=(1,4), c='indigo')
    plt.axline((0.0, 0.0), slope=np.inf, dashes=(1, 4), c='indigo')
    plt.show()

    plt.scatter(F, Q)
    plt.show()

In [None]:
if False:
    ratioData = butler.get('cpCrosstalkRatio', instrument=cameraName,
                           detector=0, exposure=2023021500266, collections=genCollection)
    fluxData = butler.get('crosstalkFluxes', instrument=cameraName,
                          detector=0, exposure=2023021500266, collections=genCollection)

    R = ratioData['RXX_S00']['RXX_S00']['C05']['C04']
    F = fluxData['RXX_S00']['C04']

    plt.scatter(F, R)