In [None]:
from collections import defaultdict
from itertools import chain
from pathlib import Path
import shutil
import re

from cytoolz import groupby
from cytoolz.curried import get
from gPhoton.pretty import print_stats
from killscreen.monitors import Netstat, Stopwatch
import pandas as pd
import pyarrow as pa
import pyarrow.csv
from pyarrow import parquet

from s3_fuse.mount_s3 import mount_bucket
from s3_fuse.ps1_utils import prune_ps1_catalog, get_ps1_cutouts
from s3_fuse.utilz import make_loaders, sample_table

In [None]:
# 'configuration'

BUCKET = 'nishapur'
S3_ROOT = '/mnt/s3'

mount_bucket(
    backend="goofys", remount=False, mount_path=S3_ROOT, bucket=BUCKET
)

# desired cutout side length in degrees
CUTOUT_SIDE_LENGTH = 60 / 3600

# which PS1 bands are we looking at? (currently only g and z are staged.)
PS1_BANDS = ("g", "z")
# shall we do GALEX stuff?
DO_GALEX_STUFF = False

# select loaders -- options are "astropy", "fitsio", "greedy_astropy", "greedy_fitsio"
# NOTE: because all the files this particular notebook is looking
# at are RICE-compressed, there is unlikely to be much difference
# between astropy and greedy_astropy -- astropy does not support
# loading individual tiles from a a tile-compressed FITS file.
LOADERS = make_loaders("greedy_fitsio", "fitsio",)

def cleanup_loader(loader_name):
    if "greedy" in loader_name:
        shutil.rmtree("/dev/shm/slicetemp", ignore_errors=True)
        
def parse_topline(log):
    total = next(reversed(log.values()))
    summary, duration, volume = total.split(",")
    cut_count = int(re.search(r"\d+", summary).group())
    seconds = float(re.search(r"\d+\.?\d+", duration).group())
    megabytes = float(re.search(r"\d+\.?\d+", volume).group())
    rate = cut_count / seconds
    weight = megabytes / cut_count
    return round(rate, 2), round(weight, 2)

In [None]:
# catalog of PS1 extragalactic extended objects, including explicit 
# assignments to PS1 stack image projection / sky cells and GALEX 
# eclipse numbers
catalog_fn = "ps1_extragalactic_skycells_eclipses.parquet"
if not Path(catalog_fn).exists():
    shutil.copy(
        Path(S3_ROOT, "ps1/metadata", catalog_fn),
        Path(catalog_fn)
    )
catalog = parquet.read_table(catalog_fn)

# for this demo, we only staged a subset of those PS1 stack images 
# (all of them at all 5 bands would be > 80 TB). this is a list of 
# the (randomly selected) projection and sky cells we staged.
test_cell_fn = "ps1_extragalactic_skycells_eclipses_1k_cell_subset.csv"
arbitrary_test_cells = (
    pa.csv
    .read_csv(Path(S3_ROOT, "ps1/metadata", test_cell_fn))
    .cast(pa.schema([("proj_cell", pa.uint16()), ("sky_cell", pa.uint8())]))
)
small_catalog = prune_ps1_catalog(catalog, arbitrary_test_cells)

# and a little pruning on GALEX: this is a table of actually-existing MIS-like 
# images by eclipse number, excluding eclipses with data currently flagged as bad
extant_mislike = pd.read_csv(Path(S3_ROOT, "extant_mislike_eclipses.csv"))['0']

In [None]:
# how many objects shall we collect slices for? (785510 are available in this test set)
TARGET_COUNT = 800
# optional parameter -- restrict the total number of PS1 source cells to test the 
# performance effects of denser sampling (1000 total PS1 cells are available in this test set).
# note that the number of actual images accessed is a factor of both the number of cells
# and the number of bands under consideration.
# if GALEX fusion is taking place, this will also indirectly
# restrict the number of GALEX images.
MAX_CELL_COUNT = 4
if MAX_CELL_COUNT is not None:
    test_catalog = prune_ps1_catalog(
        small_catalog, sample_table(arbitrary_test_cells, k=MAX_CELL_COUNT)
    )
else:
    test_catalog = small_catalog
targets = sample_table(test_catalog, k=TARGET_COUNT).to_pylist()
ps1_stacks = set(map(get(['proj_cell', 'sky_cell']), targets))
galex_eclipses = {
    e for e in tuple(chain.from_iterable(map(get('galex'), targets)))
    if e in extant_mislike.values
}

In [None]:
# per-loader performance-tuning parameters
# image_chunksize: how many images shall we initialize at once?
# image_threads: how many threads shall we init with in parallel? (None to disable.)
# cut_threads: how many threads shall we cut with in parallel? (None to disable.)
TUNING = {
    "fitsio": {"image_chunksize": 40, "image_threads": 4, "cut_threads": 4},
    "greedy_fitsio": {"image_chunksize": 40, "image_threads": 4, "cut_threads": None},
    "default": {"image_chunksize": 40, "image_threads": 4, "cut_threads": 4},
}

In [None]:
logs = {}
for loader_name, loader in LOADERS.items():
    # remount bucket to avoid "cheating"
    print(f"----testing {loader_name}----")
    mount_bucket(
        backend="goofys", remount=True, mount_path=S3_ROOT, bucket=BUCKET
    )
    tuning_params = TUNING[loader_name] if loader_name in TUNING.keys() else TUNING["default"]
    cuts, logs[loader_name] = get_ps1_cutouts(
        ps1_stacks, 
        loader, 
        targets, 
        CUTOUT_SIDE_LENGTH, 
        f"{S3_ROOT}/ps1", 
        PS1_BANDS,
        verbose=2,
        **tuning_params
    )
    cleanup_loader(loader_name)
    rate, weight = parse_topline(logs[loader_name])
    print(f"{rate} cutouts/s, {weight} MB / cutout")

In [None]:
%%time
watch = Stopwatch()
req_cutouts = {}
for target in targets[:10]:
    req_cutouts[target['obj_id']] = request_ps1_cutout(
        ps1_stack_path(target['proj_cell'], target['sky_cell'], band),
        target['ra'],
        target['dec'],
        CUTOUT_SIDE_LENGTH * 3600,
        "fits"
    )

In [None]:
# to be made into a 

In [None]:
ps1_groups = groupby(get(['proj_cell', 'sky_cell']), targets)
ps1_cutouts = {}
log = {}
for loader_name, loader in LOADERS.items():
    # remount bucket to avoid 'cheating'
    mount_bucket(
        backend="goofys", remount=True, mount_path=S3_ROOT, bucket=BUCKET
    )
    print(f"\n--------testing {loader_name}--------\n")
    outer_stat = print_stats(Stopwatch(silent=True), Netstat())
    for stack in ps1_stacks:
        image_targets = ps1_groups[stack]
        cutouts, _, stack_log = get_ps1_cutouts(
            image_targets, 
            loader,
            PS1_BANDS, 
            CUTOUT_SIDE_LENGTH, 
            f"{S3_ROOT}/ps1",
            verbose=1
        )
        ps1_cutouts |= cutouts
        log |= stack_log
    print(
        f"acquired {len(targets) * len(PS1_BANDS)} cutouts from "
        f"{len(ps1_stacks) * len(PS1_BANDS)} images,{outer_stat()}"
    )
    if DO_GALEX_STUFF is True:

        galex_slices = defaultdict(list)
        systems = {}
        for eclipse in galex_eclipses:
            eclipse_targets = tuple(filter(lambda t: eclipse in t['galex'], targets))
            slices, system = get_galex_rice_slices(
                eclipse, eclipse_targets, CUTOUT_SIDE_LENGTH, S3_ROOT, watch, stat
            )
            systems[eclipse] = system
            for k, v in slices.items():
                galex_slices[k].append(v)
        print(f"acquired GALEX cutouts,{outer_stat()}")
        galex_coadds = {}
        print(
            f"...coadding {len(tuple(chain.from_iterable(galex_slices.values())))} image slices...", 
            end=""
        )
        for obj_id, images in galex_slices.items():
            if len(images) == 0:
                print("all GALEX images for {obj_id} are bad, skipping")
            galex_coadds[obj_id] = coadd_image_slices(images, systems)
        print(f"coadded GALEX cutouts,{outer_stat()}")
        


In [None]:
ps1_groups = groupby(get(['proj_cell', 'sky_cell']), targets)
ps1_cutouts = {}
log = {}
for loader_name, loader in LOADERS.items():
    # remount bucket to avoid 'cheating'
    mount_bucket(
        backend="goofys", remount=True, mount_path=S3_ROOT, bucket=BUCKET
    )
    print(f"\n--------testing {loader_name}--------\n")
    outer_stat = print_stats(Stopwatch(silent=True), Netstat())
    for stack in ps1_stacks:
        image_targets = ps1_groups[stack]
        cutouts, _, stack_log = get_ps1_cutouts(
            image_targets, 
            loader,
            PS1_BANDS, 
            CUTOUT_SIDE_LENGTH, 
            f"{S3_ROOT}/ps1",
            verbose=1
        )
        ps1_cutouts |= cutouts
        log |= stack_log
    print(
        f"acquired {len(targets) * len(PS1_BANDS)} cutouts from "
        f"{len(ps1_stacks) * len(PS1_BANDS)} images,{outer_stat()}"
    )
    if DO_GALEX_STUFF is True:

        galex_slices = defaultdict(list)
        systems = {}
        for eclipse in galex_eclipses:
            eclipse_targets = tuple(filter(lambda t: eclipse in t['galex'], targets))
            slices, system = get_galex_rice_slices(
                eclipse, eclipse_targets, CUTOUT_SIDE_LENGTH, S3_ROOT, watch, stat
            )
            systems[eclipse] = system
            for k, v in slices.items():
                galex_slices[k].append(v)
        print(f"acquired GALEX cutouts,{outer_stat()}")
        galex_coadds = {}
        print(
            f"...coadding {len(tuple(chain.from_iterable(galex_slices.values())))} image slices...", 
            end=""
        )
        for obj_id, images in galex_slices.items():
            if len(images) == 0:
                print("all GALEX images for {obj_id} are bad, skipping")
            galex_coadds[obj_id] = coadd_image_slices(images, systems)
        print(f"coadded GALEX cutouts,{outer_stat()}")
        


In [None]:
from s3_fuse.ps1_utils import request_ps1_cutout, request_ps1_filenames, ps1_stack_path

In [None]:
request_ps1_cutout??

In [None]:
eclipse_targets

In [None]:
eclipse_targets

In [None]:
for band in PS1_BANDS:
    for target in targets:
        filename = 
        cutout = request_ps1_cutout(
        ps1_stack_path(target['proj_cell'], target['sky_cell'], band)    
)