In [1]:
# -*- coding: utf-8 -*-
import numpy as np
import astropy
import astropy.io.fits as fits
from astropy.io.fits import getdata
from astropy.table import Table
from astropy.coordinates import SkyCoord
from astropy import units as u
import healpy as hp
import os
import sys
import parsl
from parsl import python_app, bash_app
import matplotlib.pyplot as plt
from tqdm import tqdm
from time import sleep

parsl.clear()
parsl.load()


<parsl.dataflow.dflow.DataFlowKernel at 0x7f43ede3a4f0>

In [2]:
def join_cats(ipix_cats, output_file, ra_str, dec_str):
    """This function writes a single Fits file catalog gathering
    data read in all the ipix_cats.

    Parameters
    ----------
    ipix_cats : list
        List of ipix FITS files catalogs.
    output_file : str
        File name with all data from ipix cats.
    ra_str : str
        Label of RA coordinate.
    dec_str : str
        Label of DEC coordinate.
    """
    t = Table.read(ipix_cats[0])
    label_columns = t.colnames
    t_format = []
    for i in label_columns:
        t_format.append(t[i].info.dtype)

    for i, j in enumerate(ipix_cats):
        if i == 0:
            data = getdata(j)
            all_data = data
        else:
            data = getdata(j)
            all_data = np.concatenate((all_data, data))

    col = [i for i in range(len(label_columns))]

    for i, j in enumerate(label_columns):
        col[i] = fits.Column(
            name=j, format=t_format[i], array=all_data[label_columns[i]])
    cols = fits.ColDefs([col[i] for i in range(len(label_columns))])
    tbhdu = fits.BinTableHDU.from_columns(cols)
    tbhdu.writeto(output_file, overwrite=True)



def split_files(in_file, ra_str, dec_str, nside, path):
    """This function split a main file into small catalogs, based on
    HealPix ipix files, with nside=nside.

    Parameters
    ----------
    in_file : str
        Input file.
    ra_str : str
        Label of RA coordinate.
    dec_str : str
        Label of DEC coordinate.
    nside : int
        Nside to split small catalogs.
    path : str
        Path to output files.

    Returns
    -------
    list
        Name and path of output files.
    """

    os.system('mkdir -p ' + path)

    data = getdata(in_file)
    t = Table.read(in_file)
    label_columns = t.colnames
    t_format = []
    for i in label_columns:
        t_format.append(t[i].info.dtype)
    HPX = hp.ang2pix(nside, data[ra_str],
                     data[dec_str], nest=True, lonlat=True)

    HPX_un = np.unique(HPX)

    for j in HPX_un:
        cond = (HPX == j)
        data_ = data[cond]
        col = [i for i in range(len(label_columns))]

        for i in range(len(label_columns)):
            col[i] = fits.Column(
                name=label_columns[i], format=t_format[i], array=data_[label_columns[i]])
        cols = fits.ColDefs([col[i] for i in range(len(label_columns))])
        tbhdu = fits.BinTableHDU.from_columns(cols)
        tbhdu.writeto(path + str(j) + '.fits', overwrite=True)

    return [path + str(i) + '.fits' for i in HPX_un]


@python_app
def clean_input_cat(file_name, ra_str, dec_str, nside):
    """ This function removes all the stars that resides in the same ipix with
    nside = nside. This is done to simulate the features of real catalogs based on
    detections from SExtractor, where objects very close to each other are
    interpreted as a single object. That is specially significant to
    stellar clusters, where the stellar crowding in images creates a single
    object in cluster's center, but many star in the periphery.
    This function calculates the counts in each ipix with nside (usually > 2 ** 15)
    and removes ALL of stars that resides in the same pixel.

    Parameters
    ----------
    file_name : str
        Name of input file.
    ra_str : str
        Label for RA coordinate.
    dec_str : str
        Label for DEC coordinate.
    nside : int
        Nside to be populated.
    """

    output_file = file_name.split('.')[0] + '_clean.fits'

    data = getdata(file_name)
    t = Table.read(file_name)
    label_columns = t.colnames
    t_format = []
    for i in label_columns:
        t_format.append(t[i].info.dtype)
    HPX = hp.ang2pix(nside, data[ra_str],
                     data[dec_str], nest=True, lonlat=True)

    HPX_idx_sort = np.argsort(HPX)

    HPX_sort = [HPX[i] for i in HPX_idx_sort]
    data_sort = data[HPX_idx_sort]

    a, HPX_idx = np.unique(HPX_sort, return_inverse=True)

    # original order not preserved!
    HPX_un, HPX_counts = np.unique(HPX_sort, return_counts=True)

    HPX_single_star_pix = [
        j for i, j in enumerate(HPX_un) if HPX_counts[i] < 2]

    data_clean = np.array([data_sort[:][i] for i, j in enumerate(HPX_idx) if
                           HPX_un[j] in HPX_single_star_pix])

    col = [i for i in range(len(label_columns))]

    for i, j in enumerate(label_columns):
        col[i] = fits.Column(
            name=label_columns[i], format=t_format[i], array=data_clean[:, i])
    cols = fits.ColDefs([col[i] for i in range(len(label_columns))])
    tbhdu = fits.BinTableHDU.from_columns(cols)
    tbhdu.writeto(output_file, overwrite=True)


In [3]:
input_cat = 'results/des_mockcat_for_detection.fits'
ra_str, dec_str = 'ra', 'dec'
path_split = 'results/hpx_cats/'
final_cat = 'cat_clean.fits'
nside = 64

ipix_cats = split_files(input_cat, ra_str, dec_str, nside, path_split)


In [4]:
futures = list()

# Cria uma Progressbar (Opcional)
with tqdm(total=len(ipix_cats), file=sys.stdout) as pbar:
    pbar.set_description("Submit Parsls Tasks")

    # Submissão dos Jobs Parsl
    for i in ipix_cats:
        futures.append(
            clean_input_cat(i, ra_str, dec_str, 2 ** 17)
        )

        pbar.update()

# Espera todas as tasks Parsl terminarem
# Este loop fica monitarando as parsl.futures
# Até que todas tenham status done.
# Esse bloco todo é opcional

print("Tasks Done:")
with tqdm(total=len(futures), file=sys.stdout) as pbar2:
    # is_done é um array contendo True ou False para cada task
    # is_done.count(True) retorna a quantidade de tasks que já terminaram.
    is_done = list()
    done_count = 0
    while is_done.count(True) != len(futures):
        is_done = list()
        for f in futures:
            is_done.append(f.done())

        # Só atualiza a pbar se o valor for diferente.
        if is_done.count(True) != done_count:
            done_count = is_done.count(True)
            # Reset é necessário por que a quantidade de iterações
            # é maior que a quantidade de jobs.
            pbar2.reset(total=len(futures))
            # Atualiza a pbar
            pbar2.update(done_count)

        if done_count < len(futures):
            sleep(3)


Submit Parsls Tasks: 100%|██████████| 135/135 [00:00<00:00, 744.19it/s]
Tasks Done:
 21%|██▏       | 29/135 [00:21<01:20,  1.32it/s]


KeyboardInterrupt: 

In [5]:
ipix_clean_cats = [i.split('.')[0] + '_clean.fits' for i in ipix_cats]
join_cats(ipix_clean_cats, final_cat, ra_str, dec_str)

In [None]:
len_ipix = len(ipix_clean_cats)

ipix = [int((i.split('/')[-1]).split('.')[0]) for i in ipix_cats]

ra_cen, dec_cen = hp.pix2ang(nside, ipix, nest=True, lonlat=True)
half_size_plot = 0.01
fig, ax = plt.subplots(len_ipix, 4, figsize=(18, 4 * len_ipix))
j = 0
for i in range(len_ipix):
    line = int(j / 4)
    col = int(j % 4)
    data = fits.getdata(ipix_cats[i])
    RA_orig = data[ra_str]
    DEC_orig = data[dec_str]
    if len(RA_orig[(RA_orig < ra_cen[i] + half_size_plot)&(RA_orig > ra_cen[i] - half_size_plot)&\
                   (DEC_orig < dec_cen[i] + half_size_plot)&(DEC_orig > dec_cen[i] - half_size_plot)]) > 10.:
        data = fits.getdata(ipix_clean_cats[i])
        RA = data[ra_str]
        DEC = data[dec_str]
        ax[line, col].scatter(RA_orig, DEC_orig, edgecolor='b', color='None', s=20)
        ax[line, col].set_xlim([ra_cen[i] + half_size_plot, ra_cen[i] - half_size_plot])
        ax[line, col].set_ylim([dec_cen[i] - half_size_plot, dec_cen[i] + half_size_plot])
        ax[line, col].scatter(RA, DEC, color='r', s=2)
        ax[line, col].set_xlim([ra_cen[i] + half_size_plot, ra_cen[i] - half_size_plot])
        ax[line, col].set_ylim([dec_cen[i] - half_size_plot, dec_cen[i] + half_size_plot])
        ax[line, col].set_xticks([])
        ax[line, col].set_yticks([])
        ax[line, col].set_title(str(ipix[i]), x=0.5, y=0.6, fontsize=8)
        j += 1
plt.suptitle('Blue: original, Red: filtered stars; Each poststamp has {:.2f} arcmin'.format(2. * 60. * half_size_plot))
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
