In [None]:
import yaml
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact
import galsim
%matplotlib widget

In [None]:
def f(i_obs=0, truth=False, wcs='xy'):
    hdul = fits.open(f"public/{i_obs:03d}.fits")
    img = hdul[0].data
    hdr = hdul[0].header
    
    _, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
    im = ax.imshow(img, vmin=0, vmax=100)
    ax.set_xlim(0, img.shape[1])
    ax.set_ylim(0, img.shape[0])
    plt.colorbar(im, ax=ax)

    if truth:
        stars = fits.open("private/truth.fits")[f'STAR_{i_obs:04d}'].data
        sats = fits.open("private/truth.fits")[f'SAT_{i_obs:04d}'].data
        if wcs == 'xy':
            x0 = sats['x0']
            y0 = sats['y0']
            x1 = sats['x1']
            y1 = sats['y1']
            x = stars['x']
            y = stars['y']
        else:
            if wcs == 'approx':
                wcs = galsim.GSFitsWCS(header=hdr)
            else:
                wcs = galsim.GSFitsWCS(f"private/{i_obs:03d}.wcs.fits")
            x0, y0 = wcs.radecToxy(sats['ra0'], sats['dec0'], units='rad')
            x1, y1 = wcs.radecToxy(sats['ra1'], sats['dec1'], units='rad')
            x, y = wcs.radecToxy(stars['ra'], stars['dec'], units='rad')
    else:    
        docs = list(yaml.safe_load_all(open("private/sample_submission.yaml", 'r')))
        doc = docs[i_obs]
        sat = doc['sats'][0]
        if wcs == 'xy':
            x0 = sat['x0']
            y0 = sat['y0']
            x1 = sat['x1']
            y1 = sat['y1']
            x = []
            y = []
            for star in doc['stars']:
                x.append(star['x'])
                y.append(star['y'])  
            x = np.array(x)
            y = np.array(y)
        else:
            if wcs == 'approx':
                wcs = galsim.GSFitsWCS(header=hdr)
            else:
                wcs = galsim.GSFitsWCS(f"private/{i_obs:03d}.wcs.fits")
            x0, y0 = wcs.radecToxy(sat['ra0'], sat['dec0'], units='rad')
            x1, y1 = wcs.radecToxy(sat['ra1'], sat['dec1'], units='rad')
            ra = []
            dec = []
            for star in doc['stars']:
                ra.append(star['ra'])
                dec.append(star['dec'])
            x, y = wcs.radecToxy(np.array(ra), np.array(dec), units='rad')
    # Note the -1 to account for difference in FITS, matplotlib array indexing conventions.
    plt.scatter(x0-1, y0-1, facecolors='none', edgecolors='r', s=100)
    plt.scatter(x1-1, y1-1, facecolors='none', edgecolors='r', s=100)
    plt.scatter(x-1, y-1, facecolors='none', edgecolors='b', s=100)    
    plt.show()

In [None]:
plt.close('all')
interact(f, i_obs=(0, 1), truth=False, wcs=['xy', 'approx', 'truth']);