In [None]:
import os
import pickle
import re
import shutil
import time
import warnings
from itertools import chain
from pathlib import Path

import killscreen.shortcuts as ks
from cytoolz.curried import get, get_in, groupby, keyfilter, merge
from killscreen import subutils
from killscreen.aws import ec2
from killscreen.monitors import make_monitors
from more_itertools import distribute
from pyarrow import parquet

# hacky; can remove if we decide to add an install script or put this in the repo root
os.chdir(globals()['_dh'][0].parent)

from subset.science.handlers import (
    filter_ps1_catalog, sample_ps1_catalog, get_corresponding_images
)
from subset.utilz.generic import parse_topline
from subset.utilz.mount_s3 import mount_bucket

# suppress irrelevant warnings from matplotlib
warnings.filterwarnings("ignore", message="More than 20 figures")

%matplotlib notebook

## configuration

In [None]:
# username on worker-node instances
UNAME = "ubuntu"
# path on the local filesystem we'll use to collect cutouts from worker nodes
DUMP_PATH = '/home/ubuntu/.slice_test/'
os.makedirs(DUMP_PATH, exist_ok=True)
# where, on the local filesystem, shall we mount that bucket
S3_ROOT = '/home/ubuntu/s3'
if not os.path.exists(S3_ROOT):
    os.mkdir(S3_ROOT)
# bucket (meta)data is staged in
BUCKET="nishapur"
# name of launch template (not included); assumes that Name tag == template name
LAUNCH_TEMPLATE = 'fornax-slice'
# mount bucket to fetch metadata
mount_bucket(backend="goofys", mount_path=S3_ROOT, bucket=BUCKET)
# catalog of all mean objects from 1000 PS1 sky cells randomly selected from
# "extragalactic" cells that overlap the viewports of GALEX visits, then filtered
# to the "best" objects (qualityFlag bit 0b100000) with valid photometry in both
# g and z bands (this filter leaves roughly 3% of total sources). other
# similarly-formatted catalog files can be used.
CATALOG_FN = "ps1_eg_eclipses_subset_best_gz_coregistered.parquet"
if not Path(CATALOG_FN).exists():
    shutil.copy(Path(S3_ROOT, "ps1/metadata", CATALOG_FN), Path(CATALOG_FN))
catalog = parquet.read_table(CATALOG_FN).to_pandas()
# cutouts in dimensions: ra, dec in degrees. treated as side lengths of a rectangle.
CUT_SHAPE = (60 / 3600, 60 / 3600)
# restrict to sources bright in both g and z? set to 'None' for no cutoff.
MAG_CUTOFF = 20
# restrict to only sources flagged as extended / not extended?
# "extended", "point", or None for no restriction
EXTENSION_TYPE = "extended"
# restrict to only sources with a valid stack detection? (probably a good idea)
STACK_ONLY = True
# how many targets shall we randomly select?
TARGET_COUNT = 50
# optional parameter -- restrict the total number of PS1 source cells to test the
# performance effects of denser sampling. 1000 total cells are available in this test set.
# note that the total number of images accessed is number of cells * number of bands.
MAX_CELL_COUNT = 80

In [None]:
# initialize a killscreen Cluster
descriptions = ec2.ls_instances(name=LAUNCH_TEMPLATE)
# ...either from already-running EC2 instances...
if len(descriptions) == 0:
    cluster = ec2.Cluster.launch(
        count=4,
        template=LAUNCH_TEMPLATE,
        uname=UNAME,
        # 'private' because we'll be talking to them from inside AWS
        use_private_ip=True
    )
# ...or from a new fleet request.
else:
    cluster = ec2.Cluster.from_descriptions(
        descriptions, uname=UNAME, use_private_ip=True
    )
    cluster.start()
    [instance.wait_until_running() for instance in cluster.instances]
    cluster.add_public_keys()
    print("\n".join([str(i) for i in cluster.instances]))

In [None]:
# freshen these instances in case we've made code changes
def git_update(*repo_names):
    return ks.chain(
        [f"cd {repo}; git clean -d -fx; git pull & cd ~" for repo in repo_names], "and"
    )
update = git_update("fornax-s3-subsets", "killscreen", "gphoton_working")
updaters = cluster.command(update, _bg=True)
# find script / interpreter on remote instances
env = cluster.instances[0].conda_env("fornax_section")
python = f"{env}/bin/python"
endpoint = "/home/ubuntu/fornax-s3-subsets/subset/ps1_cutout_endpoint.py"

## target selection
the next cell picks a random sample of targets that satisfy the parameters defined above.
you can run it again to 'reroll' and pick a new set of targets.

In [None]:
# all sources that fit characteristic criteria
candidate_sources = filter_ps1_catalog(catalog, MAG_CUTOFF, EXTENSION_TYPE, STACK_ONLY)
# randomly-selected subset of those sources w/adequate metadata for cutout definition
targets = sample_ps1_catalog(candidate_sources, TARGET_COUNT, MAX_CELL_COUNT)
# prune irrelevant fields from targets
interesting_fields = ('obj_id', 'proj_cell', 'sky_cell', 'ra', 'dec')
targets = [keyfilter(lambda k: k in interesting_fields, t) for t in targets]
# add requested cut shape instructions to these target definitions
targets = [t | {'ra_x': CUT_SHAPE[0], 'dec_x': CUT_SHAPE[1]} for t in targets]
# make lists of the ps1 stack images these sources lie within
# (so that we can easily initialize each relevant image only once)
ps1_stacks, _ = get_corresponding_images(targets)
# split these into chunks of work, making sure that all targets within
# a single cell / image are assigned to the same instance --
target_groups = groupby(get(['proj_cell', 'sky_cell']), targets)
# this is a simple heuristic to distribute work evenly, given the above constraint:
groups = sorted(target_groups.values(), key=lambda v: 1 / len(v))
work_chunks = [
    tuple(chain.from_iterable(chunk))
    for chunk in distribute(len(cluster.instances), groups)
]

In [None]:
# simple process join function
def wait_on(processes, polling_delay=0.1):
    while any([p.is_alive() for p in processes]):
        time.sleep(polling_delay)

# when a remote process is done, grab the files from that instance
getters = []
def grab_when_done(process, *_):
    print(f"{process.host.ip} done; getting files")
    getter = process.host.get(f"{DUMP_PATH}*", DUMP_PATH, _bg=True)
    getters.append(getter)
        
# delete everything local so as to avoid confusion
subutils.run(f"rm {DUMP_PATH}/* &")

# set up some basic benchmarking...
stat, note = make_monitors(silent=True)
# ...and initiate the remote processes
remote_processes = []
for chunk, instance in zip(work_chunks, cluster.instances):
    command = f"{python} {endpoint} '{chunk}'"
    viewer = instance.command(
        command, _bg=True, _viewer=True, _done=grab_when_done
    )
    remote_processes.append(viewer)
wait_on(remote_processes)
note(roundstring(f"remote processes completed,{stat(simple_cpu=True)}"), True)
wait_on(getters)
note(roundstring(f"cleaned up files from remotes,{stat(simple_cpu=True)}"), True)

retrieved_dumps = os.listdir(DUMP_PATH)

cutfiles = tuple(filter(lambda f: f.endswith("pkl"), retrieved_dumps))
note(roundstring(f"got {len(targets) * 2} cuts,{stat(total=True, simple_cpu=True)}"), True)
log = note(None, eject=True)
rate, weight = parse_topline(log)
print(f"{rate} cutouts/s, {weight} MB / cutout (local only)")

# cleanup cached arrays on remotes
deletions = cluster.command(f"rm {DUMP_PATH}/* &", _bg=True)

In [None]:
# should you like: examine logs from remotes...
import pandas as pd
logs = []
for logfile in filter(lambda f: f.endswith("csv"), retrieved_dumps):
    remote_log = pd.read_csv(Path(DUMP_PATH, logfile))
    remote_log["host"] = re.search(
        r"(?<=ip_)(\d+_){4}", logfile
    ).group(0)[:-1]
    logs.append(remote_log)
logs = pd.concat(logs)
logs.columns = ["timestamp", "event", "duration", "volume", "cpu", "host"]
logs.sort_values(by=["host", "timestamp"])

In [None]:
# ...or your winnings
cuts = []
for file in cutfiles:
    with open(Path(DUMP_PATH, file), "rb") as stream:
        cuts.append(pickle.load(stream))
arrays = tuple(map(get_in(['arrays', 0]), chain.from_iterable(cuts)))

In [None]:
from random import choice
import matplotlib.pyplot as plt
import numpy as np

fig, grid = plt.subplot_mosaic(np.arange(9).reshape(3,3))
plt.close()
for ax in grid.values():
    ax.set_axis_off()

for ix in grid.keys():
    array = arrays[choice(range(len(arrays)))]
    clipped = np.clip(array, *np.percentile(array, (1, 99)))
    grid[ix].imshow(clipped, cmap='autumn')
    
fig

In [None]:
# destroy the cluster if you are done with it
cluster.terminate()