In [None]:
%%configure -f
{"name": "arik", "executorMemory": "8G", "numExecutors": 4, "executorCores": 2,
 "conf": {"spark.yarn.appMasterEnv.PYSPARK_PYTHON":"python3"}}

In [None]:
import sys
import subprocess
from io import BytesIO
from gzip import GzipFile
from pyspark.sql import Row

stdout = subprocess.check_output(
        sys.executable + ' -m pip install astropy',
        stderr=subprocess.STDOUT,
        shell=True).decode('utf-8')

from astropy.io import fits

In [None]:
def getfits(compressed_data):
    decomp = GzipFile(fileobj=BytesIO(compressed_data)).read()
    fits_obj = fits.open(BytesIO(decomp))
    return fits_obj

# Spark dataframe cannot do numpy types
def typeconv(i):
    try:
        return i.item()
    except:
        return i

def headerDict(fits):
    return dict(fits[0].header.items())

def createFiberExposureRows(fits):
    obsinfo = fits['OBSINFO']
    obsinfo_columns = [i.name for i in obsinfo.columns]
    n_exposures = len(obsinfo.data)
    n_fibers = int(len(fits['FLUX'].data)/n_exposures)
    rows = []
    for exp in range(n_exposures):
        expinfo = dict(zip(obsinfo_columns, [typeconv(i) for i in obsinfo.data[exp]]))
        for fiber in range(n_fibers):
            ind = exp*n_fibers + fiber
            row = headerDict(fits)
            row.update(expinfo)
            row['EXPOSURE_INDEX'] = exp
            row['FIBER_INDEX'] = fiber
            # Per fiber/exposure data
            for unit in ['FLUX', 'XPOS', 'YPOS', 'IVAR', 'MASK', 'DISP']:
                row[unit] = fits[unit].data[ind].tolist()
            # references all spectra
            for unit in ['WAVE', 'SPECRES', 'SPECRESD']:
                row[unit] = fits[unit].data.tolist()
            # for convenience, store the mean fiber positions
            row['XPOS_MEAN'] = fits['XPOS'].data[ind].mean().tolist()
            row['YPOS_MEAN'] = fits['YPOS'].data[ind].mean().tolist()
            rows.append(Row(**row))
    return rows

In [None]:
manga_fits = sc.binaryFiles('hdfs:///manga/fits/').mapValues(getfits)

In [None]:
tabledata = manga_fits.flatMap(lambda x: createFiberExposureRows(x[1]))

In [None]:
df = spark.createDataFrame(tabledata)

In [None]:
df.write.parquet('hdfs:///manga/arik-test/flux')

In [None]:
spark.sql('SET spark.sql.parquet.compression.codec = gzip')
df.write.parquet('hdfs:///manga/arik-test/flux-gz')

In [None]:
df.cache().createOrReplaceTempView('flux')

In [None]:
%%sql
SELECT SPEC.WAVE as wl, avg(SPEC.FLUX) AS ifu_avg_flux FROM (
    SELECT explode_outer(arrays_zip(WAVE, FLUX)) AS SPEC
    FROM flux 
    WHERE MANGAID='1-115062' AND EXPOSURE_INDEX=0
) 
GROUP BY SPEC.WAVE
ORDER BY SPEC.WAVE ASC

In [None]:
from matplotlib import pyplot as plt
query = '''
SELECT XPOS_MEAN+IFURA AS RA, YPOS_MEAN+IFUDEC AS DEC, array_max(FLUX) as FLUXMN
FROM flux ORDER BY FLUXMN ASC
'''
res = spark.sql(query).toPandas()
fig, ax = plt.subplots(1,1)
plt.scatter(res['RA'], res['DEC'], c=res['FLUXMN'], alpha=0.05, marker='H', cmap='hsv', s=100)

In [None]:
%matplot plt