In [None]:
# Which version of the Stack am I using?
!eups list -s | grep lsst_distrib

In [None]:
import numpy as np
import json
import os
from astropy.table import Table
from astropy import units as u
import matplotlib.pyplot as plt
%matplotlib widget

from lsst.afw.table import SimpleCatalog, GroupView
import lsst.verify
import lsst.daf.butler as dafButler
from lsst.validate.drp.calcsrd.tex import (correlation_function_ellipticity_from_matches,
                                           select_bin_from_corr)
from metric_pipeline_utils.filtermatches import filterMatches

In [None]:
#tract_array = [9813, 9615, 9697]
#band_array = ['g', 'r', 'i']
#metric_array = ['TE1', 'TE2']
tract_array = [9813]
band_array = ['i']
metric_array = ['TE1', 'TE2']

In [None]:
def getGen2Measurement(band, tract, metric):
    infile = '/datasets/hsc/repo/rerun/RC/w_2020_34/DM-26441/validateDrp/matchedVisitMetrics/%s/HSC-%s/matchedVisit_HSC-%s.json'%(tract, 
                                                                                                                                  band.upper(), 
                                                                                                                                  band.upper())
    with open(infile) as f:
        job = lsst.verify.Job.deserialize(**json.load(f))
    measurement = job.measurements['validate_drp.%s'%(metric)]
    return measurement

In [None]:
%%time
# This takes several minutes to run
gen2_measurement_dict = {}
for tract in tract_array:
    for band in band_array:
        for metric in metric_array:
            measurement = getGen2Measurement(band, tract, metric)
            print(band, tract, metric, measurement)
            gen2_measurement_dict['%s_%s_%s'%(band, tract, metric)] = measurement

In [None]:
#gen2_measurement['i_9813_TE1'].extras['radius'].quantity
#gen2_measurement['i_9813_TE1'].extras['xip'].quantity

In [None]:
repo = '/project/hsc/gen3repo/rc2w34_ssw36/'
config = os.path.join(repo,'butler.yaml')
try: butler_gen3 = dafButler.Butler(config=config)
except ValueError as e: print(e)

In [None]:
def getGen3Measurement(butler, band, tract, metric):
    #refs = list(registry.queryDatasets('metricvalue_validate_drp_%s'%(metric), collections=collection, abstract_filter=band, data))
    dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
    #assert len(refs) == 1
    #measurement = butler_gen3.getDirect(refs[0]).quantity
    collection = 'kbechtol/svv_%s_gri_matched'%(tract)
    measurement = butler.get('metricvalue_validate_drp_%s'%(metric), 
                              collections=collection, dataId=dataid)
    return measurement

In [None]:
gen3_measurement_dict = {}
for tract in tract_array:
    for band in band_array:
        for metric in metric_array:
            measurement = getGen3Measurement(butler_gen3, band, tract, metric)
            print(measurement)
            gen3_measurement_dict['%s_%s_%s'%(band, tract, metric)] = measurement

In [None]:
def getGen3MatchedCat(butler, band, tract, applyFilter=True):
    collection = 'kbechtol/svv_%s_gri_matched'%(tract)
    dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
    matchedCatalog = butler_gen3.get('matchedCatalogTract', collections=collection, dataId=dataid)
    if applyFilter:
        filteredCat = filterMatches(matchedCatalog)
        return filteredCat
    else:
        matchedCat = GroupView.build(matchedCatalog)
        return matchedCat

In [None]:
def getGen2MatchedCat(band, tract, applyFilter=True):
    #infile = '/project/jcarlin/matched_cats/RC2_tract%s_HSC-%s_matched_cat_validateDrp.fits'%(tract, band.upper())
    infile = '/project/jcarlin/matched_cats/no_ext_calib/RC2_tract%s_HSC-%s_matched_cat_validateDrp.fits'%(tract, band.upper())
    matchedCatalog = SimpleCatalog.readFits(infile)
    if applyFilter:
        filteredCat = filterMatches(matchedCatalog)
        return filteredCat
    else:
        matchedCat = GroupView.build(matchedCatalog)
        return matchedCat

In [None]:
%%time
# This takes ~10 minutes to run
gen3_filtered_cat_dict = {}
for tract in tract_array:
    for band in band_array:
        print('%s_%s'%(band, tract))
        #collection = 'kbechtol/svv_%s_gri_matched'%(tract)
        #dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
        #matchedCatalog = butler_gen3.get('matchedCatalogTract', collections=collection, dataId=dataid)
        #print(len(matchedCatalog))
        #filteredCat = filterMatches(matchedCatalog, snrMin=50)
        #print(len(filteredCat))
        filteredCat = getGen3MatchedCat(butler_gen3, band, tract, applyFilter=True)
        gen3_filtered_cat_dict['%s_%s'%(band, tract)] = filteredCat

In [None]:
%%time
# This takes ~10 minutes to run
gen2_filtered_cat_dict = {}
for tract in tract_array:
    for band in band_array:
        print('%s_%s'%(band, tract))
        filteredCat = getGen2MatchedCat(band, tract, applyFilter=True)
        gen2_filtered_cat_dict['%s_%s'%(band, tract)] = filteredCat

In [None]:
%%time
# This takes ~10 minutes to run
gen3_matched_cat_dict = {}
for tract in tract_array:
    for band in band_array:
        print('%s_%s'%(band, tract))
        matchedCat = getGen3MatchedCat(butler_gen3, band, tract, applyFilter=False)
        gen3_matched_cat_dict['%s_%s'%(band, tract)] = matchedCat

In [None]:
%%time
# This takes ~10 minutes to run
gen2_matched_cat_dict = {}
for tract in tract_array:
    for band in band_array:
        print('%s_%s'%(band, tract))
        matchedCat = getGen2MatchedCat(band, tract, applyFilter=False)
        gen2_matched_cat_dict['%s_%s'%(band, tract)] = matchedCat

In [None]:
def plotXi(radius, xip, xip_err, color, label=None):
    plt.errorbar(radius.value, xip, yerr=np.where(xip > 0, xip_err, 0), 
                marker='o', c='none', mfc=color, mec=color, ecolor=color, barsabove=True, label=label)
    plt.errorbar(radius.value, -1 * xip, yerr=np.where(xip < 0, xip_err, 0), 
                 marker='o', c='none', ls='--', mfc='none', mec=color, ecolor=color, barsabove=True)

In [None]:
def plotCorrelationFunctionComparison(butler, band, tract, 
                                      gen2_filtered_cat=None, gen3_filtered_cat=None, 
                                      gen2_measurement=None, gen3_measurement=None):

    plt.figure(figsize=(8,6))
    #plt.errorbar(radius.value, xip, yerr=np.where(xip > 0, xip_err, 0), 
    #            marker='o', c='none', mfc='black', mec='black', ecolor='black', barsabove=True)
    #plt.errorbar(radius.value, -1 * xip, yerr=np.where(xip < 0, xip_err, 0), 
    #             marker='o', c='none', ls='--', mfc='none', mec='black', ecolor='black', barsabove=True)
    
    if gen2_filtered_cat:
        gen2_filteredCat = gen2_filtered_cat['%s_%s'%(band, tract)]
        gen2_radius, gen2_xip, gen2_xip_err = correlation_function_ellipticity_from_matches(gen2_filteredCat)
        plotXi(gen2_radius * 0.95, gen2_xip, gen2_xip_err, color='black', label='validate_drp FilteredCatalog')
    
    if gen3_filtered_cat:
        gen3_filteredCat = gen3_filtered_cat['%s_%s'%(band, tract)]
        gen3_radius, gen3_xip, gen3_xip_err = correlation_function_ellipticity_from_matches(gen3_filteredCat)
        plotXi(gen3_radius, gen3_xip, gen3_xip_err, color='red', label='new_framework FilteredCatalog')
    
    if gen2_measurement:
        plotXi(gen2_measurement['%s_%s_TE1'%(band, tract)].extras['radius'].quantity * 1.05,
               gen2_measurement['%s_%s_TE1'%(band, tract)].extras['xip'].quantity, 
               gen2_measurement['%s_%s_TE1'%(band, tract)].extras['xip_err'].quantity, color='blue', label='validate_drp extras')
        #plotXi(gen2_measurement['%s_%s_TE2'%(band, tract)].extras['radius'].quantity,
        #       gen2_measurement['%s_%s_TE2'%(band, tract)].extras['xip'].quantity, 
        #       gen2_measurement['%s_%s_TE2'%(band, tract)].extras['xip_err'].quantity, color='green')
    
    xlim = plt.xlim()
    plt.xscale('log')
    plt.yscale('log')
    plt.ylim(1.e-8, 1.e-3)
    
    key = '%s_%s_%s'%(band, tract, metric)
    if gen2_measurement:
        plt.hlines(gen2_measurement['%s_%s_%s'%(band, tract, 'TE1')].quantity.value, 
                   color='blue', xmin=0., xmax=1.)
        plt.hlines(gen2_measurement['%s_%s_%s'%(band, tract, 'TE2')].quantity.value, 
                   color='blue', xmin=5., xmax=100., label='validate_drp Measurement')
    if gen3_measurement:
        plt.hlines(gen3_measurement['%s_%s_%s'%(band, tract, 'TE1')].quantity.value, 
                   color='red', xmin=0., xmax=1., ls='--')
        plt.hlines(gen3_measurement['%s_%s_%s'%(band, tract, 'TE2')].quantity.value, 
                   color='red', xmin=5., xmax=100., ls='--', label='new framework Measurement')
    
    plt.axvspan(5., xlim[-1], color='0.9')
    plt.axvspan(xlim[0], 1., color='0.9')
    plt.xlim(xlim)
    
    plt.title('tract: %s, band: %s'%(tract, band))
    plt.xlabel('r (arcmin)')
    plt.ylabel('Residual Ellipticity Correlation')
    plt.legend(loc='upper right')

In [None]:
def plot2(butler, band, tract, gen2_measurement=None, gen3_measurement=None):

    collection = 'kbechtol/svv_%s_gri_matched'%(tract)
    dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
    butler.get('matchedCatalogTract', collections=collection, dataId=dataid)

    matchedCatalog = butler.getDirect(matched_catalog_tract_refs[0])

    filteredCat = filterMatches(matchedCatalog, snrMin=50)
    
    print(len(matchedCatalog))
    print(len(filteredCat))
    
    radius, xip, xip_err = correlation_function_ellipticity_from_matches(filteredCat)

    plt.figure()
    plt.errorbar(radius.value, xip, yerr=np.where(xip > 0, xip_err, 0), 
                marker='o', c='none', mfc='black', mec='black', ecolor='black', barsabove=True)
    plt.errorbar(radius.value, -1 * xip, yerr=np.where(xip < 0, xip_err, 0), 
                 marker='o', c='none', ls='--', mfc='none', mec='black', ecolor='black', barsabove=True)
    
    xlim = plt.xlim()
    plt.xscale('log')
    plt.yscale('log')
    plt.ylim(1.e-8, 1.e-3)
    
    key = '%s_%s_%s'%(band, tract, metric)
    if gen2_measurement:
        plt.hlines(gen2_measurement['%s_%s_%s'%(band, tract, 'TE1')].quantity.value, 
                   color='blue', xmin=0., xmax=1.)
        plt.hlines(gen2_measurement['%s_%s_%s'%(band, tract, 'TE2')].quantity.value, 
                   color='blue', xmin=5., xmax=100., label='validate_drp')
    if gen3_measurement:
        plt.hlines(gen3_measurement['%s_%s_%s'%(band, tract, 'TE1')].quantity.value, 
                   color='red', xmin=0., xmax=1., ls='--')
        plt.hlines(gen3_measurement['%s_%s_%s'%(band, tract, 'TE2')].quantity.value, 
                   color='red', xmin=5., xmax=100., ls='--', label='new framework')
    
    plt.axvspan(5., xlim[-1], color='0.9')
    plt.axvspan(xlim[0], 1., color='0.9')
    plt.xlim(xlim)
    
    plt.title('tract: %s, band: %s'%(tract, band))
    plt.xlabel('r (arcmin)')
    plt.ylabel('Residual Ellipticity Correlation')
    plt.legend(loc='upper right')

In [None]:
for tract in tract_array:
    for band in band_array:
        plotCorrelationFunctionComparison(butler_gen3, band, tract, 
                                          gen2_filtered_cat=gen2_filtered_cat_dict, gen3_filtered_cat=gen3_filtered_cat_dict,
                                          gen2_measurement=gen2_measurement_dict, gen3_measurement=gen3_measurement_dict)

In [None]:
def plotRaDec(cat, **kwargs):
    ra = cat.aggregate(np.median, 'coord_ra')
    dec = cat.aggregate(np.median, 'coord_dec')
    plt.scatter(ra, dec, **kwargs)

In [None]:
cat = gen2_matched_cat_dict['%s_%s'%(band, tract)]

In [None]:
cat.groups[0]['base_PixelFlags_flag_saturated']

In [None]:
#cat = gen2_filtered_cat_dict['%s_%s'%(band, tract)]
cat = gen2_matched_cat_dict['%s_%s'%(band, tract)]
#ra = cat.apply(np.median, 'coord_ra')

In [None]:
ra_agg = cat.aggregate(np.median, 'coord_ra')

In [None]:
print(cat.count)
print(ra_agg.shape)

In [None]:
plt.figure()
plt.scatter(ra_agg, ra_agg[::-1])

In [None]:
#dir(gen2_matchedCat)

In [None]:
#gen2_matchedCat.count / len(gen2_matchedCat)
gen2_matchedCat

In [None]:
def plotMatchedCatalogComparison(butler, band, tract, 
                                 gen2_matched_cat=None, gen3_matched_cat=None,
                                 gen2_filtered_cat=None, gen3_filtered_cat=None):
    plt.figure(figsize=(6,6))
    
    if gen2_matched_cat:
        gen2_matchedCat = gen2_matched_cat['%s_%s'%(band, tract)]
        plotRaDec(gen2_matchedCat,
                  label='validate_drp MatchedCatalog')
        #gen2_ra = gen2_filteredCat.aggregate(np.median, 'coord_ra')
        #gen2_dec = gen2_filteredCat.aggregate(np.median, 'coord_dec')
        #plt.scatter(gen2_ra, gen2_dec, marker='+')
    
    if gen3_matched_cat:
        gen3_matchedCat = gen3_matched_cat['%s_%s'%(band, tract)]
        #gen3_ra = gen3_filteredCat.aggregate(np.median, 'coord_ra')
        #gen3_dec = gen3_filteredCat.aggregate(np.median, 'coord_dec')
        #plt.scatter(gen3_ra, gen3_dec, marker='x')
        plotRaDec(gen3_matchedCat, 
                  label='new_framework MatchedCatalog')
    
    if gen2_filtered_cat:
        gen2_filteredCat = gen2_filtered_cat['%s_%s'%(band, tract)]
        #gen2_ra = gen2_filteredCat.aggregate(np.median, 'coord_ra')
        #gen2_dec = gen2_filteredCat.aggregate(np.median, 'coord_dec')
        #plt.scatter(gen2_ra, gen2_dec, marker='+')
        plotRaDec(gen2_filteredCat, marker='+',
                  label='validate_drp FilteredCatalog')
    
    if gen3_filtered_cat:
        gen3_filteredCat = gen3_filtered_cat['%s_%s'%(band, tract)]
        #gen3_ra = gen3_filteredCat.aggregate(np.median, 'coord_ra')
        #gen3_dec = gen3_filteredCat.aggregate(np.median, 'coord_dec')
        #plt.scatter(gen3_ra, gen3_dec, marker='x')
        plotRaDec(gen3_filteredCat, marker='x',
                  label='new_framework FilteredCatalog')
        
    plt.title('tract: %s, band: %s'%(tract, band))
    plt.xlabel('RA')
    plt.ylabel('Dec')
    plt.legend(loc='upper right')

In [None]:
for tract in tract_array:
    for band in band_array:
        plotMatchedCatalogComparison(butler_gen3, band, tract, 
                                     gen3_matched_cat=gen3_matched_cat_dict,
                                     gen3_filtered_cat=gen3_filtered_cat_dict)
        #plotMatchedCatalogComparison(butler_gen3, band, tract, 
        #                             gen2_matched_cat=gen2_matched_cat_dict, gen3_matched_cat=gen3_matched_cat_dict,
        #                             gen2_filtered_cat=gen2_filtered_cat_dict, gen3_filtered_cat=gen3_filtered_cat_dict)

In [None]:
for tract in tract_array:
    for band in band_array:
        plotMatchedCatalogComparison(butler_gen3, band, tract, 
                                     gen2_filtered_cat=gen2_filtered_cat_dict,
                                     gen3_filtered_cat=gen3_filtered_cat_dict)

In [None]:
snrMin = 50.0
snrMax = np.Inf
extended = False
doFlags = True
nMatchesRequired = 2
isPrimary = True
psfStars = False
photoCalibStars = False
astromCalibStars = False
magKey = matchedCat.schema.find('slot_PsfFlux_mag').key

def nMatchFilter(cat):
    if len(cat) < nMatchesRequired:
        return False
    return np.isfinite(cat.get(magKey)).all()

def snrFilter(cat):
    # Note that this also implicitly checks for psfSnr being non-nan.
    snr = cat.get('base_PsfFlux_snr')
    ok0, = np.where(np.isfinite(snr))
    medianSnr = np.median(snr[ok0])
    return snrMin <= medianSnr and medianSnr <= snrMax

def ptsrcFilter(cat):
    ext = cat.get('base_ClassificationExtendedness_value')
    # Keep only objects that are flagged as "not extended" in *ALL* visits,
    # (base_ClassificationExtendedness_value = 1 for extended, 0 for point-like)
    if extended:
        return np.min(ext) > 0.9
    else:
        return np.min(ext) < 0.9

def flagFilter(cat):
    if doFlags:
        flag_sat = cat.get("base_PixelFlags_flag_saturated")
        flag_cr = cat.get("base_PixelFlags_flag_cr")
        flag_bad = cat.get("base_PixelFlags_flag_bad")
        flag_edge = cat.get("base_PixelFlags_flag_edge")
        return np.logical_not(np.any([flag_sat, flag_cr, flag_bad, flag_edge]))
    else:
        return True
    
def fullFilter(cat):
        return nMatchFilter(cat) and snrFilter(cat) and ptsrcFilter(cat) and flagFilter(cat)
    
def comboFilter(cat):
        #return snrFilter(cat) and ptsrcFilter(cat) and flagFilter(cat)
        #return snrFilter(cat) and flagFilter(cat)
        return ptsrcFilter(cat) and flagFilter(cat)

In [None]:
matchedCat_nMatch = gen3_matched_cat_dict['%s_%s'%(band, tract)].where(nMatchFilter)

In [None]:
matchedCat_snr = gen3_matched_cat_dict['%s_%s'%(band, tract)].where(snrFilter)

In [None]:
matchedCat_ptsrc = gen3_matched_cat_dict['%s_%s'%(band, tract)].where(ptsrcFilter)

In [None]:
matchedCat_flag = gen3_matched_cat_dict['%s_%s'%(band, tract)].where(flagFilter)

In [None]:
matchedCat_full = gen3_matched_cat_dict['%s_%s'%(band, tract)].where(fullFilter)

In [None]:
matchedCat_combo = gen3_matched_cat_dict['%s_%s'%(band, tract)].where(comboFilter)

In [None]:
print(gen3_matched_cat_dict['%s_%s'%(band, tract)].count, len(gen3_matched_cat_dict['%s_%s'%(band, tract)]))
print(matchedCat_nMatch.count, len(matchedCat_nMatch))
print(matchedCat_snr.count, len(matchedCat_snr))
print(matchedCat_ptsrc.count, len(matchedCat_ptsrc))
print(matchedCat_flag.count, len(matchedCat_flag))
print(matchedCat_full.count, len(matchedCat_full))
print(matchedCat_combo.count, len(matchedCat_combo))

In [None]:
plt.figure()
#plotRaDec(matchedCat_nMatch, marker='.', edgecolor='none', s=2) # Very uniform
#plotRaDec(matchedCat_snr, marker='.', edgecolor='none', s=2) # slight bias to the periphery
#plotRaDec(matchedCat_ptsrc, marker='.', edgecolor='none', s=2) # slight bias to inner
#plotRaDec(matchedCat_flag, marker='.', edgecolor='none', s=2) # slight bias to the inner
plotRaDec(matchedCat_combo, marker='.', edgecolor='none', s=10)
plotRaDec(matchedCat_full, marker='.', edgecolor='none', s=10)

In [None]:
def nMatchFilter(cat):
    if len(cat) < nMatchesRequired:
        return False
    return np.isfinite(cat.get(magKey)).all()

def snrFilter(cat):
    # Note that this also implicitly checks for psfSnr being non-nan.
    snr = cat.get('base_PsfFlux_snr')
    ok0, = np.where(np.isfinite(snr))
    medianSnr = np.median(snr[ok0])
    return snrMin <= medianSnr and medianSnr <= snrMax

def ptsrcFilter(cat):
    ext = cat.get('base_ClassificationExtendedness_value')
    # Keep only objects that are flagged as "not extended" in *ALL* visits,
    # (base_ClassificationExtendedness_value = 1 for extended, 0 for point-like)
    if extended:
        return np.min(ext) > 0.9
    else:
        return np.min(ext) < 0.9

def flagFilter(cat):
    if doFlags:
        flag_sat = cat.get("base_PixelFlags_flag_saturated")
        flag_cr = cat.get("base_PixelFlags_flag_cr")
        flag_bad = cat.get("base_PixelFlags_flag_bad")
        flag_edge = cat.get("base_PixelFlags_flag_edge")
        return np.logical_not(np.any([flag_sat, flag_cr, flag_bad, flag_edge]))
    else:
        return True

In [None]:
my_group = matchedCat_full.groups[0]
my_group

#print(nMatchFilter(my_group),
#      snrFilter(my_group),
#      ptsrcFilter(my_group),
#      flagFilter(my_group))

#flag_sat = my_group.get("base_PixelFlags_flag_saturated")
#flag_cr = my_group.get("base_PixelFlags_flag_cr")
#flag_bad = my_group.get("base_PixelFlags_flag_bad")
#flag_edge = my_group.get("base_PixelFlags_flag_edge")
#np.logical_not(np.any([flag_sat, flag_cr, flag_bad, flag_edge], axis=0))

In [None]:
# When using SNR + PTSRC, not the strong peripheral bias
# When using SNR + PTSRC + FLAG, strong peripheral bias
# When using SNR + FLAG, somewhat peripherical bias
# When using PTSRC + FLAG, 

In [None]:
from astropy.coordinates import SkyCoord
import pandas as pd

def match(lon_1, lat_1, lon_2, lat_2, sep, unique=True):
    """
    Input coordinates and max separation in decimal degrees.
    Returns angular separation between pairs in decimal degrees.
    If unique is True, consider only the unique closest match.
    If unique is False, there can be multiple catalog 1 
    objects matched to each catalog 2 object.
    """
    
    c_1 = SkyCoord(ra=lon_1*u.degree, dec=lat_1*u.degree)
    c_2 = SkyCoord(ra=lon_2*u.degree, dec=lat_2*u.degree)
    #idx_1, idx_2, d2d, d3d = c_2.search_around_sky(c_1, tol*u.deg)
    #return idx_1, idx_2, d2d.value

    max_sep = sep * u.deg
    idx_1, d2d, d3d = c_2.match_to_catalog_3d(c_1)
    sep_constraint = d2d < max_sep
    idx_2 = np.nonzero(sep_constraint)[0]
    idx_1 = idx_1[sep_constraint]
    d2d = d2d[sep_constraint].value

    if unique:
        df = pd.DataFrame({'index': idx_1, 'sep': d2d})
        #idx = df.groupby(['index']).idxmin('sep').values.astype(int)
        idx = df.groupby(['index'])['sep'].idxmin().values.astype(int)
        idx_1, idx_2, d2d = idx_1[idx], idx_2[idx], d2d[idx]

    return idx_1, idx_2, d2d

In [None]:
gen2_ra = np.degrees(gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'coord_ra'))
gen2_dec = np.degrees(gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'coord_dec'))
gen3_ra = np.degrees(gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'coord_ra'))
gen3_dec = np.degrees(gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'coord_dec'))
gen2_e1 = gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'e1')
gen3_e1 = gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'e1')
gen2_psf_e1 = gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'psf_e1')
gen3_psf_e1 = gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'psf_e1')
gen2_e2 = gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'e2')
gen3_e2 = gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'e2')
gen2_psf_e2 = gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'psf_e2')
gen3_psf_e2 = gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'psf_e2')
gen2_object = gen2_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'object')
gen3_object = gen3_filtered_cat_dict['%s_%s'%(band, tract)].aggregate(np.median, 'object')

In [None]:
np.sum(gen2_psf_e2[gen2_index] != gen3_psf_e2[gen3_index]) / len(gen3_index)
np.sum(gen2_e2[gen2_index] != gen3_e2[gen3_index]) / len(gen3_index)

In [None]:
gen2_index, gen3_index, angsep = match(gen2_ra, gen2_dec, gen3_ra, gen3_dec, 1/3600)

In [None]:
plt.figure()
#plt.scatter(gen2_e1[gen2_index], gen3_e1[gen3_index])
#plt.scatter(gen2_psf_e1[gen2_index], gen3_psf_e1[gen3_index])
plt.scatter(gen2_psf_e2[gen2_index], gen3_psf_e2[gen3_index])

In [None]:
np.max(np.fabs(gen2_ra[gen2_index] - gen3_ra[gen3_index]))

In [None]:
def gen2AlignFilter(cat):
    return np.all(np.in1d(cat.get('object'), gen2_object[gen2_index]))

def gen3AlignFilter(cat):
    return np.all(np.in1d(cat.get('object'), gen3_object[gen3_index]))

In [None]:
gen2_aligned_cat = gen2_filtered_cat_dict['%s_%s'%(band, tract)].where(gen2AlignFilter)
gen3_aligned_cat = gen3_filtered_cat_dict['%s_%s'%(band, tract)].where(gen3AlignFilter)

In [None]:
plt.figure()
plt.scatter(gen2_aligned_cat.aggregate(np.median, 'coord_ra'), 
            np.sort(gen2_aligned_cat.aggregate(np.median, 'coord_ra')) - np.sort(gen3_aligned_cat.aggregate(np.median, 'coord_ra')))

In [None]:
plt.figure()

gen2_radius, gen2_xip, gen2_xip_err = correlation_function_ellipticity_from_matches(gen2_aligned_cat)
plotXi(gen2_radius * 0.95, gen2_xip, gen2_xip_err, color='black', label='validate_drp AlignedCatalog')
    
gen3_radius, gen3_xip, gen3_xip_err = correlation_function_ellipticity_from_matches(gen3_aligned_cat)
plotXi(gen3_radius, gen3_xip, gen3_xip_err, color='red', label='new_framework AlignedCatalog')

plt.xscale('log')
plt.yscale('log')
plt.ylim(1.e-8, 1.e-3)

plt.title('tract: %s, band: %s'%(tract, band))
plt.xlabel('r (arcmin)')
plt.ylabel('Residual Ellipticity Correlation')
plt.legend(loc='upper right')

In [None]:

plt.figure(figsize=(6,6))
plotRaDec(gen2_aligned_cat, label='validate_drp AlignedCatalog', marker='+')
plotRaDec(gen3_aligned_cat, label='new_framework AlignedCatalog', marker='x')
plt.title('tract: %s, band: %s'%(tract, band))
plt.xlabel('RA')
plt.ylabel('Dec')
plt.legend(loc='upper right')

### Testing Filtering

In [None]:
tract = 9813
band =  'r'
collection = 'kbechtol/svv_%s_gri_matched'%(tract)
dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
matchedCatalog = butler_gen3.get('matchedCatalogTract', collections=collection, dataId=dataid)
print(len(matchedCatalog))

In [None]:
matchedCat.aggregate(np.median, 'base_PsfFlux_snr').shape

In [None]:
matchedCat = GroupView.build(matchedCatalog)
print(matchedCat.count)
print(len(matchedCat.groups))

In [None]:
filteredCat = filterMatches(matchedCatalog)#, snrMin=50)
print(filteredCat.count)
print(len(filteredCat.groups))
#print(filteredCat.counts)

In [None]:
matchedCat.get('base_PsfFlux_snr')

In [None]:
filteredCat.get('base_PsfFlux_snr')

In [None]:
#matchedCatalog.columns.schema

In [None]:
matchedCatalog['base_PsfFlux_snr']
matchedCatalog['base_ClassificationExtendedness_value'][0:100]

In [None]:
plt.figure()
plt.yscale('log')
plt.xscale('log')
bins = np.logspace(-1, 4, 100)
plt.hist(matchedCatalog['base_PsfFlux_snr'][matchedCatalog['base_ClassificationExtendedness_value'] == 0], bins=bins, histtype='step')
plt.hist(matchedCatalog['base_PsfFlux_snr'][matchedCatalog['base_ClassificationExtendedness_value'] == 1], bins=bins, histtype='step')
#plt.xlim(0.1, 1.e4)

In [None]:
def aggMean(x):
    np.median(np.where(x)

In [None]:
x = np.arange(10)
np.where(x > 5)[0]

In [None]:
base_PsfFlux_snr = matchedCat.aggregate(np.median, 'base_PsfFlux_snr')
base_ClassificationExtendedness_value = matchedCat.aggregate(np.mean, 'base_ClassificationExtendedness_value')

In [None]:
base_ClassificationExtendedness_value[0:100]

In [None]:
plt.figure()
plt.yscale('log')
plt.xscale('log')
bins = np.logspace(-1, 4, 100)
plt.hist(base_PsfFlux_snr[base_ClassificationExtendedness_value > 0.5], bins=bins, histtype='step')
plt.hist(base_PsfFlux_snr[base_ClassificationExtendedness_value < 0.5], bins=bins, histtype='step')

### Testing plotting below

In [None]:
#tract = 9813
tract = 9615
band = 'r'
collection = 'kbechtol/svv_%s_gri_matched'%(tract)
dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
butler_gen3.get('matchedCatalogTract', collections=collection, dataId=dataid)

matchedCatalog = butler_gen3.getDirect(matched_catalog_tract_refs[0])

filteredCat = filterMatches(matchedCatalog, snrMin=50)  

print(len(matchedCatalog))
print(len(filteredCat))

In [None]:
filteredCat.count

In [None]:
from lsst.afw.table import GroupView
matchedCat = GroupView.build(matchedCatalog)

In [None]:
print(len(matchedCatalog))
print(matchedCat.count)

In [None]:
radius, xip, xip_err = correlation_function_ellipticity_from_matches(filteredCat)

In [None]:
plt.figure()


plotXi(radius, xip, xip_err, color='red')
    
if gen2_measurement:
    plotXi(gen2_measurement['%s_%s'%(band, tract)].extras['radius'].quantity,
           gen2_measurement['%s_%s'%(band, tract)].extras['xi'].quantity, 
           gen2_measurement['%s_%s'%(band, tract)].extras['xi_err'].quantity, color='blue')

#plt.errorbar(radius.value, xip, yerr=np.where(xip > 0, xip_err, 0), 
#                marker='o', c='none', mfc='black', mec='black', ecolor='black', barsabove=True)
#plt.errorbar(radius.value, -1 * xip, yerr=np.where(xip < 0, xip_err, 0), 
#                 marker='o', c='none', ls='--', mfc='none', mec='black', ecolor='black', barsabove=True)

xlim = plt.xlim()
plt.xscale('log')
plt.yscale('log')
plt.ylim(1.e-8, 1.e-3)

key = '%s_%s_%s'%(band, tract, metric)
if gen2_measurement:
    plt.hlines(gen2_measurement['%s_%s_%s'%(band, tract, 'TE1')].quantity.value, 
               color='blue', xmin=0., xmax=1.)
    plt.hlines(gen2_measurement['%s_%s_%s'%(band, tract, 'TE2')].quantity.value, 
               color='blue', xmin=5., xmax=100., label='validate_drp')
if gen3_measurement:
    plt.hlines(gen3_measurement['%s_%s_%s'%(band, tract, 'TE1')].quantity.value, 
               color='red', xmin=0., xmax=1., ls='--')
    plt.hlines(gen3_measurement['%s_%s_%s'%(band, tract, 'TE2')].quantity.value, 
               color='red', xmin=5., xmax=100., ls='--', label='new framework')
    
plt.axvspan(5., xlim[-1], color='0.9')
plt.axvspan(xlim[0], 1., color='0.9')
plt.xlim(xlim)

plt.title('tract: %s, band: %s'%(tract, band))
plt.xlabel('r (arcmin)')
plt.ylabel('Residual Ellipticity Correlation')
plt.legend(loc='upper right')

# Comparison of Matched Catalogs

In [None]:
import fitsio

In [None]:
infile = '/project/jcarlin/matched_cats/RC2_tract9813_HSC-I_matched_cat_validateDrp.fits'
f = fitsio.FITS(infile)

In [None]:
data = f[1].read()

In [None]:
data['coord_ra']

In [None]:
from lsst.afw.table import MultiMatch, BaseCatalog, Catalog, SimpleCatalog, GroupView
#lsst.afw.table.readFits('/project/jcarlin/matched_cats/RC2_tract9813_HSC-I_matched_cat_validateDrp.fits')

In [None]:
data = BaseCatalog.readFits(infile)

In [None]:
data.schema.getNames()

In [None]:
help(MultiMatch.makeRecord)

In [None]:
mmatch = MultiMatch.makeRecord(data.Record, 'id', 'object')

In [None]:
dir(data)

In [None]:
data.Record

In [None]:
MultiMatch(data.schema, {}, idField='id')

In [None]:
from lsst.afw.table import SimpleCatalog, GroupView

#infile = '/project/jcarlin/matched_cats/RC2_tract9813_HSC-I_matched_cat_validateDrp.fits'
infile = '/project/jcarlin/matched_cats/no_ext_calib/RC2_tract9813_HSC-I_matched_cat_validateDrp.fits'
matchedCatalog = SimpleCatalog.readFits(infile)
matchedCat = GroupView.build(matchedCatalog)

In [None]:
filteredCat = filterMatches(matchedCatalog, snrMin=50)

In [None]:
print(len(matchedCat))
print(len(filteredCat))

# CODE SCRAPS
---

# Which version of the Stack am I using?
!eups list -s | grep lsst_distrib

Note that to access the `metric-pipeline-tasks` modules, one can setup the repo in the `~/notebooks/.user_setups`, for example:

`setup -k -r ~/repos/metric-pipeline-tasks`

In [None]:
import numpy as np
import json
import os
from astropy.table import Table
from astropy import units as u
import matplotlib.pyplot as plt
%matplotlib widget

import lsst.verify
import lsst.daf.butler as dafButler
from lsst.validate.drp.calcsrd.tex import (correlation_function_ellipticity_from_matches,
                                           select_bin_from_corr)
from metric_pipeline_utils.filtermatches import filterMatches

## validate_drp

In [None]:
# Output from a run of validate_drp:
validate_job_g9813 = '/datasets/hsc/repo/rerun/RC/w_2020_34/DM-26441/validateDrp/matchedVisitMetrics/9813/HSC-G/matchedVisit_HSC-G.json'
validate_job_r9813 = '/datasets/hsc/repo/rerun/RC/w_2020_34/DM-26441/validateDrp/matchedVisitMetrics/9813/HSC-R/matchedVisit_HSC-R.json'
validate_job_i9813 = '/datasets/hsc/repo/rerun/RC/w_2020_34/DM-26441/validateDrp/matchedVisitMetrics/9813/HSC-I/matchedVisit_HSC-I.json'

In [None]:
# Read in the .json saved by each validate_drp run:
with open(validate_job_g9813) as f:
    job_g9813 = lsst.verify.Job.deserialize(**json.load(f))
with open(validate_job_r9813) as f:
    job_r9813 = lsst.verify.Job.deserialize(**json.load(f))
with open(validate_job_i9813) as f:
    job_i9813 = lsst.verify.Job.deserialize(**json.load(f))

In [None]:
# Show a metric report in the notebook (use "spec_tags" to specify design, stretch, or minimum req level):
job_g9813.report(spec_tags=['design']).show()

In [None]:
# Extract the measurements from the JSON:
meas_g9813 = job_g9813.measurements.json
meas_r9813 = job_r9813.measurements.json
meas_i9813 = job_i9813.measurements.json

In [None]:
meas_r9813[0]

In [None]:
print(job_g9813.measurements['validate_drp.TE1'].quantity)
print(job_g9813.measurements['validate_drp.TE2'].quantity)

In [None]:
def getGen2Measurement(band, tract, metric):
    infile = '/datasets/hsc/repo/rerun/RC/w_2020_34/DM-26441/validateDrp/matchedVisitMetrics/%s/HSC-%s/matchedVisit_HSC-%s.json'%(tract, 
                                                                                                                                  band.upper(), 
                                                                                                                                  band.upper())
    with open(infile) as f:
        job = lsst.verify.Job.deserialize(**json.load(f))
    measurement = job.measurements['validate_drp.%s'%(metric)].quantity
    return measurement

In [None]:
tract = 9813
band = 'r'
infile = '/datasets/hsc/repo/rerun/RC/w_2020_34/DM-26441/validateDrp/matchedVisitMetrics/%s/HSC-%s/matchedVisit_HSC-%s.json'%(tract, 
                                                                                                                              band.upper(), 
                                                                                                                              band.upper())
with open(infile) as f:
    j = json.load(f)
    #job = lsst.verify.Job.deserialize(**json.load(f))

In [None]:
#j['measurements']

In [None]:
# This takes several minutes to run

tract_array = [9813, 9615, 9697]
band_array = ['g', 'r', 'i']
metric_array = ['TE1', 'TE2']
gen2_measurement = {}

for tract in tract_array:
    for band in band_array:
        for metric in metric_array:
            measurement = getGen2Measurement(band, tract, metric)
            print(band, tract, metric, measurement)
            gen2_measurement['%s_%s_%s'%(band, tract, metric)] = measurement

In [None]:
gen2_measurement

In [None]:
# Compile the measurements into arrays:

gen2_names_g9813 = []
gen2_vals_g9813 = []
gen2_units_g9813 = []

for met in meas_g9813:
    # print(met['metric'], met['value'], met['unit'])
    gen2_names_g9813.append(met['metric'])
    gen2_vals_g9813.append(met['value'])
    gen2_units_g9813.append(met['unit'])
    
gen2_names_r9813 = []
gen2_vals_r9813 = []
gen2_units_r9813 = []

for met in meas_r9813:
    # print(met['metric'], met['value'], met['unit'])
    gen2_names_r9813.append(met['metric'])
    gen2_vals_r9813.append(met['value'])
    gen2_units_r9813.append(met['unit'])
    
gen2_names_i9813 = []
gen2_vals_i9813 = []
gen2_units_i9813 = []

for met in meas_i9813:
    # print(met['metric'], met['value'], met['unit'])
    gen2_names_i9813.append(met['metric'])
    gen2_vals_i9813.append(met['value'])
    gen2_units_i9813.append(met['unit'])

In [None]:
# Make the arrays into astropy tables:
tab_gen2_g9813 = Table([gen2_names_g9813, gen2_vals_g9813, gen2_units_g9813], names=['metric', 'value', 'units'], dtype=(str, 'f2', str))
tab_gen2_r9813 = Table([gen2_names_r9813, gen2_vals_r9813, gen2_units_r9813], names=['metric', 'value', 'units'], dtype=(str, 'f2', str))
tab_gen2_i9813 = Table([gen2_names_i9813, gen2_vals_i9813, gen2_units_i9813], names=['metric', 'value', 'units'], dtype=(str, 'f2', str))

In [None]:
tab_gen2_r9813

In [None]:
for name in tab_gen2_r9813['metric']: print(name)

In [None]:
np.where(np.char.find(tab_gen2_r9813['metric'], 'TE1') >= 0)

In [None]:
def getMetricValue(metric_name):
    find_in_tab2 = np.where(np.char.find(tab_gen2_r9813['metric'], name) >= 0)

## New Framework

In [None]:
repo = '/project/hsc/gen3repo/rc2w34_ssw36/'
config = os.path.join(repo,'butler.yaml')
try: butler_gen3 = dafButler.Butler(config=config)
except ValueError as e: print(e)

In [None]:
registry = butler_gen3.registry

In [None]:
# To see what collections are in the repo:
for c in registry.queryCollections():
    print(c)

In [None]:
# To see what dataset types are (potentially) available:
dstypes = []
for x in registry.queryDatasetTypes():
    print(x)
    dstypes.append(x)

In [None]:
collection = 'kbechtol/svv_9813_gri_matched'

In [None]:
def getMetricValueDatasetTypes(butler):
    registry = butler.registry
    d_types = []
    for d_type in registry.queryDatasetTypes():
        if d_type.storageClass.name == 'MetricValue':
            d_types.append(d_type)
    return d_types

In [None]:
getMetricValueDatasetTypes(butler_gen3)

In [None]:
band = 'r'

matched_catalog_refs = list(registry.queryDatasets('matchedCatalog', collections=collection, abstract_filter=band))
matched_catalog_tract_refs = list(registry.queryDatasets('matchedCatalogTract', collections=collection, abstract_filter=band))
te1_refs = list(registry.queryDatasets('metricvalue_validate_drp_TE1', collections=collection, abstract_filter=band))
summary_te1_refs = list(registry.queryDatasets('metricvalue_summary_validate_drp_TE1', collections=collection, abstract_filter=band))

In [None]:
te1_refs

In [None]:
te1_measurements = [butler_gen3.getDirect(ref).quantity for ref in te1_refs]
te1_measurements

In [None]:
def getGen3Measurement(band, tract, metric):
    #refs = list(registry.queryDatasets('metricvalue_validate_drp_%s'%(metric), collections=collection, abstract_filter=band, data))
    dataid = {'tract':tract, 'abstract_filter':band, 'instrument':'HSC', 'skymap':'hsc_rings_v1'}
    #assert len(refs) == 1
    #measurement = butler_gen3.getDirect(refs[0]).quantity
    collection = 'kbechtol/svv_%s_gri_matched'%(tract)
    measurement = butler_gen3.get('metricvalue_validate_drp_%s'%(metric), 
                                  collections=collection, dataId=dataid).quantity
    return measurement

In [None]:
#tract = 9813 
tract = 9615
#tract = 9697
metric = 'TE1'
measurement = getGen3Measurement(band, tract, metric)
print(measurement)

In [None]:
tract_array = [9813, 9615, 9697]
band_array = ['g', 'r', 'i']
metric_array = ['TE1', 'TE2']
gen3_measurement = {}

for tract in tract_array:
    for band in band_array:
        for metric in metric_array:
            measurement = getGen3Measurement(band, tract, metric)
            print(measurement)
            gen3_measurement['%s_%s_%s'%(band, tract, metric)] = measurement

In [None]:
gen3_measurement

In [None]:
""""
te1_measurements = [butler_gen3.getDirect(ref).quantity for ref in te1_refs]
for ref, measurement in zip(te1_refs, te1_measurements):
    print('Tract %3i:%10.2e %s'%(ref.dataId['tract'],
                                 measurement.value, 
                                 measurement.unit))

te1_measurements = u.Quantity(te1_measurements)
"""

te1_measurement = butler_gen3.getDirect(te1_refs[0]).quantity
summary_te1_measurement = butler_gen3.getDirect(summary_te1_refs[0]).quantity
    
plt.figure()
#plt.hist(te1_measurements.value[~np.isnan(te1_measurements)])
#plt.axvline(te1_measurements.value[~np.isnan(te1_measurements)])
plt.axvline(te1_measurement.value, c='red')
plt.axvline(summary_te1_measurement.value, c='blue')
plt.axvline(job_r9813.measurements['validate_drp.TE1'].quantity, c='black')
plt.xlabel('TE1 (%s)'%(te1_measurement.unit))
plt.ylabel('Counts')
plt.xlim(1.e-7, 1.e-4)
plt.xscale('log')

In [None]:
matchedCatalog = butler_gen3.getDirect(matched_catalog_tract_refs[0])

In [None]:
print('%.2e'%len(matchedCatalog))

In [None]:
matchedCatalog

In [None]:
filteredCat = filterMatches(matchedCatalog, snrMin=50)

In [None]:
print(filteredCat.count)

In [None]:
dir(filteredCat)

In [None]:
#filteredCat = filterMatches(matchedCatalog)
#nMinTEx = 50
#if filteredCat.count <= nMinTEx:
#    return Struct(measurement=Measurement(metric_name, np.nan*u.Unit('')))

radius, xip, xip_err = correlation_function_ellipticity_from_matches(filteredCat)

In [None]:
xip

In [None]:
xip_err

In [None]:
xip / xip_err

In [None]:
plt.figure()
plt.scatter(radius, np.fabs(xip / xip_err))
plt.xscale('log')

In [None]:
plt.figure()
#plt.scatter(radius, np.fabs(xip))
#plt.plot(radius, xip)
#plt.plot(radius, -1. * xip)
plt.errorbar(radius.value, xip, yerr=np.where(xip > 0, xip_err, 0), 
            marker='o', c='none', mfc='black', mec='black', ecolor='black', barsabove=True)
plt.errorbar(radius.value, -1 * xip, yerr=np.where(xip < 0, xip_err, 0), 
             marker='o', c='none', ls='--', mfc='none', mec='black', ecolor='black', barsabove=True)
plt.hlines(te1_measurement.value, color='red', xmin=0., xmax=1.)
plt.hlines(job_r9813.measurements['validate_drp.TE1'].quantity, color='black', xmin=0., xmax=1.)
plt.hlines(te2_measurement.value, color='red', xmin=5., xmax=100.)
plt.hlines(job_r9813.measurements['validate_drp.TE2'].quantity, color='black', xmin=5., xmax=100.)
plt.xscale('log')
plt.yscale('log')
xlim = plt.xlim()
plt.ylim(1.e-8, 1.e-3)
plt.axvspan(5., plt.xlim()[-1], color='0.9')
plt.axvspan(plt.xlim()[0], 1., color='0.9')
plt.xlim(xlim)