In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import os
import sys
import urllib
import time
import glob
import shutil
import warnings

import astropy
from astropy import wcs
from astropy.nddata import Cutout2D
from astropy import units as u

from collections import namedtuple
import multiprocessing




In [None]:
# my home-written modules
import image_helpers

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['savefig.dpi'] = 80*2
plt.rcParams['figure.dpi'] = 80*2
plt.rcParams['figure.figsize'] = np.array((10,6))*.5
plt.rcParams['figure.facecolor'] = "white"

In [None]:
data_dir = "data"

# Read in data

In [None]:
df = pd.read_csv(os.path.join(data_dir, "matched_galaxies.csv"))
df = df.set_index("SpecObjID")
print(df.shape)
df.head()

# Figure out which fields we need to download

In [None]:
Identifier = namedtuple("Identifier", 
                        [
#                             "rerun",
                            "run", 
                            "camcol",
                            "field",
                        ])



In [None]:
frame_id_to_galaxy_ids = {}

for i, row in df.iterrows():
    identifier = Identifier(
#                   row.rerun, 
                  row.run,
                  row.camcol,
                  row.field,
                 )
    
    if identifier in frame_id_to_galaxy_ids:
        frame_id_to_galaxy_ids[identifier] |= set((row.name, ))
    else:
        frame_id_to_galaxy_ids[identifier] = set((row.name, ))

In [None]:
frame_id_to_galaxy_ids

In [None]:
len_list = [len(frame_id_to_galaxy_ids[key]) for key in frame_id_to_galaxy_ids]

In [None]:
len(len_list)

In [None]:
sum(len_list)

In [None]:
len_list

# Get galaxy images via Globus
More overhead to setup, but better for bulk transfers.

You'll need to sign up for globus, install it on your local machine (both the desktop application and the Command Line Interface), and get properly logged into the CLI.

Also, you'll need to change the destination endpoint below with your personal endpoint 

In [None]:
source_endpoint = "db57ddf2-6d04-11e5-ba46-22000b92c6ec"
destination_endpoint = "9278d8fe-e7b4-11e8-8c9c-0a1d4c5c824a"


In [None]:
# store files here after downloading and while making cutouts
# but then transfer to the long-term storage at `raw_data_dir`
short_term_raw_data_dir = "/Users/egentry/test_globus/"

globus_path = "/uufs/chpc.utah.edu/common/home/sdss/dr14/eboss/photoObj/frames/301/{run}/{camcol}/frame-{band}-{run:>06d}-{camcol}-{field:>04d}.fits.bz2"


images_per_batch = 10000

In [None]:
assert(os.path.exists(short_term_raw_data_dir)) # make sure the drive is mounted
assert(os.path.exists(image_helpers.raw_data_dir)) # make sure the drive is mounted

filename_format = os.path.join(short_term_raw_data_dir, "file_list_{:>03d}")
dirname_format = os.path.join(short_term_raw_data_dir, "{}", "")
batch_counter = 0
image_counter = 0
f = open(filename_format.format(batch_counter), "w")
print("# SOURCE_PATH DEST_PATH", file=f)
dirname = dirname_format.format(batch_counter)

for i, key in enumerate(frame_id_to_galaxy_ids):
# for i, key in ((0, Identifier(3325,2,15)), (1, Identifier(3325,2,16))):

    for band in image_helpers.bands:
    
        remote_path = globus_path.format(
            band=band, **key._asdict(),
        )
        filename = os.path.split(remote_path)[-1]
        if os.path.exists(os.path.join(image_helpers.raw_data_dir, filename)):
            continue
        
        # filter out files which don't actually exist (but give a warning)
        url_base = "http://data.sdss.org/sas/dr14/eboss/photoObj/frames/301/{run}/{camcol}/frame-{band}-{run:>06d}-{camcol}-{field:>04d}.fits.bz2"
        http_url = url_base.format(run=key.run, camcol=key.camcol, field=key.field, band=band,)
        try:
            urllib.request.urlopen(urllib.request.Request(http_url, method="HEAD"))
        except urllib.error.HTTPError as e:
            if e.code == 404:
                warnings.warn("Missing remote image file: {}-{}-{}-{}".format(
                    key.run, key.camcol, key.field, band,
                    ))
                continue
            else:
                raise e
            
        if image_counter == images_per_batch:
            f.close()
            batch_counter += 1
            image_counter = 0
#             if batch_counter > 10:
#                 assert(False) # cutoff early; that's more than enough for now
            f = open(filename_format.format(batch_counter), "w")
            print("# SOURCE_PATH DEST_PATH", file=f)
            dirname = dirname_format.format(batch_counter)
            
        local_path = os.path.join(dirname, filename)

        print(remote_path, local_path, file=f)
        
        image_counter += 1

f.close()

### Globus CLI calls

It's probably best to run these manually, rather than programmatically, since I don't want to have to worry about waiting until each particular batch is complete:

```
BATCH="000" && globus transfer  db57ddf2-6d04-11e5-ba46-22000b92c6ec 9278d8fe-e7b4-11e8-8c9c-0a1d4c5c824a --preserve-mtime --label=batch_${BATCH} --batch < file_list_${BATCH}
```

(If you do want to wait, you could try [`globus task wait`](https://docs.globus.org/cli/reference/task_wait/), or see if that exists in the python SDK.

Remember 10k images is about 30 GB.

## Now get cutouts of a batch

In [None]:
copy_raw_image_after_cutout = True
delete_old_raw_image = True

In [None]:
batches = [ 
    "33", "34", "35", 
]

for batch in batches:
    print("starting batch", batch, flush=True)

    batch_raw_data_dir = os.path.join(short_term_raw_data_dir,
                                      batch)
    if not os.path.exists(batch_raw_data_dir):
        print("stopping because batch {} doesn't exist".format(batch))
        break

    filenames = glob.glob(os.path.join(batch_raw_data_dir, 
                                       "*.fits.bz2"))
    
    if len(filenames) != images_per_batch:
        # note: the very last batch will have less files,
        # and you'll need to add the proper `if` statement here
        print("stopping because batch {} doesn't contain enough files".format(
        batch))
        break

    def create_cutout(filename):
        hdu = astropy.io.fits.open(filename)[0]
        run = hdu.header["RUN"]
        camcol = hdu.header["CAMCOL"]
        field = int(os.path.split(filename)[-1]
                           .split(".")[0]
                           .split("-")[-1])
        band = hdu.header["FILTER"]

        identifier = Identifier(run, camcol, field)
        old_header = hdu.header.copy()

        galaxies_in_frame = frame_id_to_galaxy_ids[identifier]
        for galaxy_id in galaxies_in_frame:
            cutout_filename = image_helpers.get_cutout_filename(
                galaxy_id, band, data_dir=data_dir
            )

            try:
                cutout_hdu = image_helpers.get_cutout(hdu,
                                        df.loc[galaxy_id].ra,
                                        df.loc[galaxy_id].dec,
                                       )
            except RuntimeError as e:
                print("Problematic galaxy id = {}".format(galaxy_id))
                print("Problematic image = {}".format(filename))
                raise e


            parent = os.path.split(cutout_filename)[0]
            if not os.path.exists(parent):
                os.makedirs(parent)
            cutout_hdu.writeto(cutout_filename, overwrite=True)

        if copy_raw_image_after_cutout:
            basename = os.path.split(filename)[-1]
            shutil.copy2(filename,
                      os.path.join(image_helpers.raw_data_dir, basename),
                     )

            if delete_old_raw_image:
                os.remove(filename)

    with multiprocessing.Pool() as pool:
        pool.map(create_cutout, filenames)

    if delete_old_raw_image:
        if len(glob.glob(os.path.join(batch_raw_data_dir, "*")))==0:
            os.rmdir(batch_raw_data_dir)