## **DP1 DIASource Analysis**

## 1. Initialize data & visualization

To be run on Rubin Science Platform

In [None]:
from lsst.rsp import get_tap_service
from lsst.rsp.service import get_siav2_service
from lsst.rsp.utils import get_pyvo_auth

import lsst.afw.display as afwDisplay
from lsst.afw.image import ExposureF
from lsst.afw.math import Warper, WarperConfig
from lsst.afw.fits import MemFileManager
import lsst.geom as geom

from pyvo.dal.adhoc import DatalinkResults, SodaQuery

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Rectangle
from astropy import units as u
from astropy.table import Table, vstack
import corner
from glob import glob
from astropy.coordinates import SkyCoord

import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
import random

service = get_tap_service("tap")
assert service is not None

sia_service = get_siav2_service("dp1")
assert sia_service is not None

afwDisplay.setDefaultBackend('matplotlib')

Load in DP1 data (from one of three sites, or from all three sites). Option to retrieve data from .fits files to avoid querying TAP service (takes a long time). 

In [None]:
################################################################################
#                            CONFIGURATION                                      #
################################################################################

FETCH_FROM_SERVER = False    # True: Query TAP service | False: Load from .fits files
LOAD_ALL_SITES = True       # True: Merge all sites | False: Single site only
SITE = 'ecdfs'              # Options: 'ecdfs', 'galactic', 'ecliptic'
ALL_SITES = ['ecdfs', 'galactic', 'ecliptic']

################################################################################


def set_coords(site):
    match site:
        case 'ecdfs':
            return 53.16, -28.10
        case 'galactic':
            return 95.0, -25.0
        case 'ecliptic':
            return 37.98, 7.015

def get_title():
    if LOAD_ALL_SITES:
        return 'All 3 Sites'
    else:
        match SITE:
            case 'ecdfs':
                return 'ECDFS'
            case 'galactic':
                return 'Low Galactic Latitude Field'
            case 'ecliptic':
                return 'Low Ecliptic Latitude Field'

def fetch_site_data(site, service):
    """Fetch data from server for a given site."""
    ra_cen, dec_cen = set_coords(site)
    
    query = """SELECT apFlux, apFlux_flag, apFlux_flag_apertureTruncated, apFluxErr, 
        band, bboxSize, centroid_flag, coord_dec, coord_ra, 
        dec, decErr, detector, diaObjectId, diaSourceId, 
        dipoleAngle, dipoleChi2, dipoleFitAttempted, dipoleFluxDiff, dipoleFluxDiffErr, 
        dipoleLength, dipoleMeanFlux, dipoleMeanFluxErr, dipoleNdata, 
        extendedness, forced_PsfFlux_flag, forced_PsfFlux_flag_edge, 
        forced_PsfFlux_flag_noGoodPixels, isDipole, 
        ixx, ixxPSF, ixy, ixyPSF, iyy, iyyPSF, 
        midpointMjdTai, parentDiaSourceId, 
        pixelFlags, pixelFlags_bad, pixelFlags_cr, pixelFlags_crCenter, 
        pixelFlags_edge, pixelFlags_injected, pixelFlags_injected_template, 
        pixelFlags_injected_templateCenter, pixelFlags_injectedCenter, 
        pixelFlags_interpolated, pixelFlags_interpolatedCenter, 
        pixelFlags_nodata, pixelFlags_nodataCenter, pixelFlags_offimage, 
        pixelFlags_saturated, pixelFlags_saturatedCenter, 
        pixelFlags_streak, pixelFlags_streakCenter, 
        pixelFlags_suspect, pixelFlags_suspectCenter, 
        psfChi2, psfFlux, psfFlux_flag, psfFlux_flag_edge, 
        psfFlux_flag_noGoodPixels, psfFluxErr, psfNdata, 
        ra, ra_dec_Cov, raErr, reliability, 
        scienceFlux, scienceFluxErr, 
        shape_flag, shape_flag_no_pixels, shape_flag_not_contained, 
        shape_flag_parent_source, snr, ssObjectId, 
        trail_flag_edge, trailAngle, trailDec, trailFlux, trailLength, trailRa, 
        visit, x, xErr, y, yErr
        FROM dp1.DiaSource
        WHERE CONTAINS(POINT('ICRS', ra, dec),
        CIRCLE('ICRS', {}, {}, 1.0)) = 1
        ORDER BY diaSourceId ASC""".format(ra_cen, dec_cen)
    
    print(f"Querying {site}...")
    print(query)
    
    job = service.submit_job(query)
    job.run()
    job.wait(phases=['COMPLETED', 'ERROR'])
    print(f'Job phase is {job.phase}')
    
    if job.phase == 'ERROR':
        job.raise_if_error()
    assert job.phase == 'COMPLETED'
    
    results = job.fetch_result().to_table()
    print(f"Retrieved {len(results)} rows with {len(results.colnames)} columns")
    
    results.write(f'{site}.fits', format='fits', overwrite=True)
    print(f"Saved {len(results)} rows to {site}.fits")
    
    return results


if FETCH_FROM_SERVER:
    if LOAD_ALL_SITES:
        all_results = []
        for s in ALL_SITES:
            result = fetch_site_data(s, service)
            result['site'] = s
            all_results.append(result)
        results = vstack(all_results)
        print(f"Merged {len(results)} total rows from all sites")
    else:
        results = fetch_site_data(SITE, service)
else:
    if LOAD_ALL_SITES:
        all_results = []
        for s in ALL_SITES:
            data = Table.read(f'{s}.fits')
            print(f"Loaded {len(data)} rows from {s}.fits")
            data['site'] = s
            all_results.append(data)
        results = vstack(all_results)
        print(f"Merged {len(results)} total rows from all sites")
    else:
        results = Table.read(f'{SITE}.fits')
        print(f"Loaded {len(results)} rows from {SITE}.fits")

Generated columns

In [None]:
def add_engineered_features(table):
    names = table.colnames
    names.remove('diaSourceId')
    names.insert(0, 'diaSourceId')
    table = table[names]
    
    # Engineered features
    table['flux_ext'] = table['apFlux'] / table['psfFlux']
    
    table['ellip_ext'] = (
        (np.sqrt((table['ixx'] - table['iyy'])**2 + 4 * table['ixy']**2) / 
         (table['ixx'] + table['iyy'])) - 
        (np.sqrt((table['ixxPSF'] - table['iyyPSF'])**2 + 4 * table['ixyPSF']**2) / 
         (table['ixxPSF'] + table['iyyPSF']))
    )
    
    table['i_ext'] = (table['ixx'] + table['iyy']) / (table['ixxPSF'] + table['iyyPSF'])
    
    table['template_flux'] = table['scienceFlux'] - table['psfFlux']
    table['temp_sci_flux_ratio'] = table['template_flux'] / table['scienceFlux']
    
    # For FWHM circle on plot (converted to pixels)
    table['psf_fwhm'] = (table['ixxPSF'] * table['iyyPSF'] - table['ixyPSF']**2)**(1/4) * 2.35482 * 5
    
    return table

results = add_engineered_features(results)

DP1 data summary plots

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import Rectangle

SHOW_INSET = False

band_colors = {
    'u': '#0c71ff',
    'g': '#49be61',
    'r': '#c61c00',
    'i': '#ffc200',
    'z': '#f341a2',
    'y': '#5d0000'
}

BAND_ORDER = ['u', 'g', 'r', 'i', 'z', 'y']

unique_visits_per_band = {}
for band_name in np.unique(results['band']):
    band_mask = results['band'] == band_name
    unique_visits_per_band[band_name] = len(np.unique(results['visit'][band_mask]))

sources_per_band = {}
for band_name in np.unique(results['band']):
    band_mask = results['band'] == band_name
    sources_per_band[band_name] = np.sum(band_mask)

sources_per_visit_by_band = {}
for band_name in np.unique(results['band']):
    band_mask = results['band'] == band_name
    band_data = results[band_mask]
    
    unique_visits = np.unique(band_data['visit'])
    sources_per_visit = []
    
    for visit in unique_visits:
        visit_mask = band_data['visit'] == visit
        sources_per_visit.append(np.sum(visit_mask))
    
    sources_per_visit_by_band[band_name] = np.array(sources_per_visit)

total_visits = sum(unique_visits_per_band.values())
total_sources = sum(sources_per_band.values())

fig = plt.figure(figsize=(14, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

fig.suptitle(f'DP1 Information, {get_title()}', fontsize=16, fontweight='bold')

bands = [b for b in BAND_ORDER if b in unique_visits_per_band]
colors = [band_colors[b] for b in bands]

ax1 = fig.add_subplot(gs[0, 0])
visits = [unique_visits_per_band[b] for b in bands]
bars1 = ax1.bar(bands, visits, color=colors, alpha=0.7, edgecolor='black')
ax1.set_xlabel('Band', fontsize=12)
ax1.set_ylabel('Number of Visits', fontsize=12)
ax1.set_title('Number of Visits per Band', fontsize=13, fontweight='bold')
ax1.set_ylim(0, 250)
ax1.grid(axis='y', alpha=0.3)

for bar in bars1:
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(height)}', ha='center', va='bottom', fontsize=10)

ax1.text(0.98, 0.98, f'Total: {total_visits}', transform=ax1.transAxes,
        fontsize=11, verticalalignment='top', horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.7))

ax2 = fig.add_subplot(gs[0, 1])
sources = [sources_per_band[b] for b in bands]
bars2 = ax2.bar(bands, sources, color=colors, alpha=0.7, edgecolor='black')
ax2.set_xlabel('Band', fontsize=12)
ax2.set_ylabel('Number of DIA Sources', fontsize=12)
ax2.set_title('DIA Sources per Band', fontsize=13, fontweight='bold')
ax2.set_ylim(0, 300000)
ax2.grid(axis='y', alpha=0.3)

for bar in bars2:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(height)}', ha='center', va='bottom', fontsize=10)

ax2.text(0.98, 0.98, f'Total: {total_sources}', transform=ax2.transAxes,
        fontsize=11, verticalalignment='top', horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.7))

ax3 = fig.add_subplot(gs[1, :])
all_data = [sources_per_visit_by_band[b] for b in bands]

stats_text = []
for band in bands:
    data = sources_per_visit_by_band[band]
    mean_val = np.mean(data)
    std_val = np.std(data)
    stats_text.append(f'{band}: μ={mean_val:.1f}, σ={std_val:.1f}')

bin_width = 100
all_combined = np.concatenate(all_data)
bins = np.arange(0, np.max(all_combined) + bin_width, bin_width)

n, bins, patches = ax3.hist(all_data, bins=bins, label=bands, 
                             color=colors, alpha=0.7, 
                             edgecolor='black', stacked=True)

ax3.set_xlabel('DIA Sources per Visit', fontsize=12)
ax3.set_ylabel('Frequency', fontsize=12)
ax3.set_title('Distribution of DIA Sources per Visit', 
              fontsize=13, fontweight='bold')
ax3.set_xlim(0, 5000)
ax3.set_ylim(0, 250)
ax3.legend(loc='upper right', fontsize=10)
ax3.grid(axis='y', alpha=0.3)

stats_str = '\n'.join(stats_text)
ax3.text(0.98, 0.4, stats_str, transform=ax3.transAxes,
        fontsize=10, verticalalignment='center', horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

if SHOW_INSET:
    axins = inset_axes(ax3, width="35%", height="45%", loc='upper center',
                       bbox_to_anchor=(0, 0, 1, 1),
                       bbox_transform=ax3.transAxes)
    
    axins.hist(all_data, bins=bins, label=bands, 
               color=colors, alpha=0.7, 
               edgecolor='black', stacked=True)
    
    axins.set_xlim(500, 5000)
    axins.set_ylim(0, 50)
    
    axins.set_xlabel('Sources/Visit', fontsize=9)
    axins.set_ylabel('Frequency', fontsize=9)
    axins.tick_params(labelsize=8)
    axins.grid(axis='y', alpha=0.3)
    
    rect = Rectangle((500, 0), 4500, 50, linewidth=1.5, 
                    edgecolor='red', facecolor='none', linestyle='--')
    ax3.add_patch(rect)

plt.show()

## 2. Data filtering & additional calculations

Dictionary of all thresholds

In [None]:
thresholds = {
    'snr': 5,
    'flux_ext': 0.35,
    'i_ext': 0.5,
    'ellip_ext': 0.2,
    'temp_sci_flux_ratio': 0.85,
    'flux_caps': {
        'u': 88644.7,
        'g': 118074.2,
        'r': 166872.5,
        'i': 203090.9,
        'z': 257254.0,
        'y': 264794.0,
    }
}

Quality filters

In [None]:
filtered = results.copy()
sel_snr = filtered['snr'] > thresholds['snr']
filtered = filtered[sel_snr]
sel_no_flag = ~filtered['apFlux_flag']
sel_no_flag &= ~filtered['psfFlux_flag']
sel_no_flag &= ~filtered['pixelFlags_cr']
sel_no_flag &= ~filtered['pixelFlags_bad']
sel_no_flag &= ~filtered['pixelFlags_nodata']
sel_no_flag &= ~filtered['pixelFlags_interpolated']
sel_no_flag &= ~filtered['pixelFlags_saturated']
sel_no_flag &= ~filtered['pixelFlags_suspect']
filtered = filtered[sel_no_flag]
same_sign = ((filtered['psfFlux'] > 0) & (filtered['apFlux'] > 0) & (filtered['scienceFlux'] > 0)) | \
            ((filtered['psfFlux'] < 0) & (filtered['apFlux'] < 0) & (filtered['scienceFlux'] < 0))
filtered = filtered[same_sign]

print(len(filtered))

Ellipticity, moving object, and flux threshold filter

In [None]:
results_s = filtered

## extendedness filter
ext_filter = (results_s['flux_ext'] > thresholds['flux_ext']) & \
             (results_s['i_ext'] > thresholds['i_ext']) & \
             (results_s['ellip_ext'] > thresholds['ellip_ext'])
results_s = results_s[ext_filter]

## moving object filter
moving_obj_filter = results_s['temp_sci_flux_ratio'] > thresholds['temp_sci_flux_ratio']
results_s = results_s[moving_obj_filter]

## flux threshold filter
threshold_array = np.array([thresholds['flux_caps'][band] for band in results_s['band']])
flux_cap = results_s['template_flux'] < threshold_array
results_s = results_s[flux_cap]

print(len(results_s))

Plot: ellipticity difference metric vs. two calculated extendedness metrics

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))
y_data = filtered['i_ext']
x_data = np.log10(filtered['flux_ext'])

snr_data = filtered['ellip_ext']

scatter = ax.scatter(x_data, y_data, 
                     alpha=0.5, s=1, c=snr_data, 
                     cmap='viridis_r', label=None, vmin=0, vmax=1)

ax.set_xlabel('log10(apFlux/psfFlux)', fontsize=12)
ax.set_ylabel('(ixx+iyy) / (ixxPSF+iyyPSF)', fontsize=12)
ax.set_title(f'Ellipticity Difference, {get_title()}', fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 5)

cbar = plt.colorbar(scatter, ax=ax, label='ellipticity')

plt.tight_layout()
plt.show()

Plot: sources that pass extendedness filter

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))

y_data1 = filtered['i_ext']
x_data1 = np.log10(filtered['flux_ext'])

y_data2 = results_s['i_ext']
x_data2 = np.log10(results_s['flux_ext'])

ax.scatter(x_data1, y_data1, 
           alpha=0.5, s=1, c='blue', label='All sources after filter')

ax.scatter(x_data2, y_data2, 
           alpha=1, s=2, c='red', label='Extended sources')

ax.set_xlabel('log10(apFlux/psfFlux)', fontsize=12)
ax.set_ylabel('(ixx+iyy) / (ixxPSF+iyyPSF)', fontsize=12)
ax.set_title(f'Extendedness filter, {get_title()}', fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(markerscale=5)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 5)

plt.tight_layout()
plt.show()

Corner plot: extendedness metrics

In [None]:
columns_corner = ['flux_ext', 'ellip_ext', 'i_ext', 'extendedness']
axes_scale = ['log', 'linear', 'linear', 'linear']
ranges = [(0.1, 10.0), (0, 1), (0, 4), (0, 1)]
data_array = [filtered[col] for col in columns_corner]
labels = [
    'Flux Ext.',
    'Ellip. Diff.',
    'Moment Ext.',
    'Extendedness'
]

fig = corner.corner(np.array(data_array).T, 
                    labels=labels,
                    axes_scale=axes_scale,
                    range=ranges,
                    fill_contours=True, 
                    smooth=0.7, 
                    show_titles=False, 
                    color='grey',
                    plot_datapoints=True,
                    plot_contours=True,
                    plot_density=True,
                    bins=15)

axes = np.array(fig.axes).reshape((len(columns_corner), len(columns_corner)))
for i in range(len(columns_corner)):
    for j in range(i):
        ax = axes[i, j]
        ax.scatter(data_array[j], data_array[i], s=2, alpha=0.03, color='black', rasterized=True)

fig.suptitle(f'Extendedness Parameter Correlations, {get_title()}', fontsize=18, fontweight='bold', y=0.99)

# Text box moved up and left with padding
equation_text = (
    r'$\mathrm{Flux\ Ext.} = \log(\mathrm{Aperture\ flux} / \mathrm{PSF\ flux})$' + '\n\n' +
    r'$\mathrm{Moment\ Ext.} = \frac{I_{xx} + I_{yy}}{I_{xx}^{\mathrm{PSF}} + I_{yy}^{\mathrm{PSF}}}$' + '\n\n' +
    r'$\mathrm{Ellip.\ Diff.} = \frac{\sqrt{(I_{xx}-I_{yy})^2+4I_{xy}^2}}{I_{xx}+I_{yy}} - '
    r'\frac{\sqrt{(I_{xx}^{\mathrm{PSF}}-I_{yy}^{\mathrm{PSF}})^2+4I_{xy}^{\mathrm{PSF}\ 2}}}'
    r'{I_{xx}^{\mathrm{PSF}}+I_{yy}^{\mathrm{PSF}}}$'
)

fig.text(0.55, 0.82, equation_text, fontsize=13,
         bbox=dict(boxstyle='round,pad=0.8', facecolor='white', edgecolor='black', alpha=0.9),
         verticalalignment='top')

plt.show()

View sources that pass all filters

In [None]:
results_s

## 3. Gallery view of sources passing all filters

This produces a gallery of DIA Sources in results_s table. Choose the number of rows and columns for the gallery. It also pulls Legacy Survey cutouts for each. 

In [None]:
# ==============================================================================
# CONFIGURATION
# ==============================================================================
LAYOUT_COLS = 2  # Number of columns
LAYOUT_ROWS = 5  # Number of rows
# ==============================================================================

from matplotlib.patches import Circle
from astropy.visualization import ZScaleInterval
import requests
from io import BytesIO
from PIL import Image

def get_cutout_with_retry(dl_result, spherePoint, session, fov, max_retries=3):
    """Get a cutout with exponential backoff retry logic."""
    sq = SodaQuery.from_resource(dl_result,
                                 dl_result.get_adhocservice_by_id("cutout-sync-exposure"),
                                 session=session)
    sphereRadius = fov * u.deg
    sq.circle = (spherePoint.getRa().asDegrees() * u.deg,
                 spherePoint.getDec().asDegrees() * u.deg,
                 sphereRadius)
    
    for attempt in range(max_retries):
        try:
            cutout_bytes = sq.execute_stream().read()
            sq.raise_if_error()
            mem = MemFileManager(len(cutout_bytes))
            mem.setData(cutout_bytes, len(cutout_bytes))
            return ExposureF(mem)
        except Exception as e:
            if '429' in str(e) and attempt < max_retries - 1:
                wait_time = (2 ** attempt) + random.uniform(0, 1)
                time.sleep(wait_time)
            else:
                raise

def fetch_images_for_row(row, sia_service, get_pyvo_auth, fov=0.003):
    """Fetch all images for a single row with rate limiting."""
    try:
        row_start = time.time()
        
        ra = row['ra']
        dec = row['dec']
        visit = row['visit']
        band = row['band']
        diaSourceId = row['diaSourceId']
        snr = row['snr']
        extendedness = row['extendedness']
        flux_ext = row['flux_ext']
        ellip_ext = row['ellip_ext']
        i_ext = row['i_ext']
        template_flux = row['template_flux']
        scienceFlux = row['scienceFlux']
        psfFlux = row['psfFlux']
        apFlux = row['apFlux']
        psf_fwhm = row['psf_fwhm']
        
        spherePoint = geom.SpherePoint(ra*geom.degrees, dec*geom.degrees)
        circle = (ra, dec, 0.0001)
        
        t0 = time.time()
        lvl2_table = sia_service.search(pos=circle, calib_level=2).to_table()
        sel = lvl2_table['dataproduct_subtype'] == 'lsst.visit_image'
        sel &= lvl2_table['lsst_visit'] == visit
        sci_table = lvl2_table[sel]
        
        if len(sci_table) == 0:
            return None, f"No science images found"
        
        lvl3_table = sia_service.search(pos=circle, calib_level=3).to_table()
        search_time = time.time() - t0
        
        sel = lvl3_table['dataproduct_subtype'] == 'lsst.template_coadd'
        sel &= lvl3_table['lsst_band'] == band
        ref_table = lvl3_table[sel]
        
        if len(ref_table) == 0:
            return None, f"No template images found"
        
        sel = lvl3_table['dataproduct_subtype'] == 'lsst.difference_image'
        sel &= lvl3_table['lsst_visit'] == visit
        diff_table = lvl3_table[sel]
        
        if len(diff_table) == 0:
            return None, f"No difference images found"
        
        t0 = time.time()
        dl_result_sci = DatalinkResults.from_result_url(sci_table['access_url'][0], session=get_pyvo_auth())
        dl_result_ref = DatalinkResults.from_result_url(ref_table['access_url'][0], session=get_pyvo_auth())
        dl_result_diff = DatalinkResults.from_result_url(diff_table['access_url'][0], session=get_pyvo_auth())
        datalink_time = time.time() - t0
        
        t0 = time.time()
        sci = get_cutout_with_retry(dl_result_sci, spherePoint, get_pyvo_auth(), fov)
        time.sleep(0.3)
        ref = get_cutout_with_retry(dl_result_ref, spherePoint, get_pyvo_auth(), fov)
        time.sleep(0.3)
        diff = get_cutout_with_retry(dl_result_diff, spherePoint, get_pyvo_auth(), fov)
        download_time = time.time() - t0
        
        t0 = time.time()
        warper_config = WarperConfig()
        warper = Warper.fromConfig(warper_config)
        sci_wcs = sci.getWcs()
        sci_bbox = sci.getBBox()
        warped_ref = warper.warpExposure(sci_wcs, ref, destBBox=sci_bbox)
        warp_time = time.time() - t0
        
        total_time = time.time() - row_start
        
        return {
            'visit': visit,
            'band': band,
            'diaSourceId': diaSourceId,
            'ra': ra,
            'dec': dec,
            'sci': sci,
            'snr': snr,
            'extendedness': extendedness,
            'flux_ext': flux_ext,
            'ellip_ext': ellip_ext,
            'i_ext': i_ext,
            'template_flux': template_flux,
            'scienceFlux': scienceFlux,
            'psfFlux': psfFlux,
            'apFlux': apFlux,
            'psf_fwhm': psf_fwhm,
            'warped_ref': warped_ref,
            'diff': diff,
            'search_time': search_time,
            'datalink_time': datalink_time,
            'download_time': download_time,
            'warp_time': warp_time,
            'total_time': total_time
        }, None
        
    except Exception as e:
        return None, str(e)

n_images = LAYOUT_COLS * LAYOUT_ROWS
sample_indices = np.random.choice(len(results_s), size=min(n_images, len(results_s)), replace=False)
sampled_rows = [results_s[i] for i in sample_indices]

print(f"\n{'='*60}")
print(f"Starting gallery creation at {datetime.now().strftime('%H:%M:%S')}")
print(f"Layout: {LAYOUT_COLS} columns x {LAYOUT_ROWS} rows = {n_images} images")
print(f"{'='*60}\n")
start_time = time.time()

gallery_results = []
with ThreadPoolExecutor(max_workers=2) as executor:
    futures = {executor.submit(fetch_images_for_row, row, sia_service, get_pyvo_auth): idx 
               for idx, row in enumerate(sampled_rows)}
    
    for future in as_completed(futures):
        idx = futures[future]
        result, error = future.result()
        
        if result:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] ✓ {len(gallery_results)+1}/{n_images}: visit={result['visit']}, diaSourceId={result['diaSourceId']}, time={result['total_time']:.1f}s")
            gallery_results.append(result)
        else:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] ✗ Failed: {error}")
        
        if len(gallery_results) >= n_images:
            break

fetch_time = time.time() - start_time
print(f"\nAll images fetched in {fetch_time:.1f}s ({fetch_time/60:.1f} min)")
print(f"Average per image set: {fetch_time/len(gallery_results):.1f}s\n")

actual_n_images = len(gallery_results)
if actual_n_images < n_images:
    print(f"\n⚠️  Warning: Only {actual_n_images} images available, but layout is {LAYOUT_COLS}x{LAYOUT_ROWS}={n_images}")
    print(f"Remaining {n_images - actual_n_images} cells will be left empty/white\n")

print("Creating plots...")
plot_start = time.time()

fig = plt.figure(figsize=(LAYOUT_COLS * 9, LAYOUT_ROWS * 3))
gs = GridSpec(LAYOUT_ROWS, LAYOUT_COLS, figure=fig, 
              hspace=0.02, wspace=0.02,
              left=0.01, right=0.99, top=0.99, bottom=0.01)

for idx in range(n_images):
    gallery_row = idx // LAYOUT_COLS
    gallery_col = idx % LAYOUT_COLS
    
    gs_sub = gs[gallery_row, gallery_col].subgridspec(1, 3, wspace=0.01)
    
    ax1 = fig.add_subplot(gs_sub[0, 0])
    ax2 = fig.add_subplot(gs_sub[0, 1])
    ax3 = fig.add_subplot(gs_sub[0, 2])
    
    if idx < len(gallery_results):
        result = gallery_results[idx]
        
        for ax, img in [(ax1, result['sci']), (ax2, result['warped_ref']), (ax3, result['diff'])]:
            interval = ZScaleInterval()
            vmin, vmax = interval.get_limits(img.image.array)
            
            ax.imshow(img.image.array, cmap='gray', origin='lower', 
                      vmin=vmin, vmax=vmax, aspect='equal', interpolation='nearest')
            ax.set_axis_off()
            ax.set_position(ax.get_position())
            ax.margins(0, 0)
            ax.set_xlim(0, img.image.array.shape[1])
            ax.set_ylim(0, img.image.array.shape[0])
        
        ax1.text(0.02, 0.98, f'Visit: {result["visit"]}\nSNR: {result["snr"]:.3f}\nSci Flux: {result["scienceFlux"]:.1f}', 
                 transform=ax1.transAxes, ha='left', va='top',
                 fontsize=10, color='white',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
        
        ax1.text(0.02, 0.02, f'{result["band"]}', 
                 transform=ax1.transAxes, ha='left', va='bottom',
                 fontsize=16, color='white',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))

        ax2.text(0.02, 0.98, f'RA: {result["ra"]}\nDEC: {result["dec"]}', 
                 transform=ax2.transAxes, ha='left', va='top',
                 fontsize=10, color='white',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))

        ax2.text(0.02, 0.02, f'Template Flux: {result["template_flux"]:.1f} ({(result["template_flux"]/result["scienceFlux"]*100):.0f}% sci flux)', 
                 transform=ax2.transAxes, ha='left', va='bottom',
                 fontsize=10, color='white',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
        
        ax3.text(0.98, 0.98, f'DIASourceID: {result["diaSourceId"]}\nPSF Flux: {result["psfFlux"]:.1f} | Ap Flux: {result["apFlux"]:.1f}', 
                 transform=ax3.transAxes, ha='right', va='top',
                 fontsize=10, color='white',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))

        ax3.text(0.98, 0.02, f'Ext.: {result["extendedness"]:.3f}\nLog Flux Ext.: {np.log10(result["flux_ext"]):.3f}\nEllip. Diff.: {result["ellip_ext"]:.3f}\nMoment Ext.: {result["i_ext"]:.3f}', 
                 transform=ax3.transAxes, ha='right', va='bottom',
                 fontsize=10, color='white',
                 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
        
        psf_fwhm_pixels = result['psf_fwhm']
        circle_x = psf_fwhm_pixels * 1.5
        circle_y = psf_fwhm_pixels * 1.5
        
        circle = Circle((circle_x, circle_y), psf_fwhm_pixels / 2, 
                       fill=False, edgecolor='cyan', linewidth=2, alpha=0.8)
        ax3.add_patch(circle)
        
        cross_length = psf_fwhm_pixels / 2
        ax3.plot([circle_x - cross_length, circle_x + cross_length], 
                [circle_y, circle_y], 
                color='cyan', linewidth=2, alpha=0.8)
        ax3.plot([circle_x, circle_x], 
                [circle_y - cross_length, circle_y + cross_length], 
                color='cyan', linewidth=2, alpha=0.8)
        
    else:
        for ax in [ax1, ax2, ax3]:
            ax.set_axis_off()

plot_time = time.time() - plot_start
print(f"Plotting complete in {plot_time:.1f}s")

total_time = time.time() - start_time
print(f"\n{'='*60}")
print(f"Complete at {datetime.now().strftime('%H:%M:%S')}")
print(f"Total time: {total_time:.1f}s ({total_time/60:.1f} minutes)")
print(f"Created {len(gallery_results)} image sets")
print(f"{'='*60}\n")

plt.show()

# ==============================================================================
# LEGACY SURVEY IMAGE ARRAY
# ==============================================================================
print("\n" + "="*60)
print("Creating Legacy Survey cutout gallery...")
print("="*60 + "\n")

def fetch_legacy_survey_cutout(ra, dec, pixscale=0.15, timeout=10, max_retries=3):
    """Fetch a cutout from Legacy Survey with retry logic."""
    url = f"https://www.legacysurvey.org/viewer/cutout.jpg?ra={ra}&dec={dec}&layer=ls-dr9&pixscale={pixscale}"
    
    for attempt in range(max_retries):
        try:
            response = requests.get(url, timeout=timeout)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content))
            return np.array(img), None
        except Exception as e:
            if attempt < max_retries - 1:
                wait_time = 1.0 + attempt
                print(f"    ⚠ Retry {attempt + 1}/{max_retries - 1} after {wait_time:.1f}s...")
                time.sleep(wait_time)
            else:
                return None, str(e)
    
    return None, "Max retries exceeded"

print("Downloading Legacy Survey cutouts...")
legacy_start = time.time()
legacy_images = []

for idx, result in enumerate(gallery_results):
    img, error = fetch_legacy_survey_cutout(result['ra'], result['dec'])
    if img is not None:
        legacy_images.append(img)
        print(f"✓ {idx+1}/{len(gallery_results)}: Downloaded cutout for DIASourceID {result['diaSourceId']}")
    else:
        print(f"✗ {idx+1}/{len(gallery_results)}: Failed to download - {error}")
        legacy_images.append(None)
    time.sleep(0.1)

legacy_fetch_time = time.time() - legacy_start
print(f"\nLegacy Survey cutouts downloaded in {legacy_fetch_time:.1f}s\n")

fig_legacy = plt.figure(figsize=(LAYOUT_COLS * 3, LAYOUT_ROWS * 3))
gs_legacy = GridSpec(LAYOUT_ROWS, LAYOUT_COLS, figure=fig_legacy,
                     hspace=0.02, wspace=0.02,
                     left=0.01, right=0.99, top=0.99, bottom=0.01)

for idx in range(n_images):
    gallery_row = idx // LAYOUT_COLS
    gallery_col = idx % LAYOUT_COLS
    
    ax = fig_legacy.add_subplot(gs_legacy[gallery_row, gallery_col])
    
    if idx < len(legacy_images) and legacy_images[idx] is not None:
        ax.imshow(legacy_images[idx], origin='upper', aspect='equal', interpolation='nearest')
        ax.set_axis_off()
        ax.margins(0, 0)
    else:
        ax.set_axis_off()

legacy_plot_time = time.time() - legacy_start - legacy_fetch_time
print(f"Legacy Survey plotting complete in {legacy_plot_time:.1f}s")

total_legacy_time = time.time() - legacy_start
print(f"Total Legacy Survey gallery time: {total_legacy_time:.1f}s\n")

plt.show()

print("="*60)
print("Both galleries complete")
print("="*60)

## 4. Injected Source Analysis

Pull Shenming's injected sources & identify LAGN DIASources from injection coordinates

In [None]:
inj_radec_files = sorted(glob('/home/sfu/shared_lagn_injection/v0.4/inj_radec*.fits'))
print("Merging inj_radec files:")
for f in inj_radec_files:
    print(f"  {f}")
inj_radec = vstack([Table.read(f) for f in inj_radec_files])
print(f"Total inj_radec rows: {len(inj_radec)}\n")

no_inj_sources_files = sorted(glob('/home/sfu/shared_lagn_injection/v0.4/diaSources_tab_*.csv'))
print("Merging no_inj_sources files:")
for f in no_inj_sources_files:
    print(f"  {f}")
no_inj_sources = vstack([Table.read(f, format='csv') for f in no_inj_sources_files])
print(f"Total no_inj_sources rows: {len(no_inj_sources)}\n")

inj_sources_files = sorted(glob('/home/sfu/shared_lagn_injection/v0.4/injected_diaSources_tab_*.csv'))
print("Merging inj_sources files:")
for f in inj_sources_files:
    print(f"  {f}")
inj_sources = vstack([Table.read(f, format='csv') for f in inj_sources_files])
print(f"Total inj_sources rows: {len(inj_sources)}\n")

def lagn_lookup(sources, lookup, max_sep=0.002):
    src_coords = SkyCoord(ra=sources['ra'], dec=sources['dec'], unit='deg')
    inj_coords = SkyCoord(ra=lookup['ra'], dec=lookup['dec'], unit='deg')
    _, sep2d, _ = src_coords.match_to_catalog_sky(inj_coords)
    return sep2d.deg < max_sep

# Create the 'lagn' column
inj_sources['lagn'] = lagn_lookup(inj_sources, inj_radec)
inj_sources = add_engineered_features(inj_sources)

Plot of DIASources with & without injection. Includes injection coordinates. 

In [None]:
plt.figure(figsize=(10, 8), dpi=300)
plt.scatter(inj_sources['ra'], inj_sources['dec'], c='red', alpha=0.3, label='Sources, injection')
plt.scatter(no_inj_sources['ra'], no_inj_sources['dec'], c='blue', alpha=0.3, label='Sources, no injection')
plt.scatter(inj_radec['ra'], inj_radec['dec'], c='green', alpha=0.3, label='Injection coords.')

plt.xlabel('RA (degrees)')
plt.ylabel('Dec (degrees)')
plt.legend()
plt.grid(True, alpha=0.3)

ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()

plt.show()

Corner plot with DP1 DIASources (gray contours) and injected LAGN (green contours)

In [None]:
columns_corner = ['flux_ext', 'ellip_ext', 'i_ext', 'extendedness', 'psfChi2', 'trailLength', 'trailFlux', 'snr', 'temp_sci_flux_ratio',]
axes_scale = ['log', 'linear', 'linear', 'linear', 'log', 'linear', 'log', 'log', 'linear']
ranges = [(0.1, 10.0), (0, 1), (0, 4), (0, 1), (100, 10000), (0, 4), (1000, 100000), (1, 100), (0, 1)]

data_array_all = [results[col] for col in columns_corner]
data_array_lagn = [inj_sources[inj_sources['lagn']][col] for col in columns_corner]

labels = [
    'Flux Ext.',
    'Ellip. Diff.',
    'Moment Ext.',
    'Extendedness',
    'PSF χ²',
    'Trail Length',
    'Trail Flux',
    'SNR',
    'Temp./Sci. Flux'
]

fig = corner.corner(np.array(data_array_all).T, 
                    labels=labels,
                    axes_scale=axes_scale,
                    range=ranges,
                    fill_contours=True, 
                    smooth=0.7, 
                    show_titles=False, 
                    color='grey',
                    plot_datapoints=False,
                    plot_contours=True,
                    plot_density=True,
                    bins=20,
                    fig=plt.figure(figsize=(15, 15)),
                    label_kwargs=dict(fontsize=12),
                    max_n_ticks=3,
                    hist_kwargs=dict(density=True)
                    )

fig = corner.corner(np.array(data_array_lagn).T,
                    labels=labels,
                    axes_scale=axes_scale,
                    range=ranges,
                    fill_contours=True,
                    smooth=0.7,
                    show_titles=False,
                    color='green',
                    plot_datapoints=False,
                    plot_contours=True,
                    plot_density=True,
                    bins=20,
                    fig=fig,
                    label_kwargs=dict(fontsize=12),
                    max_n_ticks=3,
                    hist_kwargs=dict(density=True)
                    )

axes = np.array(fig.axes).reshape((len(columns_corner), len(columns_corner)))

for ax in fig.axes:
    ax.tick_params(labelsize=9)
    for label in ax.get_xticklabels():
        label.set_rotation(0)
    ax.tick_params(axis='both', which='major', pad=2)

col_indices = {col: columns_corner.index(col) for col in thresholds.keys() if col in columns_corner}

for i in range(len(columns_corner)):
    for j in range(i+1):
        ax = axes[i, j]
        
        col_y = columns_corner[i]
        col_x = columns_corner[j] if i != j else columns_corner[i]
        
        if i != j:
            if col_x in col_indices:
                threshold_x = thresholds[col_x]
                ax.axvline(threshold_x, color='blue', linestyle='--', linewidth=1.5, alpha=0.6, zorder=5)
                ax.axvspan(threshold_x, ranges[j][1], alpha=0.1, color='blue', zorder=1)
            
            if col_y in col_indices:
                threshold_y = thresholds[col_y]
                ax.axhline(threshold_y, color='blue', linestyle='--', linewidth=1.5, alpha=0.6, zorder=5)
                ax.axhspan(threshold_y, ranges[i][1], alpha=0.1, color='blue', zorder=1)
        
        else:
            if col_x in col_indices:
                threshold = thresholds[col_x]
                ax.axvline(threshold, color='blue', linestyle='--', linewidth=1.5, alpha=0.6, zorder=5)
                ylim = ax.get_ylim()
                ax.axvspan(threshold, ranges[i][1], alpha=0.1, color='blue', zorder=1)

equation_text = (
    r'$\mathrm{Flux\ Ext.} = \log(\mathrm{Aperture\ flux} / \mathrm{PSF\ flux})$' + '\n\n' +
    r'$\mathrm{Moment\ Ext.} = \frac{I_{xx} + I_{yy}}{I_{xx}^{\mathrm{PSF}} + I_{yy}^{\mathrm{PSF}}}$' + '\n\n' +
    r'$\mathrm{Ellip.\ Diff.} = \frac{\sqrt{(I_{xx}-I_{yy})^2+4I_{xy}^2}}{I_{xx}+I_{yy}} - '
    r'\frac{\sqrt{(I_{xx}^{\mathrm{PSF}}-I_{yy}^{\mathrm{PSF}})^2+4I_{xy}^{\mathrm{PSF}\ 2}}}'
    r'{I_{xx}^{\mathrm{PSF}}+I_{yy}^{\mathrm{PSF}}}$'
)

fig.text(0.55, 0.82, equation_text, fontsize=13,
         bbox=dict(boxstyle='round,pad=0.8', facecolor='white', edgecolor='black', alpha=0.9),
         verticalalignment='top')

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
legend_elements = [
    Patch(facecolor='grey', edgecolor='black', label='All DP1 DIASources'),
    Patch(facecolor='green', edgecolor='darkgreen', label='Injected LAGN'),
    Line2D([0], [0], color='blue', linestyle='--', linewidth=1.5, alpha=0.6, label='Selection Thresholds')
]
fig.legend(handles=legend_elements, loc='upper right', fontsize=15, framealpha=0.9)

plt.show()