In [1]:
import numpy as np
import astropy
import astropy.io.fits as fits
from astropy.io.fits import getdata
from astropy.table import Table
import healpy as hp
from astropy.coordinates import SkyCoord
from astropy import units as u
import os
import parsl
from parsl import python_app, bash_app
import matplotlib.pyplot as plt

parsl.clear()
parsl.load()

<parsl.dataflow.dflow.DataFlowKernel at 0x7f69c8072ca0>

In [2]:
def join_cat(data, label_coluns, t_format, outfile):

    col0 = fits.Column(name=label_coluns[0], format=t_format[0], array=data[:,0])
    col1 = fits.Column(name=label_coluns[1], format=t_format[1], array=data[:,1])
    col2 = fits.Column(name=label_coluns[2], format=t_format[2], array=data[:,2])
    col3 = fits.Column(name=label_coluns[3], format=t_format[3], array=data[:,3])
    col4 = fits.Column(name=label_coluns[4], format=t_format[4], array=data[:,4])
    col5 = fits.Column(name=label_coluns[5], format=t_format[5], array=data[:,5])
    col6 = fits.Column(name=label_coluns[6], format=t_format[6], array=data[:,6])
    col7 = fits.Column(name=label_coluns[7], format=t_format[7], array=data[:,7])
    cols = fits.ColDefs([col0, col1, col2, col3, col4, col5, col6, col7])
    tbhdu = fits.BinTableHDU.from_columns(cols)
    tbhdu.writeto(outfile, overwrite=True)
    

@python_app
def clean_cat(input_cat, max_sep_arcsec, label_columns, in_path):
    ra1, dec1 = np.loadtxt(str(input_cat) + '.dat', usecols=(1,2), unpack=True)
    data = np.loadtxt(input_cat, unpack=True)
    cat1 = SkyCoord(ra=ra1*u.degree, dec=dec1*u.degree)  
    max_sep = 1.0 * u.arcsec
    idx, d2d, d3d = cat1.match_to_catalog_3d(cat1)
    d2d_s = np.sort(d2d)
    idx = idx[d2d_s[1] > max_sep]
    f = open(in_path + '/' + str(input_cat) + '_clean.dat', "w")
    for i in idx:
        print(' '.join([str(data[i][j]) for j in range(len(label_columns))]), file=f)
    f.close()

@python_app
def split_files(HPX_un, data_sort, label_columns, HPX_sort, in_path):
    for i, j in enumerate(HPX_un):
        print(float(i) / len(HPX_un))
        for ii, jj in enumerate(HPX_sort):
            if jj == j:
                f = open(in_path + '/' +  str(j) + '.dat', "a")
                print(' '.join([str(data_sort[ii][jjj]) for jjj in range(len(label_columns))]), file=f)
                f.close()
    return HPX_un

def read_input_cat(file_name, ra_str, dec_str, nside):

    data = getdata(file_name)
    t = Table.read(file_name)
    label_columns = t.colnames
    t_format = []
    for i in label_columns:
        t_format.append(t[i].info.dtype)
    HPX = hp.ang2pix(nside, data[ra_str], data[dec_str], nest=True, lonlat=True)
    
    HPX_idx_sort = np.argsort(HPX)
    
    HPX_sort = [HPX[i] for i in HPX_idx_sort]
    data_sort = data[HPX_idx_sort]
    
    a, HPX_idx = np.unique(HPX_sort, return_inverse=True)
    
    HPX_un, HPX_counts = np.unique(HPX_sort, return_counts=True) # original order not preserved!
    
    return HPX_un, data_sort, label_columns, HPX_sort, t_format, HPX_idx, HPX_counts

In [None]:
HPX_un, data_sort, label_columns, HPX_sort, t_format, HPX_idx, HPX_counts = \
read_input_cat('results/des_mockcat_for_detection.fits', 'ra', 'dec', 2**17)

HPX_single_star_pix = [j for i,j in enumerate(HPX_un) if HPX_counts[i] < 2]

data_clean = np.array([data_sort[:][i] for i,j in enumerate(HPX_idx) if HPX_un[j] in HPX_single_star_pix])
print(data_clean, len(data_sort['ra']))

#print(data_clean[0:20,1])
plt.scatter(data_sort['ra'], data_sort['dec'], c='b', s=10)
plt.scatter(data_clean[0:10], data_clean[0:10], c='r', s=3)
plt.show()

In [None]:
join_cat(data_clean, label_columns, t_format, 'cat_clean.fits')

In [None]:
#TODO: make a few plots showing the results