Investigate how well the final report unit works.
- Retrieve T2ElasticcReport output (possibly together with redshift and ndet.
- Select one of the elasticc submission channel.
- Find out which simulation classes this corresponded to.
- For channels correspond to these:
- Study the fraction of all the different models. Possibly also do this as a function of
number of detections and/or redshift.

In [None]:
import pymongo
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from extcats import CatalogQuery
from scipy.stats import binned_statistic

In [None]:
client = pymongo.MongoClient()

In [None]:
db = client.ElasticcReportValidation

In [None]:
col = db.t2

In [None]:
# Get the same z information as XGB.
def get_zinfo(t2_res):
    '''
    Extract the same redshift information as what is used by elasticc units.
    '''
    # For some reason we trained xgb using z, zerr and host_sep
    zdata = {}
    
    if t2_res['z_source'] in ['HOSTGAL2_ZQUANT', 'HOSTGAL_ZQUANT', 'HOSTGAL_ZSPEC', 'default']:
                        
        # This was the sampling used for job_training
        if len(t2_res['z_samples'])==3:
            zdata['z'] = t2_res['z_samples'][1]
            zdata['z_err'] = t2_res['z_samples'][1] - t2_res['z_samples'][0]
            zdata['host_sep'] = t2_res['host_sep']
        elif len(t2_res['z_samples'])==4 and t2_res['z_samples'][0]==0.01:
            print('hostless')
            zdata = {'z':None, 'z_err':None, 'host_sep': None}
        else:
            print('whaaat')
            print(t2_res)
            raise ValueError
            
    else:
        print('whiiit')
        print(t2_res)
        raise ValueError
        
    return zdata


In [None]:
reports = []

In [None]:
for t2info in col.find({'unit':'T2ElasticcRedshiftSampler'}):
    b = {}
    b['stock'] = t2info['stock']
    b['channel'] = t2info['channel'][0]
    if 'body' not in t2info.keys():
        print('Failed run - enough to care about?')
        print(t2info)
        continue
    zinfo = get_zinfo(t2info['body'][-1])
    
    b.update(zinfo)
    
    reports.append( b )

In [None]:
df_rep = pd.DataFrame.from_dict(reports)

In [None]:
df_rep.shape

In [None]:
df_rep['channel'].unique()

In [None]:
modelmap = {'agn':'AGN', 'cart':'CART','ulenssinglepylima':'uLens', 'slsnihost':'SLSN', 
 'dwarfnova':'dwarf-nova', 'ulensbinary':'uLens', 'slsninohost':'SLSN', 'eb':'EB',
    'snia91bg':'SNIa91bg', 'ulenssinglegenlens':'uLens', 'sniasalt2':'SNIa', 'ilot':'ILOT', 
 'sniax':'SNIax', 'tde':'TDE', 'knb19':'KN',
       'sniibhostxtv19':'SNII', 'knk17':'KN', 'snibhostxtv19':'SNIbc', 
 'mdwarfflare':'Mdwarf-flare', 'pisn':'PISN',
       'sniitemplates':'SNII', 'rrl':'RRL', 'snibtemplates':'SNIbc', 'snicblhostxtv19':'SNIbc',
       'snichostxtv19':'SNIbc', 'snictemplates':'SNIbc', 'sniihostxtv19':'SNII',
       'sniinhostxtv19':'SNII', 'sniinmf':'SNII', 'sniinmosfit':'SNII', 'dsct':'DSC', 
 'cepheid':'Cepheid'}

In [None]:
df_rep['model'] = df_rep['channel'].map(modelmap)

In [None]:
df_rep

First question - for each model, how many objects and what fraction of these have redshifts?

In [None]:
for model in df_rep['model'].unique():
    z = df_rep['z'][df_rep['model']==model]
    is_z = (z>0)
    print(model, len(z), sum(is_z))

In [None]:
for model in df_rep['model'].unique():
    z = df_rep['z'][df_rep['model']==model]
    is_z = (z>0)
    if sum(is_z)==0:
        print('... no redshifts, no need to plot.')
        continue

        
    plt.figure()
    plt.title(model)
    plt.hist(z[is_z], bins=20)
    plt.xlabel('z')
    plt.show()

In [None]:
bins = np.arange(-0.0901,3.5,0.1)

In [None]:
bins

In [None]:
for model in df_rep['model'].unique():
    z = df_rep['z'][df_rep['model']==model]
    is_z = (z>0)
    if sum(is_z)==0:
        print('... no redshifts, no need to plot.')
        continue

        
    plt.figure()
    plt.title(model)
    plt.hist(z[is_z], bins=bins)
    plt.xlabel('z')
    plt.show()

In [None]:
for model in df_rep['model'].unique():
    z = df_rep['z'][df_rep['model']==model]
    is_z = (z>0)
    if sum(is_z)==0:
        print('... no redshifts, no need to plot.')
        continue

        
    plt.figure()
    plt.title(model)
    plt.hist(z[is_z], bins=bins, density=True, stacked=True)
    plt.xlabel('z')
    plt.show()

In [None]:
plt.figure()


for model in df_rep['model'].unique():
    z = df_rep['z'][df_rep['model']==model]
    is_z = (z>0)
    if sum(is_z)==0:
        print('... no redshifts, no need to plot.')
        #continue

    plt.hist(z, bins=bins, density=False, stacked=False, label=model)
plt.legend(loc='best')
plt.xlabel('z')
plt.show()

Next stage: Try to evalute one of the models.

In [None]:
from sklearn.neighbors import KernelDensity

In [None]:
z = df_rep['z'][df_rep['model']=='AGN']
is_z = (z>0)

In [None]:
kde = KernelDensity(bandwidth=0.15, kernel='gaussian')

In [None]:
modprobs = {}
for model in df_rep['model'].unique():
    z = df_rep['z'][df_rep['model']==model]
    is_z = (z>0)

    # Only attempt for non-recurring?
    if model not in ['CART', 'uLens', 'SLSN', 'dwarf-nova', 'SNIa91bg', 'SNIa','ILOT', 'SNIax', 'TDE', 'KN', 'SNII', 'SNIbc', 'Mdwarf-flare', 'PISN']:
        continue
    
    if sum(is_z)==0:
        print('... no redshifts, no need to plot.')
        
        modprobs[model] = np.zeros(len(bins))
        modprobs[model][0] = 1. / (bins[1]-bins[0])
        
        continue

    kde.fit(z[is_z][:,None])
    logprob = kde.score_samples(bins[:, None])
    modprobs[model] = np.exp(logprob)
    
    plt.figure()
    plt.title(model)
    
    plt.fill_between(bins, np.exp(logprob), alpha=0.5)
    
    plt.hist(z[is_z], bins=bins, density=True, stacked=False)
    plt.xlabel('z')
    plt.show()

In [None]:
modprobs

In [None]:
df_mod = pd.DataFrame.from_dict(modprobs)

In [None]:
df_mod.shape

In [None]:
df_mod.iloc[zbin]

In [None]:
tprob = np.array( df_mod.sum(axis=1) )

In [None]:
tprob

In [None]:
df_foo = df_mod.transpose()  / tprob

In [None]:
df_mod = df_foo.transpose()

In [None]:
# Ok, lets take a best fit redshift
z = 1.5
# Which bin? 
zbin = int( (z + 0.09) / 0.1 )
print(zbin)

In [None]:
df_mod.iloc[zbin]

In [None]:
df_mod

In [None]:
df_procent = df_mod * 1000

In [None]:
df_zmap = df_procent.astype(int)

In [None]:
print( df_zmap.to_dict() )

In [None]:
zmap = {'CART': {0: 9, 1: 124, 2: 135, 3: 143, 4: 143, 5: 127, 6: 97, 7: 62, 8: 33, 9: 16, 10: 8, 11: 6, 12: 6, 13: 5, 14: 5, 15: 6, 16: 7, 17: 6, 18: 3, 19: 1, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'uLens': {0: 304, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'SLSN': {0: 0, 1: 1, 2: 2, 3: 3, 4: 5, 5: 11, 6: 20, 7: 35, 8: 55, 9: 80, 10: 112, 11: 153, 12: 215, 13: 307, 14: 421, 15: 536, 16: 645, 17: 746, 18: 827, 19: 879, 20: 911, 21: 936, 22: 958, 23: 976, 24: 988, 25: 995, 26: 998, 27: 999, 28: 999, 29: 999, 30: 999, 31: 999, 32: 999, 33: 999, 34: 999, 35: 999}, 'dwarf-nova': {0: 304, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'SNIa91bg': {0: 5, 1: 78, 2: 100, 3: 127, 4: 151, 5: 156, 6: 132, 7: 88, 8: 45, 9: 17, 10: 6, 11: 3, 12: 2, 13: 2, 14: 1, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'SNIa': {0: 1, 1: 19, 2: 26, 3: 39, 4: 60, 5: 92, 6: 134, 7: 181, 8: 217, 9: 225, 10: 200, 11: 150, 12: 93, 13: 48, 14: 23, 15: 11, 16: 5, 17: 2, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'ILOT': {0: 24, 1: 257, 2: 223, 3: 173, 4: 114, 5: 59, 6: 23, 7: 6, 8: 1, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'SNIax': {0: 4, 1: 64, 2: 77, 3: 95, 4: 116, 5: 136, 6: 145, 7: 137, 8: 111, 9: 77, 10: 46, 11: 23, 12: 9, 13: 3, 14: 1, 15: 1, 16: 1, 17: 2, 18: 1, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'TDE': {0: 3, 1: 46, 2: 56, 3: 69, 4: 84, 5: 100, 6: 113, 7: 123, 8: 133, 9: 141, 10: 145, 11: 142, 12: 130, 13: 108, 14: 83, 15: 63, 16: 50, 17: 43, 18: 38, 19: 34, 20: 30, 21: 25, 22: 20, 23: 14, 24: 9, 25: 4, 26: 1, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'KN': {0: 29, 1: 293, 2: 234, 3: 166, 4: 98, 5: 45, 6: 15, 7: 3, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'SNII': {0: 3, 1: 47, 2: 59, 3: 75, 4: 94, 5: 112, 6: 126, 7: 131, 8: 127, 9: 117, 10: 105, 11: 95, 12: 86, 13: 77, 14: 70, 15: 66, 16: 63, 17: 59, 18: 48, 19: 33, 20: 21, 21: 11, 22: 5, 23: 2, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'SNIbc': {0: 4, 1: 60, 2: 73, 3: 89, 4: 105, 5: 119, 6: 127, 7: 128, 8: 121, 9: 107, 10: 90, 11: 72, 12: 53, 13: 36, 14: 25, 15: 18, 16: 13, 17: 8, 18: 3, 19: 1, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'Mdwarf-flare': {0: 304, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}, 'PISN': {0: 0, 1: 7, 2: 10, 3: 16, 4: 24, 5: 38, 6: 62, 7: 100, 8: 152, 9: 215, 10: 284, 11: 352, 12: 401, 13: 409, 14: 368, 15: 296, 16: 211, 17: 131, 18: 76, 19: 49, 20: 36, 21: 26, 22: 15, 23: 6, 24: 1, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0}}

Side study - out of all runs without an answer, what are they?

In [None]:
df_test = pd.DataFrame.from_dict(zmap)

In [None]:
df_test

In [None]:
# Ok, lets take a best fit redshift
z = 0.2
# Which bin? 
zbin = int( (z + 0.09) / 0.1 )
print(zbin)

In [None]:
df_test.iloc[zbin] / 1000

In [None]:
for mod in zmap.keys():
    print(mod, zmap[mod][zbin] / 1000 )
    

We now investigate a rate prior, based on BTS:
https://arxiv.org/pdf/2009.01242.pdf

We use the bright end (<18.5) relative rates. No idea whether this makes sense, but should not this brightness limited sample be the same as we expect at any cut?

In [None]:
# <18.5 cut
bts = {'SNIa':875, 'uLens':4, 'SLSN':19, 'dwarf-nova':4, 'SNIa91bg':0, 'ILOT':4, 'SNIax':0, 'TDE':5, 'KN':2, 'SNII':218, 'SNIbc':76, 'Mdwarf-flare':4, 'PISN':2, 'CART':4}

In [None]:
# Grabbed a late list of all SNIa, < 18.5 from bts
df_bts = pd.read_csv('/home/jnordin/tmp/btstmp.csv')

In [None]:
counts = df_bts.groupby(by='type').count()['ZTFID']

In [None]:
# Allow peculiars to be either, and add 91T to normal, and ignore SC
cnorm = counts['SN Ia'] + counts['SN Ia-91T']
cbg = counts['SN Ia-91bg'] + counts['SN Ia-pec']
cx = counts['SN Iax'] + counts['SN Ia-pec']
counts

In [None]:
frac_91bg = cbg / cnorm
frac_x = cx / cnorm

In [None]:
print(frac_91bg, frac_x)

We can thus _estimate_ that the _observed_ rate of both subtypes are around 1%.

In [None]:
bts['SNIa91bg'] = bts['SNIa'] * 0.01
bts['SNIax'] = bts['SNIa'] * 0.01
bts['SNIa'] = bts['SNIa'] * 0.98  # Petty, but why not?

In [None]:
# Normalize
n = sum( [v for v in bts.values()] )
print(n)

In [None]:
nbts = {}
for k, v in bts.items():
    nbts[k] = v/n

In [None]:
# To not make things impossible, we add a lower floor at 1%
fbts = {}
for k, v in nbts.items():
    if v>0.01:
        fbts[k] = v
    else:
        fbts[k] = 0.01

In [None]:
fbts

In [None]:
# Normalize again (last time...)
n = sum( [v for v in fbts.values()] )
print(n)
nbts = {}
for k, v in fbts.items():
    nbts[k] = v/n

In [None]:
# Create scaled int version
ibts = {}
for k, v in nbts.items():
    ibts[k] = int(v*1000)
    if ibts[k]<10:
        ibts[k] += 1   # (sum up to 10)

In [None]:
sum( [v for v in ibts.values()] )

In [None]:
print(ibts)