In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import seaborn as sns
import matplotlib.pyplot as plt
from astropy.table import Table
import copy
from astroduet.models import Simulations
from matplotlib import rcParams
import numpy as np
rcParams.update({'font.size': 18})

In [None]:
table = Table.read('monte_carlo.csv')

models = list(table['model'])
# models = [m + '0' if (m.startswith('rsg') or m.startswith('ysg')) else m for m in models]
table['model'] = models
# table.write('monte_carlo_corr.csv')

good_chisq = (table['D2_chisq'] < 1e32) & (table['D1_chisq'] < 1e32)
table['best_chisq'] = good_chisq
good_chisq = (table['D1_chisq'] < 1e32)
table['good_chisq'] = good_chisq
# table['valid_fit'] = (table['D1_chisq'] > 0) & (table['D2_chisq'] > 0) & (table['ngood'] > 5)
table['valid_fit'] = (table['D1_chisq'] > 0) & (table['ngood'] > 5)
table['rejected'] = ~(table['good_chisq']&table['valid_fit'])
table['fit_model'] = [s.replace('.dat', '') for s in table['fit_model']]

faint_galaxy = (table['magnitude'] > 27) | (table['magnitude'] == 0)
table['galaxy'][table['galaxy'] != "none"] = [str(i) for i in table['magnitude'][table['galaxy'] != "none"].astype(int)]

table.sort('galaxy')
# table = table[~table['rejected']]
# table = table[good]

In [None]:
table

In [None]:
set(table['model'][~table['rejected']])
table

In [None]:
plt.figure(figsize=(10, 8))
sns.pairplot(table[~table['rejected']&(table['galaxy'] != 'none')].to_pandas(), 
             hue='fit_model', 
             vars='D1_chisq,D2_chisq,distance,magnitude'.split(','));

In [None]:
plt.figure(figsize=(10, 8))
sns.pairplot(table[~table['rejected']&(table['model'] == 'bsg20')].to_pandas(), 
             hue='fit_model', 
             vars='D1_chisq,D2_chisq,distance'.split(','),
                 plot_kws=dict(edgecolor=None, size=0.5, alpha=0.5));

In [None]:
threshold_ngood = 100
for model in sorted(set(table['model'])):
    plt.figure(figsize=(10, 8))
    print(model)
    good = table['model'] == model
    table_filt = table[good]
    good_galaxy = table_filt['galaxy'] == 'none'
    print(np.count_nonzero(table_filt[good_galaxy]['ngood'] > threshold_ngood) / len(table_filt[good_galaxy]) )
    sns.pairplot(table_filt.to_pandas(), hue='galaxy', 
                 vars='ngood,distance'.split(','),
                 plot_kws=dict(edgecolor=None, size=0.5, alpha=0.5));

In [None]:
from astropy.table import Table, QTable
from tqdm import tqdm
import numpy as np
import re
radius_re = re.compile(r'sg([0-9]+)[^0-9]*')


def get_radius(newtable, all_models=None, all_fit_models=None):
    all_models_in_table = list(set(table_filtered['model']))
    all_fit_models_in_table = list(set(table_filtered['fit_model']))
 
    if all_fit_models is None:
        all_fit_models = all_fit_models_in_table
    if all_models is None:
        all_models = all_models_in_table
   
    radius_real = [radius_re.search(m).group(1) for m in all_models_in_table]
    radius_fit = [radius_re.search(m).group(1) for m in all_fit_models_in_table]
    
    newtable['radius'] = 0.
    newtable['fit_radius'] = 0.
    
    for r, m in zip(radius_real, all_models):
        good = newtable['model'] == m
        newtable['radius'][good] = r

    for r, m in zip(radius_fit, all_fit_models):
        good = newtable['fit_model'] == m
        newtable['fit_radius'][good] = r
        
    return newtable


def rearrange_table(table_filtered, quantity_to_compare, group_by='distance',
                    all_models=None,
                    all_fit_models=None, calculate_radius=False):
    t = table_filtered.group_by(group_by)
    standard_quantities = 'model,galaxy,final_resolution,distance,ngood,rejected,galaxy'.split(',')

    newtable = QTable(t[standard_quantities].groups.aggregate(lambda arr: arr[0]))
    
    for model_fit in tqdm(all_fit_models):
        values = [sub[quantity_to_compare][sub['fit_model'] == model_fit][0] for sub in t.groups]
        newtable[model_fit] = values
    
    if calculate_radius:
        newtable = get_radius(newtable)
        radius_distance = newtable['fit_radius'] - newtable['radius']
    
    newtable['correct_model'] = 0
  
    best_fit_is_correct = np.zeros(len(newtable), dtype=int)
    for model_fit in tqdm(all_fit_models):
        other_models = [m for m in all_models if m != model_fit]
        this_model_is_best = np.ones(len(newtable), dtype=bool)
        for m in other_models:
            if m not in newtable.colnames:
                continue
            this_model_is_best = newtable[model_fit] < newtable[m]
        best_fit_is_correct += this_model_is_best&(model_fit == newtable['model'])

    newtable['correct_model'] = best_fit_is_correct
    return newtable

def measure_radius_error(table_filtered, group_by='distance',
                    all_models=None, quantity_to_compare='D1_chisq',
                    all_fit_models=None, calculate_radius=False):

    newtable = get_radius(copy.deepcopy(table_filtered))
    newtable = newtable[newtable[quantity_to_compare] > 0]
    newtable['radius_err'] = newtable['fit_radius'] - newtable['radius']
    newtable['radius_err_rel'] = newtable['radius_err'] / newtable['radius']
    newtable['quantity_to_compare'] = quantity_to_compare
    
    grouped = newtable.group_by(group_by)
    newtable_aggr = grouped.groups.aggregate(lambda arr: arr[0])
    for i, (g, newt) in enumerate(zip(grouped.groups, newtable_aggr)):
#         if i == 0:
#             print(g[quantity_to_compare])
#             print(g[np.argmin(g[quantity_to_compare])])
        newtable_aggr[i] = g[np.argmin(g[quantity_to_compare])]

    return newtable_aggr

In [None]:
# table_filtered = table[~table['rejected']]
table_filtered = copy.deepcopy(table)
# rearrange_table(table_filtered, 'D1_chisq')
newtable_D1 = measure_radius_error(table_filtered)
newtable_D2 = measure_radius_error(table_filtered, quantity_to_compare='D2_chisq')
newtable_ratio = measure_radius_error(table_filtered, quantity_to_compare='ratio_chisq')
newtable_D1_nofit = measure_radius_error(table_filtered, quantity_to_compare='D1_chisq_nofit')
newtable_D2_nofit = measure_radius_error(table_filtered, quantity_to_compare='D2_chisq_nofit')
newtable_ratio_nofit = measure_radius_error(table_filtered, quantity_to_compare='ratio_chisq_nofit')

newtable_D1.sort('galaxy')
newtable_D2.sort('galaxy')
newtable_ratio.sort('galaxy')
newtable_D1_nofit.sort('galaxy')
newtable_D2_nofit.sort('galaxy')
newtable_ratio_nofit.sort('galaxy')


In [None]:
def plot_ngood_vs_distance(table):
    grouped_model = table.group_by('model')
    quantity_to_compare = list(set(table['quantity_to_compare']))[0]
    factor = quantity_to_compare.replace('_chisq', '').replace('_nofit', '')
    print(quantity_to_compare)
    for table_group in grouped_model.groups:
        good_galaxy = (table_group['galaxy'] == 'none')&(table_group[quantity_to_compare] > 0) &(table_group[factor] != 1.0) 
        plt.figure()
        model = list(set(table_group['model']))[0]
        plt.title(model)
        table_filt = table_group[good_galaxy]
        percent = table_filt['ngood'] / table_filt['ngood'].max() * 100
        
        good = (percent > 88)&(percent < 92)
        if np.any(good):
            distance_90 = np.median(table_filt["distance"][good])
            print(model, f'90% values: {distance_90} Mpc')
            plt.axvline(distance_90)
        plt.scatter(table_filt['distance'], percent)
        plt.ylabel('ngood (%)')
        plt.xlabel('distance')
        
plot_ngood_vs_distance(newtable_D1)

In [None]:
plot_ngood_vs_distance(newtable_ratio_nofit)

In [None]:
def plot_factor_vs_distance(table):
    grouped_model = table.group_by('model')
    quantity_to_compare = list(set(table['quantity_to_compare']))[0]
    factor = quantity_to_compare.replace('_chisq', '')

    for table_group in grouped_model.groups:
        good_galaxy = (table_group['galaxy'] == 'none')&(table_group[quantity_to_compare] > 0) &(table_group[factor] != 1.0) 
        plt.figure()
        plt.title(list(set(table_group['model']))[0])
        table_filt = table_group[good_galaxy]
        plt.scatter(table_filt['distance'], table_filt[factor])
        plt.ylabel(factor)
        plt.xlabel('distance')
        
plot_factor_vs_distance(newtable_D1)

In [None]:
plot_factor_vs_distance(newtable_D2)

In [None]:
def plot_chisq_vs_distance(table):
    grouped_model = table.group_by('model')
    quantity_to_compare = list(set(table['quantity_to_compare']))[0]
    factor = quantity_to_compare.replace('_chisq', '')

    for table_group in grouped_model.groups:
        good_galaxy = (table_group['galaxy'] == 'none')&(table_group[quantity_to_compare] > 0) &(table_group[factor] != 1.0) 
        table_filt = table_group[good_galaxy]
        plt.figure()
        plt.title(list(set(table_group['model']))[0])
        plt.scatter(table_filt['distance'], table_filt[quantity_to_compare], zorder=10)
        plt.ylabel(quantity_to_compare)
        plt.xlabel('distance')


In [None]:
plot_chisq_vs_distance(newtable_D1)

In [None]:
import matplotlib

def plot_radius_err_vs_galaxy_mag(table, name=""):
    grouped_model = table.group_by('model')

    for table_filt in grouped_model.groups:
        plt.figure(figsize=(15,5))
        assert len(set(table_filt['model'])) == 1, f"{set(table_filt['model'])} is wrong"
        title = list(set(table_filt['model']))[0]
        real_radius = table_filt['radius'][0]

        plt.title(name + title)
        grouped_gal = table_filt.group_by('galaxy')
        all_galaxy_vals = list(set(grouped_gal['galaxy']))

        cmap = matplotlib.cm.get_cmap('nipy_spectral')
        normalize = matplotlib.colors.Normalize(vmin=0, vmax=len(all_galaxy_vals))
        colors = [cmap(normalize(value)) for value in range(len(all_galaxy_vals))]
        for i, filt in enumerate(grouped_gal.groups):
            distances = np.linspace(filt['distance'].min(), filt['distance'].max(), 10)
            label = list(set(filt['galaxy']))[0]
            dist = []
            rad_err = []
            good_values = []
            for distance_intvs in zip(distances[:-1], distances[1:]):
                good_distance = (filt['distance'] >= distance_intvs[0])&(filt['distance'] < distance_intvs[1])
                if not np.any(good_distance):
                    continue
        #         print(f"   Distance: {distance_intvs[0]} to {distance_intvs[1]}")
                good = filt[good_distance]
                not_rejected = ~good['rejected']
                if not np.any(not_rejected):
                    continue
                valid = good[not_rejected]
                if not len(valid) > 2:
                    continue

                radius_err_abs = np.mean(np.abs(valid['radius_err']))

                rad_err.append(radius_err_abs)
                dist.append(np.mean(distance_intvs))
                if radius_err_abs < real_radius / 2:
                    good_values.append(True)
                else:
                    good_values.append(False)
            if len(dist) == 0:
                continue
            dist = np.array(dist)
            rad_err = np.array(rad_err)
            good_values = np.array(good_values, dtype=bool)
            plt.scatter(dist[good_values], rad_err[good_values], color=colors[i], 
                        label=label, alpha=0.7, edgecolors='k', zorder=10)
            plt.scatter(dist[~good_values], rad_err[~good_values], color=colors[i], 
                        alpha=0.4, edgecolors='none')
        plt.xlabel("Distance (Mpc)")
        plt.ylabel("Radius error ($R_{\odot}$)")
        plt.legend(title="Bkg galaxy mag")
        plt.yscale('symlog')
        plt.savefig(name + title + '_radius_err_mag.png')
        
plot_radius_err_vs_galaxy_mag(newtable_D1, name="D1_")

In [None]:
plot_radius_err_vs_galaxy_mag(newtable_D2, name="D2_")

In [None]:
plot_radius_err_vs_galaxy_mag(newtable_ratio, name="ratio_")

In [None]:
plot_radius_err_vs_galaxy_mag(newtable_D1_nofit, name="D1_nofit_")

In [None]:
plot_radius_err_vs_galaxy_mag(newtable_D2_nofit, name="D2_nofit_")

In [None]:
plot_radius_err_vs_galaxy_mag(newtable_ratio_nofit, name="ratio_nofit_")

In [None]:
# plt.figure(figsize=(15, 10))

# colors = ['b', '#6666aa', '#aaaaff', 'r', '#aa6666', '#ffaaaa']
# for i, m in enumerate(all_models):
#     filt = newtable[newtable['model'] == m]
#     label = m
#     print(m)
#     distances = np.arange(50, 700, 50)
#     for distance_intvs in zip(distances[:-1], distances[1:]):
#         good_distance = (filt['distance'] >= distance_intvs[0])&(filt['distance'] < distance_intvs[1])
#         if not np.any(good_distance):
#             continue
#         print(f"   Distance: {distance_intvs[0]} to {distance_intvs[1]}")
#         good = filt[good_distance]
#         not_rejected = ~good['rejected']
#         if not np.any(not_rejected):
#             print('All rejected')
#             continue
#         valid = good[not_rejected]
        
#         correct_model = valid['correct_model']
#         if not np.count_nonzero(correct_model) > 2:
#             print("Not enough correct fits")
#             continue
#         correct = valid[correct_model]
#         valid_ratio = len(valid)/len(good)*100
#         print(f"      Valid points: {valid_ratio:.0f}%")
#         correct_ratio = len(correct) / len(valid)*100.0
#         print(f"      Correct model fit: {correct_ratio:.0f}%")
#         plt.scatter(np.mean(distance_intvs), valid_ratio, c=colors[i], label=label)
#         label=None
        
#     plt.xlabel("Distance (Mpc)")
#     plt.ylabel("% Valid simulations")
#     plt.legend()
        
#     print()

In [None]:
table_filtered_R600 = 
all_models_in_table = list(set(table_filtered['model']))

plt.figure(figsize=(15, 10))

colors = ['b', '#6666aa', '#aaaaff', 'r', '#aa6666', '#ffaaaa']
for i, m in enumerate(sorted(all_models)):
    filt = newtable[newtable['model'] == m]
    label = m
#     print(m)
    distances = np.linspace(filt['distance'].min(), filt['distance'].max(), 10)
    for distance_intvs in zip(distances[:-1], distances[1:]):
        good_distance = (filt['distance'] >= distance_intvs[0])&(filt['distance'] < distance_intvs[1])
        if not np.any(good_distance):
            continue
#         print(f"   Distance: {distance_intvs[0]} to {distance_intvs[1]}")
        good = filt[good_distance]
        not_rejected = ~good['rejected']
        if not np.any(not_rejected):
            print('All rejected')
            continue
        valid = good[not_rejected]
        
        radius_err_abs = np.mean(np.abs(valid['radius_err']))

        plt.scatter(np.mean(distance_intvs), radius_err_abs, c=colors[i], label=label)
        label=None
        
    plt.xlabel("Distance (Mpc)")
    plt.ylabel("Radius error")
    plt.legend()
        
    print()

In [None]:
# from sklearn.preprocessing import OneHotEncoder
# from sklearn.feature_selection import RFE
# from sklearn.linear_model import LogisticRegression
# ohe = OneHotEncoder(sparse=False)
# target = table['model'] == table['fit_model']
# table_transformed = ohe.fit_transform(table.to_pandas())


In [None]:
# # load the iris datasets
# dataset = table.pandas()
# # create a base classifier used to evaluate a subset of attributes
# model = LogisticRegression()
# # create the RFE model and select 3 attributes
# rfe = RFE(model, 3)
# rfe = rfe.fit(dataset.data, dataset.target)
# # summarize the selection of the attributes
# print(rfe.support_)
# print(rfe.ranking_)