In [None]:
from itertools import combinations
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import astropy.units as u
from astropy.table import Table
from astropy.coordinates import SkyCoord


from exod.post_processing.crossmatch import CrossMatch
from exod.post_processing.crossmatch import crossmatch_vizier

In [None]:
df_regions = pd.read_csv('../data/results_combined/30_4_2024/df_regions.csv')
#df_regions = df_regions.sample(100)

In [None]:
df_regions

In [None]:
def split_subsets(df_regions):
    subsets = ['5_0.2_2.0',
               '5_2.0_12.0',
               '5_0.2_12.0',
               '50_0.2_2.0',
               '50_2.0_12.0',
               '50_0.2_12.0',
               '200_0.2_2.0',
               '200_2.0_12.0',
               '200_0.2_12.0']
    
    dfs = {}
    for s in subsets:
        dfs[s] = df_regions[df_regions['runid'].str.contains(s)]
    return dfs


def calc_subset_stats(dfs):
    all_res = []
    for k, df in dfs.items():
        t_bin, E_lo, E_hi = k.split('_')
        res = {'subset'      : k,
               't_bin'       : t_bin,
               'E_lo'        : E_lo,
               'E_hi'        : E_hi,
               'n_regions'   : len(df),
               'n_obsids'    : len(df['runid'].value_counts()),
               'reg/obs'     : len(df) / len(df['runid'].value_counts()),
               'mean counts' : df['intensity_mean'].mean(),
               'std counts' : df['intensity_mean'].std(),
               }
        all_res.append(res)
    
    df_region_subset = pd.DataFrame(all_res)
    return df_region_subset
    

In [None]:
dfs = split_subsets(df_regions=df_regions)
df_region_subset = calc_subset_stats(dfs)

In [None]:
df_region_subset.style.background_gradient(sns.diverging_palette(125, 365, as_cmap=True))

In [None]:
linestyles = {'5_'   : 'solid',
              '50_'  : 'dotted',
              '200_' : 'dashed'}
colors = {'0.2_2.0'  : 'red',
          '2.0_12.0' : 'blue',
          '0.2_12.0' : 'black'}

fig, ax = plt.subplots(2,1,figsize=(12,8), sharex=True)
for k, df in dfs.items():
    ls = [linestyles[key] for key in linestyles.keys() if key in k]
    c  = [colors[key] for key in colors.keys() if key in k]
    ax[0].hist(df['intensity_mean'], bins=np.linspace(0,500, 50), label=k, histtype='step', lw=1.5, ls=ls[0], color=c[0])
    ax[1].hist(df['intensity_mean'], bins=np.linspace(0,500, 50), label=k, histtype='step', lw=1.5, density=True, ls=ls[0],  color=c[0])
ax[0].set_ylabel('Number')
ax[1].set_ylabel('Fraction')
ax[1].set_xlabel('Intensity Mean (Counts)')
for a in ax:
    a.legend()
    a.grid()
plt.subplots_adjust(hspace=0)
plt.show()
    

In [None]:
def crossmatch_subsets(dfs):
    return 'something bruh'

In [None]:
SkyCoord(ra=df['ra_deg'], dec=df['dec_deg'], unit='deg', frame='fk5', equinox='J2000')

In [None]:
df

In [None]:
tab_cmatch

In [None]:
for col in df.columns[1:]:
    df[col].value_counts()

In [None]:
for k1, df1 in dfs.items():
    sc1 = SkyCoord(ra=df1['ra_deg'], dec=df1['dec_deg'], unit='deg', frame='fk5', equinox='J2000')#
    all_res = []
    res = {}
    res[k1]  = np.arange(len(sc1))
    for k2, df2 in dfs.items():
        if k1 == k2:
            continue
        sc2 = SkyCoord(ra=df2['ra_deg'], dec=df2['dec_deg'], unit='deg', frame='fk5', equinox='J2000')
        cmatch = sc1.match_to_catalog_sky(sc2)
        
        tab_cmatch = Table(cmatch)
        tab_cmatch.rename_columns(names=tab_cmatch.colnames, new_names=['idx', 'sep2d', 'dist3d'])
        tab_cmatch['sep2d'] = tab_cmatch['sep2d'].to(u.arcsec)
        
        is_match = np.where(tab_cmatch['sep2d'] < max_sep, tab_cmatch['idx'], -1)
        res[k2] = is_match
        
        print(f'{k1} ({len(df1)}) {k2:<12} ({len(df2)})')

    df = pd.DataFrame(res)
    # print(df)
    print(f'='*50)
    print(f'{k1} Results:')
    for col in df.columns[1:]:
        count = (df[col] > -1).sum()
        print(f'{col:<12} : {count:<5} / {len(df)} ({count/len(df):.2f})')
    print('='*50)