In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import os
import sys

import astropy
from astropy import wcs
from astropy.nddata import Cutout2D
from astropy import units as u

from collections import namedtuple


import glob


In [None]:
# my home-written modules
import image_helpers

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
plt.rcParams['savefig.dpi'] = 80*2
plt.rcParams['figure.dpi'] = 80*2
plt.rcParams['figure.figsize'] = np.array((10,6))*.5
plt.rcParams['figure.facecolor'] = "white"

In [None]:
data_dir = image_helpers.data_dir


# Load data

In [None]:
df = pd.read_csv(os.path.join(data_dir, "matched_galaxies.csv"))
df = df.set_index("SpecObjID")
print(df.shape)
df.head()

In [None]:
dirnames = glob.glob(os.path.join(data_dir, "images", "cutout", "*"))
ids_with_images = [int(os.path.split(dirname)[-1]) 
             for dirname in dirnames]

filename_format = os.path.join(data_dir, "images", "cutout", "{0}", "{0}-*.fits")
has_5_bands = lambda id: len(glob.glob(filename_format.format(id))) == 5

ids_with_images = [i for i in ids_with_images
                   if has_5_bands(i)]

In [None]:
filename_format.format(ids_with_images[0])

In [None]:
ids_with_images

In [None]:
len(ids_with_images)

In [None]:
def load_as_array(galaxy_id):
    data_list = [None]*len(image_helpers.bands)
    for i, band in enumerate(image_helpers.bands):
        cutout_filename = image_helpers.get_cutout_filename(
            galaxy_id, band
        )
        f = astropy.io.fits.open(cutout_filename)
        data_list[i] = f[0].data
    
    return data_list
        


In [None]:
galaxy_id = 75094093037830144
a = np.array(load_as_array(galaxy_id))

In [None]:
df.loc[[galaxy_id]][["run", "camcol", "field"]]

In [None]:
def sdss_stretch(data, u_a=np.exp(6.), u_b = 0.05): 
    """
    adapted from: https://hsc-gitlab.mtk.nao.ac.jp/snippets/23#L215

    u_a and u_b depend on the dataset and tuned by hand.
    basically, u_a helps you shift where you get the transition
    between the linear and logarithmic behaviors
    bimodalities. u_b is basically a bias
    
    My rules of thumb are:
    1) Using the image, choose a sky value, u_b such that if you
       went any higher, you'd start to lose ~more galaxy pixels
       than background pixels.
    2) Using the histogram, choose a softening parameter, u_a,
       such that your two populations of pixels (background and target)
       have some overlap around 0, but not too much. Basically, just
       make it representative of your uncertainty whether the pixels
       at 0 are foreground or background.
    """
    data -= u_b
    data = np.arcsinh(u_a * (data)) / np.arcsinh(u_a)
    data += u_b
    return data



In [None]:
import ipywidgets
@ipywidgets.interact(ith_galaxy=ipywidgets.IntSlider(min=0, max=50),
                     u_a=ipywidgets.FloatSlider(min=1,max=10, value=6),
                     u_b=ipywidgets.FloatSlider(min=-.1, max=.1,
                                                step=.01, value=0.05),
                    show_hist = ipywidgets.Checkbox())
def tmp(ith_galaxy, u_a, u_b, show_hist):
    a = np.array(load_as_array(sorted(ids_with_images)[ith_galaxy]))
    a = sdss_stretch(a, np.exp(u_a), u_b)
    b = a.copy()
    a[a<0] = 0
    a[a>1] = 1
    plt.imshow(a.swapaxes(0,2)[:,:,(3,2,1)])
    if show_hist:
        plt.figure()
        for i in range(b.shape[0]):
            plt.hist(b[i].flatten(), histtype="step", label=image_helpers.bands[i])
        plt.legend()


# Scale images, combine bands, save to `.npy`

In [None]:
processed_filename_format = os.path.join(
    data_dir,
    "images",
    "processed",
    "{galaxy_id}.npy"
)

processed_dir = os.path.split(processed_filename_format)[0]
if not os.path.exists(processed_dir):
    os.makedirs(processed_dir)

In [None]:
for i, galaxy_id in enumerate(ids_with_images):
    if(i%(len(ids_with_images)//10) == 0):
        print("i = {:>6d}".format(i))
        sys.stdout.flush()
    output_filename = processed_filename_format.format(galaxy_id=galaxy_id)
    if os.path.exists(output_filename):
        continue
        
    img = load_as_array(galaxy_id)
    if not np.all(np.array([img_i.shape for img_i in img]) == 95):
        # image not the correct shape - at least one band must have been near an edge
        continue
    img = np.array(img)
    img = sdss_stretch(img)
    np.save(output_filename, img)
    
