Setup the spark session. Switch to something like 26G, 8 exec, 4 cores per for prod.

In [1]:
%%configure -f
{"name": "arik-load-logrss", "executorMemory": "26G", "numExecutors": 8, "executorCores": 2,
 "conf": {"spark.yarn.appMasterEnv.PYSPARK_PYTHON":"python3"}}

In [2]:
import sys
import os
import subprocess
from io import BytesIO
from gzip import GzipFile
from pyspark.sql import Row
import glob
import time

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
37,application_1586890731024_0036,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
def get_fits_module():
    if 'astropy' not in sys.modules:
        stdout = subprocess.check_output(
            sys.executable + ' -m pip install astropy',
            stderr=subprocess.STDOUT,
            shell=True).decode('utf-8')
    from astropy.io import fits
    return fits

def get_wcs_module():
    if 'astropy' not in sys.modules:
        stdout = subprocess.check_output(
            sys.executable + ' -m pip install astropy',
            stderr=subprocess.STDOUT,
            shell=True).decode('utf-8')
    from astropy import wcs
    return wcs

def getfitslocal(path):
    fits = get_fits_module()
    fits_obj = fits.open(path)
    return fits_obj

# Spark dataframe cannot do numpy types
def typeconv(i):
    try:
        return i.item()
    except:
        return i

def createSpaxelSpectrumRows(fits):
    # Scan through X, Y pixel space of flux cube, discard empty spectra
    wcs = get_wcs_module()
    wc = wcs.WCS(fits['FLUX'].header)
    n_wav, n_y, n_x = fits['FLUX'].data.shape
    rows = []
    for x in range(n_x):
        for y in range(n_y):
            spec = fits['FLUX'].data[:, y, x]
            if spec.sum() == 0:
                continue
            ra, dec, _ = wc.all_pix2world([[x, y, 0]], 0)[0]
            row = {
                'RA': typeconv(ra),
                'DEC': typeconv(dec),
                '_SRC_X': x,
                '_SRC_Y': y,
            }
            # metadata, could be more relevant items, but for obsinfo I think we need to refer to 
            # another table anyway, so we need uniquely identifying items here, and might also just 
            # keep a particularly set of handy things (e.g. frequently queried)
            for h in ['PLATEID', 'IFUDSGN', 'DESIGNID', 'MANGAID', 'MJDRED', 'DATE-OBS']:
                row[h] = typeconv(fits['FLUX'].header[h])
            # Cube units
            for unit in ['FLUX', 'IVAR', 'MASK', 'DISP', 'PREDISP']:
                row[unit] = fits[unit].data[:, y, x].tolist()
            # Spectral axis units
            for unit in ['WAVE', 'SPECRES', 'SPECRESD', 'PRESPECRES', 'PRESPECRESD']:
                row[unit] = fits[unit].data.tolist()
            # broadband images, one value per spaxel
            for unit in ['G', 'R', 'I', 'Z']:
                row[unit] = fits[unit + 'IMG'].data[y, x].tolist()
            rows.append(Row(**row))
    return rows

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Read local fits to parquet
Here we scan all directories that have a stack subdir and identify all LOGRSS files within them. Each file will be given it's own partition, so first we scan for the total dataset size including number of files, sum of filesize and max filesize (this will determine how much memory tasks need).

In [4]:
base_dir = '/sciserver/vc/manga/vc/sas/dr15/manga/spectro/redux/v2_4_3/'
stack_dir = lambda x: base_dir + x + '/stack'
rss_dirs = [stack_dir(i) for i in
            os.listdir(base_dir)
            if os.path.isdir(stack_dir(i))
           ]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
files_rdd = sc.parallelize(rss_dirs).flatMap(lambda d: glob.glob(d+'/*-LOGCUBE.fits.gz'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
files_stats = files_rdd.map(
    lambda x: (os.stat(x).st_size/1024/1024, 1, os.stat(x).st_size/1024/1024)
).reduce(
    lambda x,y: (x[0]+y[0], x[1]+y[1], max(x[2],y[2]))
)
print('N files: {1}. Total Size: {0:0.0f}MB, average size: {2:0.0f}'.format(*files_stats))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

N files: 4857. Total Size: 597114MB, average size: 214

In [7]:
n_part = int(files_stats[1])
table_data = files_rdd.repartition(n_part).map(getfitslocal).flatMap(createSpaxelSpectrumRows)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### WARNING
This is a big job. In this data format we end up blowing up the size due to some repeated columns that do not benefit well from RLE (for example, the array columns). It ends up creating about 1.3TB of parquet files, which are much larger in memory and with column format this needs to be stored in RAM of tasks. If there are problems, it is quite likely not enough memory is allocated.

In [8]:
t = time.time()
hdfs_dir = 'hdfs:///manga/arik-test/dr15/v2_4_3/logcube'
table = spark.createDataFrame(table_data)
table.write.mode('overwrite').parquet(hdfs_dir)
print(time.time()-t)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

22478.775051355362