# Emulating realistically bad-for-cosmology SN Ia samples from PLAsTiCC data

_Alex I. Malz (GCCL@RUB)_

In [None]:
import collections
import gzip
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle as pkl

rando = 42

In [None]:
import proclam
from proclam.metrics.util import *
from proclam.metrics.util import RateMatrix

classes we care about

| `true_target`=`type` | `code` |
| -------------------- | ------ |
| 90 | SNIa |
| 67 | SNIa-91bg |
| 52 | SNIax |
| 42 | SNII |
| 62 | SNIbc |
| 95 | SLSN-I |
| 88 | AGN |

In [None]:
maybe_sn_classes = {90: 'SNIa', 
                    67: 'SNIa-91bg', 
                    52: 'SNIax', 
                    42: 'SNII', 
                    62: 'SNIbc', 
                    95: 'SLSN-I', 
                    88: 'AGN'}
maybe_sn_classes[15] = 'TDE'
maybe_sn_classes[64] = 'KN'

sel_class = 90

## gather all available lightcurves

In [None]:
datapath = '/media/RESSPECT/data/PLAsTiCC/'

other than intermediate data products, work in `/media/RESSPECT/data/PLAsTiCC/for_metrics/`

In [None]:
all_lcs = pd.read_csv(datapath+'PLAsTiCC_zenodo/plasticc_test_metadata.csv')
all_lcs = all_lcs.rename(columns={"object_id": "id", "true_z": "redshift", "true_target": "code"})

In [None]:
field_info0 = {}
for field in ['ddf', 'wfd']:
    field_info0[field] = {}
    field_info0[field]['true_cat'] = all_lcs.loc[all_lcs['ddf_bool'] == (field == 'ddf')][['id', 'redshift', 'code']]
    field_info0[field]['n_tot_cat'] = len(field_info0[field]['true_cat'])
    field_info0[field]['n_each_cat'] = dict(field_info0[field]['true_cat'].groupby('code').count()['id'])
    print(field_info0[field]['n_each_cat'])

In [None]:
# n_class_pos = 3000

def gen_nums(field_info, field, n_class_pos=3000, sel_class=90, classmap=maybe_sn_classes):

    field_info[field]['pos'] = field_info[field]['n_each_cat'][sel_class]
    field_info[field]['glob'] = sum([field_info[field]['n_each_cat'][classid] for classid in classmap])
    field_info[field]['neg'] = field_info[field]['glob'] - field_info[field]['pos']
    
    n_class_glob = n_class_pos * field_info[field]['glob'] / field_info[field]['pos']
    field_info[field]['class_all'] = {classid: int(round(n_class_glob * field_info[field]['n_each_cat'][classid] / field_info[field]['glob'])) for classid in classmap}
    field_info[field]['class_glob'] = sum(field_info[field]['class_all'].values())
    print(field_info[field]['class_all'])
    
    return field_info[field].copy()

In [None]:
field_info1, field_info3, field_info6 = {}, {}, {}
field_defs = [int(1e3), int(3e3), int(6e3)]
for field in ['ddf', 'wfd']:
    field_info1[field] = gen_nums(field_info0, field, n_class_pos=field_defs[0])
    field_info3[field] = gen_nums(field_info0, field, n_class_pos=field_defs[1])
    field_info6[field] = gen_nums(field_info0, field, n_class_pos=field_defs[2])

TODO: sample from those that survived but use ratios from overall, SALT2 failures are class dependent/independent, cadence dependent/independent, can't disentangle (because this freebie classification criterion will not always be true with SALT3 etc., the detection ratios will depend on cadence), from our perspective contamination can't be worse, this is worst case for single contaminant

In [None]:
lc_fit_filter_ddf = pd.read_csv(datapath+'for_metrics/ddf/samples/all_objs_survived_SALT2_DDF.csv')
lc_fit_filter_wfd = pd.read_csv(datapath+'for_metrics/wfd/samples/all_objs_survived_SALT2_WFD.csv')

In [None]:
all_maybe_sn_ddf = pd.merge(lc_fit_filter_ddf['id'], all_lcs[['id', 'redshift', 'code']], on=['id'])
all_maybe_sn_wfd = pd.merge(lc_fit_filter_wfd['id'], all_lcs[['id', 'redshift', 'code']], on=['id'])

TODO: plot "confusion matrix" based on SALT2 fit success

In [None]:
all_maybe_sn = {'ddf': all_maybe_sn_ddf, 'wfd': all_maybe_sn_wfd}
for info in [field_info1, field_info3, field_info6]:
    for field in ['ddf', 'wfd']:
        count_surv = dict(all_maybe_sn[field].groupby('code').count()['id'])
        info[field]['class_avail'] = {classid: 0 for classid in maybe_sn_classes.keys()}
        for classid in maybe_sn_classes.keys():
            if classid in count_surv.keys():
                info[field]['class_avail'][classid] = count_surv[classid]
        print(info[field]['class_avail'])

~~a priori all samples will be 3000 "classified SN Ia"~~ TODO: try with `n_class_pos`=3e4, 3e5 for the three dummy cases

## subsample the classes to make new samples

TODO: maybe investigate redshift distribution of sample classified as Ia?

~~TODO: make a table of all test cases~~

In [None]:
savepath = '/media/RESSPECT/data/PLAsTiCC/for_metrics/'
savepaths = {}
for field in ['ddf', 'wfd']:
    savepaths[field] = savepath + field + '/samples/'

### get sample ids matching a confusion matrix

To calculate the true/false positive/negative rates along the way to making the subsamples, we need a notion of negatives that would never end up in the cosmology sample.
Let's use the DDF type ratios to figure out how many objects will be classified as negative for our samples of 3000 positive classifications.

`n_class_all` contains the number of objects in the true population, and the confusion matrix tells us how many will end up being classified as positive or negative

save outputs as `id,redshift,type,code,orig_sample=test,queryable=True`

In [None]:
def subsample_cat(cm, cm_indices, ntot, cat, surv=None, rando=rando,
                  pos_key=90, where_to_save=None, save_neg=True, force=False):
    if surv is None:
        surv = cat.copy()
    #     print(ntot[pos_key])
    # normalize to number in true class
#     print(cm)
    pcm = (cm.T / np.sum(cm, axis=1)).T
#     pcm = cm / np.sum(cm, axis=0)
#     print(np.sum(pcm))
    # want row corresponding to predicted class
    pos_row = pcm[cm_indices[pos_key]] * ntot[pos_key]
    pos_row = [int(round(i)) for i in pos_row]
    pos_ids, neg_ids = pd.DataFrame(columns=cat.columns), pd.DataFrame(columns=cat.columns)
    err = 0
    for typeid in cm_indices.keys():
        n_crit = len(surv[surv['code'] == typeid])
#         print((typeid, pos_row[cm_indices[typeid]], n_pos))
#         print(('debug', surv[surv['code'] == typeid]))
        
        n_pos = pos_row[cm_indices[typeid]]
        if n_pos > ntot[typeid]:
            print(f'cannot draw {n_pos} {typeid} from existing {ntot[typeid]} in {pos_row}')
            n_pos = ntot[typeid]
            err = 1
        if n_pos > n_crit:
            print(f'cannot draw {n_pos} {typeid} from surviving {n_crit} in {pos_row}')
            n_pos = n_crit
            err = 1
            
        n_neg = ntot[typeid] - n_pos
        if n_neg > ntot[typeid]:
            print(f'cannot draw {n_neg} {typeid} from negative {ntot[typeid]}')
            n_neg = ntot[typeid]
            err = 1
            
        print((n_pos, n_neg))

#         print(len(cat[cat['code'] == typeid]))
        pos = surv[surv['code'] == typeid].sample(n=n_pos, random_state=rando)
        neg = cat[cat['code'] == typeid].sample(n=n_neg, random_state=rando, replace=True)
#         pos = matches[:n_pos]
#         neg = matches[n_pos:]
        if len(pos) > 0:
            pos_ids = pos_ids.append(pos)
        if len(neg) > 0:
            neg_ids = neg_ids.append(neg)
    # special checks for edge cases on rounding errors! only matters when more than 2 classes present
    n_err = int(round(np.sum(pos_row)) - len(pos_ids))
    print('err=' + str(n_err))
    if n_err > 0:
        bonus = surv[(surv['code'] == pos_key) & (~surv.id.isin(pos.id))].sample(n=n_err, random_state=rando)
#         print(bonus)
        pos_ids = pos_ids.append(bonus)
#         print((len(pos_ids),  int(round(np.sum(pos_row)))))
#         print(pos_ids[-1 * err:])
#         neg_ids = neg_ids[err:]
#     print((len(pos_ids),  int(round(np.sum(pos_row)))))
    if n_err < 0:
        drop_indices = np.random.choice(pos_ids[pos_ids['code'] == pos_key].index, -1 * err, replace=False)
        pos_ids = pos_ids.drop(drop_indices)
    assert(len(pos_ids) == int(round(np.sum(pos_row))))
#     assert(len(pos_ids) + len(neg_ids) == np.sum(np.array([ntot[typeid] for typeid in cm_indices.keys()])))
    if where_to_save:
#         if err == 1:
#             where_to_save += 'fail' + str(len(pos_ids))
        pos_ids['orig_sample'] = 'test'
        pos_ids['queryable'] = True
        pos_ids['type'] = None
        pos_ids[['id','redshift','type','code','orig_sample','queryable']].to_csv(where_to_save+'.csv', index=False)
    return pos_ids, neg_ids

### realistic classifier

start from fiducial contamination rates from ~~a real (awful) confusion matrix at `/media/RESSPECT/data/PLAsTiCC/for_metrics/confusion_matrices`~~ Avocado

~~These were just the test set lightcurves for classes (67, 88, 42(minus 7?), 90(minus 11?), 52, 62, 64, 95, 15) from ddf-only~~

~~figure out classes in confusion matrix by comparing number of ddf test set-only lightcurves~~

In [None]:
fid_cm = np.loadtxt('confusion_matrix_no_galactic_kb.txt')[:, :-6]
# with open(savepath+'confusion_matrices/confusion_matrix.npy', 'rb') as confmat:
#     fid_cm = np.load(confmat)
plt.imshow(fid_cm)
plt.colorbar()

In [None]:
# cm_classes = [67, 88, 42, 90, 52, 62, 64, 95, 15]
cm_classes = [90, 67, 52, 42, 62, 95, 15, 64, 88]
cm_indices = {}
for classid in maybe_sn_classes.keys():
    cm_indices[classid] = cm_classes.index(classid)

fiducial sample corresponding to input confusion matrix

In [None]:
for i, field_info in enumerate([field_info1, field_info3, field_info6]):
    for field in ['ddf', 'wfd']:
#         print(field_info[field]['class_all'])
        fiducial = subsample_cat(fid_cm, cm_indices, field_info[field]['class_all'], field_info[field]['true_cat'], surv=all_maybe_sn[field],
                             where_to_save=savepaths[field]+'fiducial'+str(field_defs[i]))
        if fiducial is not None:
            print(len(fiducial[0]))
#     print(len(fiducial[0][fiducial[0]['code'] == sel_class])+len(fiducial[1][fiducial[1]['code'] == sel_class]))
#     print((fiducial[0]['code'].value_counts(), fiducial[1]['code'].value_counts()))

### 100% SNIa sample

In [None]:
perf_cm = np.identity(len(cm_indices.keys()))
for i, field_info in enumerate([field_info1, field_info3, field_info6]):
    for field in ['ddf', 'wfd']:
        perfect = subsample_cat(perf_cm, cm_indices, field_info[field]['class_all'], field_info[field]['true_cat'], surv=all_maybe_sn[field],
                            where_to_save=savepaths[field]+'perfect'+str(field_defs[i]))#, rando=999)

### random/guessing/uncertain classifier

In [None]:
for i, field_info in enumerate([field_info1, field_info3, field_info6]):
    for field in ['ddf', 'wfd']:
        rand_cm = np.ones((len(cm_indices.keys()), len(cm_indices.keys()))) / len(cm_indices.keys())**2
        rand_cm *= np.array([field_info[field]['n_each_cat'][key] for (key, val) in 
                         sorted(cm_indices.items(), key=lambda x: x[1])])
        rand_cm = rand_cm.T / np.sum(rand_cm)
        rand_cm *= np.array([field_info[field]['class_all'][key] for (key, val) in 
                         sorted(cm_indices.items(), key=lambda x: x[1])])
        rand_cm = rand_cm.T / np.sum(rand_cm)
        guesser = subsample_cat(rand_cm, cm_indices, field_info[field]['class_all'], field_info[field]['true_cat'], surv=all_maybe_sn[field],
                            where_to_save=savepaths[field]+'random'+str(field_defs[i]))

## evaluate classification metrics on the subsamples

better to do it along the way to making the subsamples, especially important for non-extreme subsamples filling the space of classification metric values

first get rates using `proclam` functionality

In [None]:
def cat_to_rate(pos_ids, neg_ids, pos_key=sel_class):
    pos_ids['classed'] = True
    neg_ids['classed'] = False
    whole_samp = pd.concat((pos_ids, neg_ids))
    whole_samp['truth'] = None
    whole_samp['truth'][whole_samp['code'] != pos_key] = False
    whole_samp['truth'][whole_samp['code'] == pos_key] = True
    bin_cm = det_to_cm(whole_samp['classed'].to_numpy(), whole_samp['truth'].to_numpy())
    rawrate = cm_to_rate(bin_cm)._asdict()
    rel_to_sel = {key: rawrate[key][0] for key in rawrate.keys()}
    rate = proclam.util.RateMatrix(**rel_to_sel)
    return rate

### calculate all the metrics!

TODO: put some version of this into `proclam` at some point!

In [None]:
class det_mets(RateMatrix):
    "binary classification metrics"
    def __init__(self, **rates):
        """
        Call like `thing = det_mets(**rates._asdict())`
        """
#         self.rates = rates#.asdict()
        self._get_tots()
        self._from_rates()
        self._sn_mets()
        self._translate()
    def _get_tots(self):
        self.CP = self.TP + self.FN
        self.CN = self.TN + self.FP
        self.T = self.TP + self.TN
        self.F = self.FP + self.FN
        self.P = self.TP + self.FP
        self.N = self.TN + self.FN
    def _from_rates(self):
        self.PPV = self.TP / (self.TP + self.FP)
        self.NPV = self.TN / (self.TN + self.FN)
        self.PT = (np.sqrt(self.TPR * (1. - self.TNR)) + self.TNR - 1.) / (self.TPR + self.TNR - 1.)
        self.TS = self.TP / (self.TP + self.FN + self.FP)
        self._derived()
    def _derived(self):
        self.ACC = (self.TP + self.TN) / (self.CP + self.CN)
        self.BA = (self.TPR + self.TNR) / 2,
        self.F1S = 2. * self.PPV * self.TPR / (self.PPV + self.TPR)
        self.MCC = (self.TP * self.TN - self.FP * self.FN) / (np.sqrt(self.P * self.CP * self.CN * self.N))
        self.FM = np.sqrt(self.PPV * self.TPR)
        self.BM = self.TPR + self.TNR - 1.
        self.MK = self.PPV + self.NPV - 1.
    def _translate(self):
        self.positive = self.CP
        self.negative = self.CN
        self.sensitivity = self.TPR
        self.recall = self.TPR
        self.specificity = self.TNR
        self.selectivity = self.TNR
        self.precision = self.PPV
        self.FDR = 1. - self.PPV
        self.FOR = 1. - self.NPV
        self.CSI = self.TS
        self.accuracy = self.ACC
        self.f1_score = self.F1S
        self.informedness = self.BM
        self.deltaP = self.MK
    def _sn_mets(self):
        self.get_efficiency()
        self.get_purity()
    def get_efficiency(self):
        self.efficiency = self.TP / self.CP
        return self.efficiency
    def get_purity(self):
        self.purity = self.TP / self.P
        return self.purity
    def get_fom(self, penalty):
        self.pseudo_purity = self.TP / (self.TP + penalty * self.FP)
        return self.pseudo_purity * self.efficiency

demonstrate on the archetypes (broken at the moment due to flagging when not enough of contaminant in "closed universe"

In [None]:
# for field in ['ddf', 'wfd']:
#     print(field)
#     for cm in [cm_perfect, cm_almost, cm_noisy, cm_uncertain]:
#         pos, neg = subsample_cat(cm, cm_indices, field_info[field]['class_all'], field_info[field]['true_cat'])
#         rates = cat_to_rate(pos, neg)
#         mets = det_mets(**rates._asdict())
#         print(f'purity:{mets.purity}, efficiency:{mets.efficiency}, fom1:{mets.get_fom(1.)}, fom3:{mets.get_fom(3.)}')

## next, make samples corresponding to metric values

original plan was to have these samples:
- 100% Ia
- Ia/Ibc
- - 50/50
- - 75/25
- - 90/10
- - 95/5
- - 98/2
- Ia/II
- Ia/91bg
- Ia/Iax
- AGN
- TDE 
- KN

In [None]:
ia_percents = np.array([50, 68, 75, 90, 95, 98, 99])
mix_percents = 100 - ia_percents
contaminants = maybe_sn_classes.copy()
contaminants.pop(sel_class)
metpaths = {field: savepath+field+'/metrics/' for field in ['ddf', 'wfd']}

assume symmetry in 2-class mix

In [None]:
# binary_ia_mets = {}
cols = ['field', 'contaminant', 'percent', 'inloc', 'name', 'f1', 'purity', 'efficiency', 'fom1', 'fom3'] + [key for key in RateMatrix.__dict__.keys() if key[0] != '_']
directory = pd.DataFrame(columns=cols)
field_info = field_info3
n_class_pos = field_defs[1]
for field in ['ddf', 'wfd']:
    for key, val in contaminants.items():
        subset_indices = {sel_class: 0, key: 1}
        crit = math.floor(field_info[field]['class_all'][key] / n_class_pos * 100)
        print(f'cannot have more than {crit} percent of {val}')
        for i, perc in enumerate(mix_percents):
            print(f'seeking {perc * n_class_pos / 100} of '+str(field_info[field]['class_all'][key]))
            if perc > crit:
#                 perc = 100 - math.floor(field_info[field]['class_all'][key] / n_class_pos * 100.)
                perc = crit
#             else:
#                 sampfn = savepaths[field]+str(ia_percents[i])+str(maybe_sn_classes[sel_class])+str(perc)+val
#                 cm = np.array([[ia_percents[i], perc], [perc, ia_percents[i]]])
            fn = str(100 - perc)+str(maybe_sn_classes[sel_class])+str(perc)+val
            sampfn = savepaths[field]+fn
            cm = np.array([[100 - perc, perc], [perc, 100 - perc]])
            pos, neg = subsample_cat(cm, subset_indices, field_info[field]['class_all'], field_info[field]['true_cat'], surv=all_maybe_sn[field], 
                                     where_to_save=sampfn)#where_to_save=None)#
            rates = cat_to_rate(pos, neg)
            mets = det_mets(**rates._asdict())
            
            metfn = metpaths[field]+fn#f'{100-perc}_{sel_class}_{perc}_{key}'
            prelim = [mets.f1_score, mets.purity, mets.efficiency, mets.get_fom(1.), mets.get_fom(3.)]
            print(f'{metfn} = F1:{mets.f1_score}, purity:{mets.purity}, efficiency:{mets.efficiency}, fom1:{mets.get_fom(1.)}, fom3:{mets.get_fom(3.)}')
            with open(metfn+'.pkl', 'wb') as metfile:
                pkl.dump(rates._asdict(), metfile)
                print('success for '+metfn)
            thisloc = len(directory)
            directory.loc[thisloc] = [field, val, perc, metfn, fn] + prelim + [rates._asdict()[key] for key in rates._asdict().keys()]
directory = directory.drop_duplicates(ignore_index=True)
directory.to_csv(savepath+'directory.csv', index=False)

NEXT: plot redshift distribution of all samples

TODO: consider nontrivial mixes

## create new confusion matrices to tune output sample rates

consider `proclam` classifier archetypes for inspiration

In [None]:
# M_classes = len(cm_indices)

# # 'Uncertain' --> 'Random'
# cm_uncertain = np.ones((M_classes, M_classes))

# # 'Perfect'
# cm_perfect = np.eye(M_classes) + 1.e-8

# # 'Almost'
# cm_almost = np.eye(M_classes) + 0.1 * np.ones((M_classes, M_classes))

# # 'Noisy'
# cm_noisy = np.eye(M_classes) + 0.5 * np.ones((M_classes, M_classes))

# # # 'Tunnel Vision'
# # cm = np.ones((M_classes, M_classes))
# # cm = cm * np.asarray(0.1)[np.newaxis, np.newaxis]
# # cm[:, chosen] = cm[:, chosen] / M_classes
# # cm[chosen][chosen] += M_classes

# # # 'Cruise Control'
# # cm = np.eye(M_classes) + 1.e-8
# # cm[:] = cm[chosen]

# # # 'Subsuming'
# # cm = np.eye(M_classes) + 0.1 * np.ones((M_classes, M_classes))
# # cm[chosen] = cm[chosen-1]

# # # 'Mutually Subsuming'
# # cm = np.eye(M_classes) + 0.1 * np.ones((M_classes, M_classes))
# # cm[chosen][chosen+1] = cm[chosen][chosen]
# # cm[chosen+1][chosen] = cm[chosen+1][chosen+1]

In [None]:
# # 'Mutually Subsuming'
# target = cm_indices[sel_class]
# contaminant = cm_indices[62]
# half_ibc_cm = np.eye(M_classes) + 0.1 * np.ones((M_classes, M_classes))
# half_ibc_cm[target][contaminant] = half_ibc_cm[target][target]
# half_ibc_cm[contaminant][target] = half_ibc_cm[contaminant][contaminant]
# # plt.imshow(half_ibc_cm)
# # plt.colorbar()

make new confusion matrices as mixtures of existing ones

In [None]:
# def mix_arr(inarrs, weights=None):
#     narrs = len(inarrs)
#     if weights is None:
#         weights = np.ones_like((1, narrs))
#     arrs = inarrs / np.sum(np.sum(inarrs, axis=-1), axis=-1)[:, np.newaxis, np.newaxis]
#     normwts = weights / np.sum(weights)
#     outarr = np.sum(arrs * normwts[:, np.newaxis, np.newaxis], axis=0)
#     return outarr

In [None]:
# new_cm = mix_arr(np.array([cm_uncertain, cm_perfect]))
# plt.imshow(new_cm)
# plt.colorbar()