## Notebook to plot the results of the fitting routine tests vs the same results for WFC3
### Limited to the cases with background

In [None]:
import matplotlib.pyplot as plt
import pickle, bz2, glob
import numpy as np
from scipy.interpolate import interp1d
import pandas as pd

%matplotlib notebook

#### Setup cell

In [None]:
dirsave = '/user/gennaro/Functional_work/Up_the_ramp_myfork/Simulations_results/'
files = glob.glob(dirsave+'*BKG*in.pbz2')
testnames = []
for file in files:
    testnames.append(file.split('/')[-1][5:-8])
    print(testnames[-1])

In [None]:
n_ord = 8          

In [None]:
testname = testnames[n_ord]
outputs_file = dirsave+'Test_'+testname+'_out.pbz2'
inputs_file  = dirsave+'Test_'+testname+'_in.pbz2'
calwf3out_file = dirsave+'Test_'+testname+'_in.JSON'

#### Restore the saved files and prepare all variables

In [None]:
with bz2.BZ2File(outputs_file, 'rb') as f:
    dictoload = pickle.load(f)

goodints_l        =  dictoload['goodints_l']
counter_l         =  dictoload['counter_l']
error_l           =  dictoload['error_l']
crloops_counter_l =  dictoload['crloops_counter_l']
outerate_l        =  dictoload['outerate_l']
gof_stat_l        =  dictoload['gof_stat_l']
gof_pval_l        =  dictoload['gof_pval_l']


with bz2.BZ2File(inputs_file, 'rb') as f:
    dictoload = pickle.load(f)

meas_l     = dictoload['meas_l']
myfluxes   = dictoload['myfluxes']
myramps    = dictoload['myramps']
myCRrates  = dictoload['myCRrates']
mybgs      = dictoload['mybgs']
CRdict_l   = dictoload['CRdict_l']
extra_bg_l = dictoload['extra_bg_l']

del(dictoload)

#Unpack the lists to regroup items by ramp

ntest = len(meas_l)//len(myramps)

gi_list  = [np.empty([ntest,mm.group_times.size-1],dtype=np.bool_) for mm in myramps]
CR_list   = [[[] for _ in range(ntest)] for mm in myramps]
meas_list = [[[] for _ in range(ntest)] for mm in myramps]
ebg_list  = [[[] for _ in range(ntest)] for mm in myramps]

counter  = np.empty([ntest,len(myfluxes)],dtype=np.int_)
error    = np.empty([ntest,len(myfluxes)],dtype=np.int_)
outerate = np.empty([ntest,len(myfluxes)])
crloops_counter  = np.empty([ntest,len(myfluxes)],dtype=np.int_)
gof_stat = np.empty([ntest,len(myfluxes)])
gof_pval = np.empty([ntest,len(myfluxes)])

for l,(g,C,m,c,e,cc,o,gs,gp,eb) in enumerate(zip(goodints_l,CRdict_l,meas_l,counter_l,error_l,crloops_counter_l,outerate_l,gof_stat_l,gof_pval_l,extra_bg_l)):

    k = l % ntest
    j = l // ntest
    
    counter[k,j]  = c
    error[k,j]    = e
    crloops_counter[k,j] = cc
    outerate[k,j] = o
    gof_stat[k,j] = gs
    gof_pval[k,j] = gp

    gi_list[j][k,:] = g
    CR_list[j][k]   = C
    meas_list[j][k] = m
    ebg_list[j][k]  = eb
    


#### Restore calwf3 results

In [None]:
#Note, we put the df in a list to conform with the multiple-ramps mode for the non-calwf3 experiments

calwf3df = [pd.read_json(calwf3out_file)]
calwf3df[0].head()

#### Regroup all the simulated CRhits and check whether they have been detected

In [None]:
allCRtimes = []
allCRcounts = []
allCRdetect = []
allCRdetect_C3 = []

ramps_with_CRs = []

for j,(myflux,myramp) in enumerate(zip(myfluxes,myramps)):

    allCRtimes_p = []
    allCRcounts_p = []
    allCRdetect_p = []
    allCRdetect_C3_p = []
    ramps_with_CRs_p = np.zeros(len(CR_list[j]),dtype=np.bool_)
    
    CRlist_p = CR_list[j]
    for i,dd in enumerate(CRlist_p):
        if dd is not None:
            ramps_with_CRs_p[i] = True
            allCRtimes_p.extend(dd['times'])
            allCRcounts_p.extend(dd['counts'])
            for t in dd['times']:
                hit_intv_idx = np.nonzero(myramp.group_times <= t)[0][-1]
                if gi_list[j][i,hit_intv_idx] ==  False:
                    allCRdetect_p.append(True)
                else:
                    allCRdetect_p.append(False)

                if calwf3df[j].loc[i,'TRUTH_VALUE'][hit_intv_idx] ==  False:
                    allCRdetect_C3_p.append(True)
                else:
                    allCRdetect_C3_p.append(False)

                    
            
    
    #print(allCRcounts_p)
    #print(allCRtimes_p)
    allCRtimes.append(np.asarray(allCRtimes_p))
    allCRcounts.append(np.asarray(allCRcounts_p))
    allCRdetect.append(np.asarray(allCRdetect_p))
    allCRdetect_C3.append(np.asarray(allCRdetect_C3_p))
    
    ramps_with_CRs.append(ramps_with_CRs_p)


#### Look for false positives in the detected CRhits

In [None]:
CR_false_positives = []
CR_false_positives_C3 = []

for j,(myflux,myramp) in enumerate(zip(myfluxes,myramps)):
    CR_false_positives_p = np.zeros_like(gi_list[j],dtype=np.bool_)
    CR_false_positives_p_C3 = np.zeros_like(gi_list[j],dtype=np.bool_)
    print('*************')
    print(myflux,myramp.ngroups)
    for i in range(ntest):
        detected_CR_idx = np.nonzero(~gi_list[j][i,:])[0]
        ndet = len(detected_CR_idx)
        if ndet > 0:
            CRdict  = CR_list[j][i]
            for k in range(ndet):
                if CRdict is None:
                    CR_false_positives_p[i,detected_CR_idx[k]] = True
                else:   
                    ts = myramp.group_times[detected_CR_idx[k]]
                    te = myramp.group_times[detected_CR_idx[k]+1]
                    if np.all( (CRdict['times']< ts) | ((CRdict['times']> te))) == True:
                        CR_false_positives_p[i,detected_CR_idx[k]] = True

        detected_CR_idx_C3 = np.nonzero(~np.asarray((calwf3df[j].loc[i,'TRUTH_VALUE'][:])))[0]
        ndet_C3 = len(detected_CR_idx_C3)
        if ndet_C3 > 0:
            CRdict  = CR_list[j][i]
            for k in range(ndet_C3):
                if CRdict is None:
                    CR_false_positives_p_C3[i,detected_CR_idx_C3[k]] = True
                else:   
                    ts = myramp.group_times[detected_CR_idx_C3[k]]
                    te = myramp.group_times[detected_CR_idx_C3[k]+1]
                    if np.all( (CRdict['times']< ts) | ((CRdict['times']> te))) == True:
                        CR_false_positives_p_C3[i,detected_CR_idx_C3[k]] = True

                        
    print('FP_new',np.sum(CR_false_positives_p))       
    CR_false_positives.append(CR_false_positives_p)
    print('FP_C3',np.sum(CR_false_positives_p_C3))       
    CR_false_positives_C3.append(CR_false_positives_p_C3)
        

### Diagnostic plots

#### Global diagnostics

In [None]:
plt.style.use('bmh')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.labelweight'] = 'normal'
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 13

In [None]:
lw=1.2
colT = 'orange'
colW = '#A60628'
colN = '#348ABD'

for j,(myflux,myramp) in enumerate(zip(myfluxes,myramps)):
    
    BM0 = (error[:,j] == 0) #& (counter[:,j] > 15)
    mn = np.mean(outerate[BM0,j])
    md = np.median(outerate[BM0,j])
    st = np.std(outerate[BM0,j])

    C3_out = calwf3df[j]['FLT_PIXEL_VALUE'].values
    mn_C3 = np.mean(C3_out)
    md_C3 = np.median(C3_out)
    st_C3 = np.std(C3_out)

    vall = np.hstack([outerate[BM0,j],C3_out])
    
    mode_rng = np.percentile(vall, (25, 75))
    print('MD rng:',mode_rng)
    mode_bins=100
    
    f = plt.figure(figsize=(12,3.5))
    
#    minbin = md_C3-3*st_C3#np.min(np.hstack([outerate[BM0,j],C3_out]))
#    maxbin = md_C3+4*st_C3#np.max(np.hstack([outerate[BM0,j],C3_out]))
    minbin,maxbin = np.percentile(vall, (.15,99.85))
    print('H range: ',minbin,maxbin)
    
    nbh = 50
    
    ax1 = f.add_subplot(1,3,1)
    n,b,p = ax1.hist(outerate[BM0,j],bins=np.linspace(minbin,maxbin,nbh),histtype='step',label='This Work',linewidth=lw)
    ax1.hist(C3_out,bins=b,histtype='step',label='calwf3',linewidth=lw)

    h,b = np.histogram(outerate[BM0,j], bins=mode_bins, range=mode_rng)    
    imode = np.argmax(h)
    mode = 0.5*(b[imode]+b[imode+1])

    h,b = np.histogram(C3_out, bins=mode_bins, range=mode_rng)    
    imode = np.argmax(h)
    mode_C3 = 0.5*(b[imode]+b[imode+1])
    
    if mybgs[j] is not None:
        mean_bg_er = mybgs[j]['mean_bg_er']
        ax1.axvline(myflux+mean_bg_er,color='orange',label='Truth+mean background')
    else:
        mean_bg_er = None
        a_rms = None
        ax1.axvline(myflux,color='orange',label='Truth')
        
#    ax1.axvline(mode,color='b',label='New, Mode')
#    ax1.axvline(mn,color='b',label='New, Mean',linestyle='--')
#    ax1.axvline(md,color='b',label='New, Median',linestyle=':')
#    ax1.axvline(mn-st,color='b',linestyle='--')
#    ax1.axvline(mn+st,color='b',linestyle='--')
#    ax1.axvline(mode_C3,color='r',label='calwf3, Mode')
#   ax1.axvline(mn_C3,color='r',label='calwf3, Mean',linestyle='--')
#    ax1.axvline(md_C3,color='r',label='calwf3, Median',linestyle=':')
#    ax1.axvline(mn_C3-st_C3,color='r',linestyle='--')
#    ax1.axvline(mn_C3+st_C3,color='r',linestyle='--')
    #ax1.axvline(md,color='black',label='New, Median')
    
    
    ax1.set_title('Fit output',fontsize=13)
    ax1.set_xlabel('e/s')
    ax1.legend()


    ax2 = f.add_subplot(1,3,2)   
    pthr = 0.1
    b = np.linspace(0,1,100)
    n,b,p = ax2.hist(gof_pval[:,j][np.isfinite(gof_pval[:,j])],bins=b,linewidth=lw,histtype='step')
#    n,b,p = ax2.hist(gof_pval[:,j][np.isfinite(gof_pval[:,j]) & rCRs],bins=b,linewidth=lw,histtype='step',label='YES-CR')
#    n,b,p = ax2.hist(gof_pval[:,j][np.isfinite(gof_pval[:,j]) & (~rCRs)],bins=b,linewidth=lw,histtype='step',label='NO-CR')
#    n,b,p = ax8.hist(gof_pval[:,j][np.isfinite(gof_pval[:,j]) & (np.asarray([kk is not None for kk in CR_list[j]]))],bins=np.linspace(0,1,50),linewidth=lw,histtype='step',label='YES-CR')
#    n,b,p = ax6.hist(gof_stat[:,j]/(np.sum(gi_list[j],axis=1)-1),bins=np.linspace(0,3,50))
#    n,b,p = ax6.hist(gof_stat[:,j],bins=50)
#    ax6.scatter(gof_pval[:,j],gof_stat[:,j])
    
    ax2.set_title('Goodness of fit',fontsize=13)
#    ax2.axvline(pthr,color='#22aa12',linestyle='--')
    ax2.set_xlabel('p-val')
    ax2.set_yscale('log')
    ax2.set_ylim(0.9,7500)
    ax2.legend()
    
    BMpt = gof_pval[:,j] > pthr
    


    if mybgs[j] is not None:
        ax3 = f.add_subplot(1,3,3)
        bg_times = mybgs[j]['times']
        bg_electron_rate = mybgs[j]['vbg_er']

        bg_int = interp1d(bg_times,bg_electron_rate,'quadratic')
        varbg = bg_int(myramps[j].read_times)

        dt = myramp.read_times[-1]-myramp.read_times[0]
        t_avg = np.trapz(varbg,myramps[j].read_times) / dt
        varbg = varbg/t_avg * mean_bg_er
        a_rms = np.sqrt(np.trapz(np.square(varbg-mean_bg_er),myramps[j].read_times)/dt)

        ax3.set_title('Mean Background rate (e/s)={:7.4f} \n rms/mean ={:7.4f}'.format(mean_bg_er,a_rms/mean_bg_er),fontsize=13)
        ax3.set_xlabel('s')
        ax3.plot(myramp.read_times,varbg);
        ax3.axhline(mean_bg_er,linestyle='--')
#        sts = r'Flux (e/s) = {}; T$_{{exp}} = ${}; N$_{{g}}$, N$_{{f}}$, N$_{{s}}$ = {}, {}, {}; Extra background (e/s) = {}'.format(myflux,myramp.group_times[-1],myramp.ngroups,myramp.nframes,myramp.nskips,mean_bg_electron_rate)

        for ax,fac in zip([ax1,ax2,ax3],[1.5,1.1,1]):
            ax.set_axis_bgcolor('#FFFFFF')
            ax.set_ylim(ax.get_ylim()[0],fac*ax.get_ylim()[1])    
    else:
#        sts = r'Flux (e/s) = {}; T$_{{exp}} = ${}; N$_{{g}}$, N$_{{f}}$, N$_{{s}}$ = {}, {}, {}; No extra background'.format(myflux,myramp.group_times[-1],myramp.ngroups,myramp.nframes,myramp.nskips)
        for ax,fac in zip([ax1,ax2],[1.5,1.1]):
            ax.set_axis_bgcolor('#FFFFFF')
            ax.set_ylim(ax.get_ylim()[0],1.2*ax.get_ylim()[1])    

        
    ######################################################
        
    
#    f.suptitle(sts, fontsize=14)
#    f.tight_layout(rect=[0,0.03,1,.92])
    f.tight_layout(rect=[0,0.03,1,.98])
    f.savefig('/user/gennaro/Functional_work/WFC3/ISRs/Up_the_ramp_fitting/Figs/Test_'+testname+'_diagnostic_plots_'+str(j)+'.pdf') 

    
    actual_counts = np.empty(len(meas_list[j]))
    noisy_counts = np.empty(len(meas_list[j]))
    for k in range(actual_counts.size):
        actual_counts[k] = meas_list[j][k].noiseless_counts[-1]
        noisy_counts[k] = meas_list[j][k].noisy_counts[-1] - meas_list[j][k].cum_CR_counts[-1] - meas_list[j][k].noisy_counts[0]
    
    
    exptime = myramp.group_times[-1] - myramp.group_times[0] 
    mean_signal = myflux * exptime
    poi_err = np.sqrt(mean_signal)
    act_poi_err = np.std(actual_counts)*meas_list[j][0].gain
    noisy_counts_err = np.std(noisy_counts)*meas_list[j][0].gain
    
    eff_RON = meas_list[j][0].effRON_e
    eff_qerr = np.sqrt(meas_list[j][0].gain*myramp.nframes/12) * eff_RON/meas_list[j][0].RON_e

    tot_noise = np.sqrt(np.sum(np.square(np.array([poi_err,eff_RON,eff_qerr]))))

    ds = {'exptime':exptime,
          'myflux':myflux,
          'mean_signal':mean_signal,
          'poi_err':poi_err,
          'eff_RON':eff_RON,
          'eff_qerr':eff_qerr,
          'tot_noise':tot_noise,
          'act_poi_err':act_poi_err,
          'noisy_counts_err':noisy_counts_err,
          'mn':mn,
          'md':md,
          'mode':mode,
          'st':st,
          'mn_C3':mn_C3,
          'md_C3':md_C3,
          'mode_C3':mode_C3,
          'st_C3':st_C3,
          'FP_CR':np.sum(CR_false_positives[j]),
          'FP_CR_C3':np.sum(CR_false_positives_C3[j]),
          'mean_bg_er':mean_bg_er,
          'a_rms':a_rms
         }
    
    with bz2.BZ2File('/user/gennaro/Functional_work/WFC3/ISRs/Up_the_ramp_fitting/Table_data/Test_'+testname+'_tabledata_'+str(j)+'.pbz2', 'w') as f:
        pickle.dump(ds,f)

    
    print('######################')
    print('Input flux (e/s):',myflux)
    print('Exposure time',exptime)
    print('Number of groups / frames / skips: {} / {} / {}'.format(myramp.ngroups,myramp.nframes,myramp.nskips))
    print('Number of tests:',ntest)
    print('Fraction of good fits:',100.*BM0.sum().astype(np.float_)/ntest,'%')
    print(' ')    
    print('Pure Poisson / eff. RON / eff. quantization error / total relative error [% w.r.t. mean signal]: {:6.3f}, {:6.3f}, {:6.3f}, {:6.3}'.format(
        100.*poi_err/mean_signal, 100.*eff_RON/mean_signal, 100.*eff_qerr/mean_signal, 100*tot_noise/mean_signal))
    print('Pure Poisson / eff. RON / eff. quantization error / total error [e/s]: {:7.4f}, {:7.4f}, {:7.4f}, {:7.4f}'.format(
        poi_err/exptime, eff_RON/exptime, eff_qerr/exptime, tot_noise/exptime))
    print('Standard deviation from "noiseless" counts (e/s): {:7.4f}'.format(act_poi_err/exptime))
    print(' ')
    print('Output mean (e/s): {:7.4f} [new] / {:7.4f} [calwf3]'.format(mn,mn_C3))
    print('Output median (e/s): {:7.4f} [new] / {:7.4f} [calwf3]'.format(md,md_C3))
    print('Output mode (e/s): {:7.4f} [new] / {:7.4f} [calwf3]'.format(mode,mode_C3))
    print('Output standard deviation (e/s): {:7.4f} [new] / {:7.4f} [calwf3]'.format(st,st_C3))
    print('Rel. % error: : {:5.2f} [new] / {:5.2f} [calwf3]'.format(100.*st/myflux,100.*st_C3/myflux))
    print('Rel. % bias:: {:5.2f} [new] / {:5.2f} [calwf3]'.format(100.*(mn-myflux)/myflux,100*(mn_C3-myflux)/myflux))
    print(' ')
    print('Fraction of good fits at {:6.3f} confidence: {} %'.format(pthr,100.*np.sum(BMpt)/len(BMpt)))
    print(' ')
    print('CR hits / detected',len(allCRcounts[j]),'/',np.sum(allCRdetect[j]))
    print('CR - false positives',np.sum(CR_false_positives[j]))
    