In [None]:
from itertools import chain
from pathlib import Path
import random
import shutil
import time

from collections import defaultdict
from cytoolz import first, groupby
from cytoolz.curried import get
from gPhoton.coadd import get_galex_rice_slices, coadd_image_slices
from gPhoton.reference import eclipse_to_paths
from killscreen.monitors import Netstat, Stopwatch
import pandas as pd
import pyarrow as pa
import pyarrow.csv
from pyarrow import parquet
import sh

from s3_fuse.mount_s3 import conditional_unmount, mount_bucket
from s3_fuse.ps1_utils import prune_ps1_catalog, get_ps1_cutouts
from s3_fuse.utilz import print_stats

In [None]:
# 'configuration'

BUCKET_NAME = 'nishapur'
S3_ROOT = '/mnt/s3'

# desired cutout side length in degrees
CUTOUT_SIDE_LENGTH = 50 / 3600

# how many (randomly-selected) targets would we like cutouts for? 
TARGET_COUNT = 10
# which PS1 bands are we looking at? (currently only g and z are staged.)
PS1_BANDS = ("g", "z")
# shall we do GALEX stuff?
DO_GALEX_STUFF = False

mount_bucket(
    backend="goofys", remount=True, mount_path=S3_ROOT, bucket=BUCKET_NAME
)

In [None]:
# catalog of PS1 extragalactic extended objects, including explicit 
# assignments to PS1 stack image projection / sky cells and GALEX 
# eclipse numbers
catalog_fn = "ps1_extragalactic_skycells_eclipses.parquet"
if not Path(catalog_fn).exists():
    shutil.copy(
        Path(S3_ROOT, "ps1/metadata", catalog_fn),
        Path(catalog_fn)
    )
catalog = parquet.read_table(catalog_fn)

In [None]:
# for this demo, we only staged a subset of those PS1 stack images 
# (all of them at all 5 bands would be > 80 TB). this is a list of 
# the (randomly selected) projection and sky cells we staged.
test_cell_fn = "ps1_extragalactic_skycells_eclipses_1k_cell_subset.csv"
arbitrary_test_cells = (
    pa.csv
    .read_csv(Path(S3_ROOT, "ps1/metadata", test_cell_fn))
    .cast(pa.schema([("proj_cell", pa.uint16()), ("sky_cell", pa.uint8())]))
)
small_catalog = prune_ps1_catalog(catalog, arbitrary_test_cells)

# and a little pruning on GALEX: this is a table of actually-existing MIS-like 
# images by eclipse number, excluding eclipses with data currently flagged as bad
extant_mislike = pd.read_csv(Path(S3_ROOT, "extant_mislike_eclipses.csv"))['0']

In [None]:
def sample_table(table, k=None):
    if k is None:
        return table
    return table.take(random.sample(range(len(table)), k=k))

In [None]:
# how many objects shall we collect slices for? (785510 are available in this test set)
TARGET_COUNT = 40
# optional parameter -- restrict the total number of PS1 source cells to test the 
# performance effects of denser sampling (1000 total PS1 cells are available in this test set).
# note that the number of actual images accessed is a factor of both the number of cells
# and the number of bands under consideration.
# if GALEX fusion is taking place, this will also indirectly
# restrict the number of GALEX images.

MAX_CELL_COUNT = 5
if MAX_CELL_COUNT is not None:
    test_catalog = prune_ps1_catalog(
        small_catalog, sample_table(arbitrary_test_cells, k=MAX_CELL_COUNT)
    )
else:
    test_catalog = small_catalog
targets = sample_table(test_catalog, k=TARGET_COUNT).to_pylist()
ps1_stacks = set((map(get(['proj_cell', 'sky_cell']), targets)))

In [None]:
watch, stat = Stopwatch(silent=True), Netstat()
ps1_groups = groupby(get(['proj_cell', 'sky_cell']), targets)
ps1_cutouts = {}
for stack in ps1_stacks:
    image_targets = ps1_groups[stack]
    cutouts, _ = get_ps1_cutouts(
        image_targets, 
        PS1_BANDS, 
        CUTOUT_SIDE_LENGTH, 
        S3_ROOT, 
        stat, 
        watch,
        verbose=1
    )
    ps1_cutouts |= cutouts

if DO_GALEX_STUFF is True:
    galex_eclipses = {
        e for e in tuple(chain.from_iterable(map(get('galex'), targets)))
        if e in extant_mislike.values
    }
    galex_slices = defaultdict(list)
    systems = {}
    for eclipse in galex_eclipses:
        eclipse_targets = tuple(filter(lambda t: eclipse in t['galex'], targets))
        slices, system = get_galex_rice_slices(
            eclipse, eclipse_targets, CUTOUT_SIDE_LENGTH, S3_ROOT, watch, stat
        )
        systems[eclipse] = system
        for k, v in slices.items():
            galex_slices[k].append(v)

    galex_coadds = {}
    print(
        f"...coadding {len(tuple(chain.from_iterable(galex_slices.values())))} image slices...", 
        end=""
    )
    for obj_id, images in galex_slices.items():
        if len(images) == 0:
            print("all GALEX images for {obj_id} are bad, skipping")
        galex_coadds[obj_id] = coadd_image_slices(images, systems)
    print_stats(watch, stat)
