In [1]:
import os
import sys
sys.path.insert(0, '../gofher')

import itertools
from collections import defaultdict

from sparcfire import get_gofher_params_for_fixed_ref_band, read_sparcfire_galaxy_csv
from galaxy import galaxy, construct_band_pair_key
from sdss import SDSS_BANDS_IN_ORDER, visualize_sdss, create_sdss_csv
from spin_parity import read_spin_parity_galaxies_label_from_csv
from gofher import run_gofher_on_galaxy_with_fixed_gofher_parameters, run_gofher_on_galaxy_with_fixed_center_only, run_gofher_on_galaxy_with_sparcfire_center_inital_guess
from file_helper import write_csv

In [2]:
blur_sdss_fits_folder = "E:\\grad_school\\research\\spin_parity_blurring\\sdss_output"
blur_sdss_folder = "E:\\grad_school\\research\\spin_parity_blurring\\sparcfire_sdss_output"

dark_side_path = "C:\\Users\\school\\Desktop\\github\\spin-parity-catalog\\table_info\\csv_format_of_table\\"

fixed_ref_band_folder = "C:\\Users\\school\\Desktop\\cross_id\\sdss_mosaic_construction\\sdss_ref_band"
csv_base_output_dir = "E:\\grad_school\\research\\spin_parity_blurring\\gofher_sdss_output_stats"

In [3]:
generate_visualization = False
generate_csv = True

In [4]:
#run_types = ['inital_guess'] #,'fixed_center' rs doesn't matter
#run_types = ['sparcfire']
#bulge_disk_rs = [0.125]

run_types = ['sparcfire','inital_guess'] #,'fixed_center' rs doesn't matter
bulge_disk_rs = [0.5,0.25,0.125]
#run_types = ['fixed_center']
#bulge_disk_rs = [0.5]

In [5]:
#blurring params to run on:
sn_list = [8, 16, 32, 64, 128, 256]
psf_list = [4.0,5.6, 8.0, 11.3, 16.0, 22.6, 32.0, 45.2, 64.0, 90.5, 128.0]

datasets = {"train":["table2"],"test":["table4","table5"],"eval":["table3"]}
table_list = sum(datasets.values(), []) #flatten list of list of values

In [6]:
def get_blur_folder(the_sn,the_psf):
    return "psf_{}_background_{}".format(str(the_psf),str(the_sn))

def get_fits_path(the_sn,the_psf):
    return lambda table_name,name,band: os.path.join(blur_sdss_fits_folder,get_blur_folder(the_sn,the_psf),table_name,name,"{}_{}.fits".format(name,band))

def get_galaxy_list(table_name,the_sn,the_psf):
    return os.listdir(os.path.join(blur_sdss_fits_folder,get_blur_folder(the_sn,the_psf),table_name))

def get_dark_side_csv_path(table_name):
    return os.path.join(dark_side_path,"table_{}.csv".format(table_name.strip()[-1]))

def get_sparcfire_galaxy_csv_path(table_name,the_sn,the_psf):
    return os.path.join(blur_sdss_folder,get_blur_folder(the_sn,the_psf),table_name,"G.out","galaxy.csv")

def get_ref_band_path(table_name):
    return os.path.join(fixed_ref_band_folder,"{}.txt".format(table_name))

def get_csv_output_path(dataset_name,the_sn,the_psf,the_r,the_type):
    r_string = str(the_r).replace('.','')
    if the_type != 'fixed_center':
        folder_dir = "{}_r_{}".format(the_type,r_string)
    else:
        folder_dir = the_type
    file_name = "{}_{}.csv".format(get_blur_folder(the_sn,the_psf),dataset_name)
    return os.path.join(csv_base_output_dir,folder_dir,file_name)

def get_visualization_dir_output_path(dataset_name,the_sn,the_psf,the_r,the_type):
    r_string = str(the_r).replace('.','')
    if the_type != 'fixed_center':
        folder_dir = "{}_r_{}".format(the_type,r_string)
    else:
        folder_dir = the_type
    add_folder_name = "{}_{}".format(get_blur_folder(the_sn,the_psf),dataset_name)
    return os.path.join(csv_base_output_dir,folder_dir,add_folder_name)

In [7]:
def get_ref_band_dict_for_table(table_name):
    ref_band_dict = dict()

    first = True
    with open(get_ref_band_path(table_name)) as f:
        for line in f.readlines():
            if first:
                first = False
            else:
                to_parse = line.strip().rsplit(" ",1)
                if len(to_parse) != 2: continue

                ref_band_dict[to_parse[0]] = to_parse[1]
    return ref_band_dict


def get_ref_band_and_dark_side_dicts_for_folders(the_folders):
    ref_band_dict = dict()
    dark_side_dict = dict()

    for folder in the_folders:
        for (k,v) in get_ref_band_dict_for_table(folder).items():
            ref_band_dict[k] = v
        for (k,v) in read_spin_parity_galaxies_label_from_csv(get_dark_side_csv_path(folder)).items():
            dark_side_dict[k] = v
    
    return ref_band_dict, dark_side_dict

def get_galaxy_to_folder_dict(the_folders,the_sn,the_psf):
    galaxy_to_folder_dict = dict()

    for each_folder in the_folders:
        for each_galaxy in get_galaxy_list(each_folder,the_sn,the_psf):
            galaxy_to_folder_dict[each_galaxy] = each_folder

    return galaxy_to_folder_dict

def get_sparcfire_for_folders(the_folders, sn, psf):
    sparcfire_bands = dict()
    for each_folder in the_folders:
        the_pa = get_sparcfire_galaxy_csv_path(each_folder,sn, psf)
        #print(the_pa)
        for (k,v) in read_sparcfire_galaxy_csv(the_pa).items():
            sparcfire_bands[k] = v

    return sparcfire_bands


In [8]:
def create_blur_csv(gals,the_band_pairs,csv_path):
    correct_count = defaultdict(int)
    no_vote_count = defaultdict(int)
    incorrect_count = defaultdict(int)


    """create an csv containing the information from gofher of the given galaxies"""
    #Construct CSV header:
    csv_column_headers = ['name','dark_side_label','pos_side_label','neg_side_label','ref_band','encounted_sersic_error','table_name']
    #per_band_column_headers = ['label','score']
    per_band_column_headers = ['pos_side_mean','pos_side_std','neg_side_mean','neg_side_std','D','P','label','score']

    for band_pair in the_band_pairs:
        band_pair_key = construct_band_pair_key(band_pair[0],band_pair[1])
        csv_column_headers.extend(list(map(lambda x: "{}_{}".format(band_pair_key,x),per_band_column_headers)))

    csv_column_headers.extend(['vote_count','vote_score'])
    
    #Construct CSV rows:
    rows = []
    for gal in gals:
        if not isinstance(gal,galaxy): continue

        the_row = [gal.name,gal.dark_side,gal.pos_side_label,gal.neg_side_label,gal.ref_band,str(gal.encountered_sersic_fit_error),gal.folder]
        for band_pair in the_band_pairs:
            band_pair_key = construct_band_pair_key(band_pair[0],band_pair[1])
            if band_pair_key not in gal.band_pairs: 
                the_row.extend(['MISSING',0])
                continue
            the_band_pair = gal.get_band_pair(band_pair_key)

            the_row.extend([the_band_pair.pos_fit_norm_mean,the_band_pair.pos_fit_norm_std,
                            the_band_pair.neg_fit_norm_mean,the_band_pair.neg_fit_norm_std,
                            the_band_pair.d_stat, the_band_pair.p_value,
                            the_band_pair.classification_label,
                            the_band_pair.classification_score])
            
            #the_row.extend([the_band_pair.classification_label,
            #                the_band_pair.classification_score])
            
            if the_band_pair.classification_score == 1:
                correct_count[band_pair_key+"_score"] += 1
            elif the_band_pair.classification_score == -1:
                incorrect_count[band_pair_key+"_score"] += 1
            else:
                no_vote_count[band_pair_key+"_score"] += 1
        
        the_row.extend([gal.cumulative_classification_vote_count,gal.cumulative_score])
        rows.append(the_row)

        if gal.cumulative_score == 1:
            correct_count["vote_score"] += 1
        elif gal.cumulative_score == -1:
            incorrect_count["vote_score"] += 1
        else:
            no_vote_count["vote_score"] += 1

    the_score_keys = list(filter(lambda x: "score" in x,csv_column_headers))

    correct_count_row = ['']*len(csv_column_headers)
    correct_count_row[0] = 'CORRECT_COUNT'

    no_vote_count_row = ['']*len(csv_column_headers)
    no_vote_count_row[0] = 'NO_VOTE_COUNT'

    incorrect_count_row = ['']*len(csv_column_headers)
    incorrect_count_row[0] = 'INCORRECT_COUNT'

    accuracy_row = ['']*len(csv_column_headers)
    accuracy_row[0] = 'ACCURACY'

    for each_score_key in the_score_keys:
        row_ind = csv_column_headers.index(each_score_key)

        correct = 0
        no_vote = 0
        incorrect = 0
        accuracy = 100.0

        if each_score_key in correct_count:
            correct = correct_count[each_score_key]

        if each_score_key in no_vote_count:
            no_vote = no_vote_count[each_score_key]

        if each_score_key in incorrect_count:
            incorrect = incorrect_count[each_score_key]

        if (correct+incorrect) > 0:
            accuracy = (correct/(correct+incorrect)) * 100.0

        correct_count_row[row_ind] = correct
        no_vote_count_row[row_ind] = no_vote
        incorrect_count_row[row_ind] = incorrect
        accuracy_row[row_ind] = accuracy

    rows.extend([correct_count_row,no_vote_count_row,incorrect_count_row,accuracy_row])
    write_csv(csv_path,csv_column_headers,rows)

In [9]:
def get_visualization_additional_str(the_sn,the_psf,the_r,the_type):
    return " {} r={} (psf={}, sn={})".format(the_type,the_r,the_psf,the_sn)

In [10]:
def run_on_blurred_images(name, fits_path, sparcfire_bands, the_ref_band, table_name, dark_side_label='', bulge_disk_r=1.0, run_type = '', vis_path ='', add_vis_string=''):
    the_gal = galaxy(name,dark_side_label)

    for band in SDSS_BANDS_IN_ORDER:
       if not os.path.exists(fits_path(table_name,name,band)): continue
       the_gal.construct_band(band,fits_path(table_name,name,band))

    the_sparcfire_derived_params = get_gofher_params_for_fixed_ref_band(sparcfire_bands, the_ref_band, bulge_disk_r=bulge_disk_r)
    if the_sparcfire_derived_params == None: return None

    the_gal.ref_band = the_ref_band
    the_band_pairs = list(itertools.combinations(SDSS_BANDS_IN_ORDER, 2))

    the_gal.folder = table_name

    run_type = run_type.strip().lower() 
    if run_type == 'inital_guess':
        the_gal = run_gofher_on_galaxy_with_sparcfire_center_inital_guess(the_gal,the_band_pairs,the_sparcfire_derived_params)
    elif run_type == 'fixed_center':
        the_gal = run_gofher_on_galaxy_with_fixed_center_only(the_gal,the_band_pairs,the_sparcfire_derived_params)
    else:
        the_gal = run_gofher_on_galaxy_with_fixed_gofher_parameters(the_gal,the_band_pairs,the_sparcfire_derived_params)

    if generate_visualization and vis_path != '':
        visualize_sdss(the_gal,vis_path, add_vis_string)
        
    return the_gal #uncomment for csv

In [11]:
def run_all_settings():
    for each_dataset in datasets:
        print(each_dataset)
        ref_band_dict, dark_side_dict = get_ref_band_and_dark_side_dicts_for_folders(list(datasets[each_dataset]))
        j = 1
        for sn in sn_list:
            for psf in psf_list:
                print(" ",j,get_blur_folder(sn,psf))
                galaxy_to_table_dict =  get_galaxy_to_folder_dict(list(datasets[each_dataset]),sn,psf)
                #print(get_sparcfire_for_folders(list(datasets[each_dataset]), sn, psf))
                sparcfire_csv_dict = get_sparcfire_for_folders(list(datasets[each_dataset]), sn, psf)

                #r = 0.125
                #run_type = 'inital_guess'

                for r in bulge_disk_rs:
                    for run_type in run_types:
                        print("     ",r,run_type)
                        csv_output_path = get_csv_output_path(each_dataset,sn,psf,r,run_type)
                        ##if os.path.exists(csv_output_path): 
                        ##    j += 1
                        ##    continue

                        vis_folder = get_visualization_dir_output_path(each_dataset,sn,psf,r,run_type)
                        #if os.path.exists(vis_folder):
                        #    j += 1
                        #    continue
                        if generate_visualization and not os.path.exists(vis_folder):
                            os.makedirs(vis_folder)

                        i = 1
                        the_gals = []
                        for name in galaxy_to_table_dict.keys():
                            print("     ",i,name)
                            if name not in sparcfire_csv_dict: continue
                            if name not in ref_band_dict: continue
                            if name not in galaxy_to_table_dict: continue
                            if name not in dark_side_dict: continue

                            try:
                                vis_path = os.path.join(vis_folder,"{}.png".format(name))
                                #if os.path.exists(vis_path): continue #to skip rerun
                                add_vis_string = get_visualization_additional_str(sn,psf,r,run_type)
                                current_gal = run_on_blurred_images(name,get_fits_path(sn,psf),sparcfire_csv_dict[name],ref_band_dict[name],galaxy_to_table_dict[name],dark_side_dict[name],r,run_type,vis_path,add_vis_string) #uncomment for csv
                                #run_on_blurred_images(name,get_fits_path(sn,psf),sparcfire_csv_dict[name],ref_band_dict[name],galaxy_to_table_dict[name],dark_side_dict[name],r,run_type,vis_path,add_vis_string)

                                if not isinstance(current_gal,galaxy): continue #uncomment for csv
                                the_gals.append(current_gal) #uncomment for csv
                            except Exception as e:
                                print("         ",e)
                                #return
                            i += 1


                        
                        #uncomment for csv
                        if generate_csv:
                            base_dir = os.path.dirname(csv_output_path)
                            if not os.path.exists(base_dir):
                                os.makedirs(base_dir)
                        
                            the_band_pairs = list(itertools.combinations(SDSS_BANDS_IN_ORDER, 2)) #SDSS
                        
                            #create_blur_csv(the_gals,the_band_pairs,csv_output_path)
                            create_sdss_csv(the_gals,the_band_pairs,csv_output_path)
                        
                        j += 1
                        return
                

In [12]:
run_all_settings()

train
  1 psf_4.0_background_8
      0.5 sparcfire
      1 IC1683
      2 IC1755
      3 IC2101


      4 IC5376
      5 MCG-02-02-030
      5 MCG-02-51-004
      5 NGC1035
      6 NGC1056
      7 NGC1084
      8 NGC1093
      9 NGC157
      10 NGC1667
      11 NGC169
      12 NGC2347
      13 NGC2403
      13 NGC2410
      14 NGC2639
      15 NGC2683
          zero-size array to reduction operation minimum which has no identity
      16 NGC2742
      17 NGC2775
          zero-size array to reduction operation minimum which has no identity
      18 NGC2782
      18 NGC2841
      19 NGC2903
      20 NGC3160
      21 NGC3198
          zero-size array to reduction operation minimum which has no identity
      22 NGC3227
      23 NGC3310
      24 NGC3368
      25 NGC3521
          zero-size array to reduction operation minimum which has no identity
      26 NGC3623
          zero-size array to reduction operation minimum which has no identity
      27 NGC3627
          zero-size array to reduction operation minimum which has no identity
      28 NGC3646
      29 NGC3672
      30 NGC367

KeyError: 'g-i'