In [70]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

from sotodlib import core
import sotodlib.io.load as io_load

from moby2.analysis import socompat
socompat.register_loaders()

import tools

#from IPython.core.interactiveshell import InteractiveShell
#InteractiveShell.ast_node_interactivity = "all"

In [71]:
from sotodlib.core import FlagManager

import sotodlib.flags as flags
import sotodlib.sim_flags as sim_flags

import sotodlib.tod_ops.filters as filters

from sotodlib.tod_ops import fourier_filter, rfft, detrend_data

In [72]:
import tools
import importlib
importlib.reload(tools)
from tools import in_range

In [73]:
def test_glitch(tod, detail=False):
    found_glitches = flags.get_glitch_flags(tod, signal=signal_name, overwrite=True).ranges
    true_glitches = tod.flags.true_glitches
    
    found_glitches_ranges2D = np.vstack([r.mask() for r in found_glitches])
    true_glitches_ranges2D = np.vstack(r.mask() for r in true_glitches)
    
    result = found_glitches_ranges2D*true_glitches_ranges2D
            
    false_sum = 0
    for det in range(tod.dets.count):
        false_count = 0
        r = found_glitches[det].ranges()
        
        for index in r:
            s = np.sum(found_glitches[det].mask()[index[0]:index[1]] * true_glitches[det].mask()[index[0]:index[1]])
            if s == 0:
                false_count = false_count + 1
                false_sum = false_sum + 1
                
        if detail == True:
            true = tod.flags.true_glitches[det]
            
    
            found = flags.get_glitch_flags(tod, signal=signal_name, overwrite=True).ranges[det]
        
            results = true.mask()*found.mask()
            print('det:', det)
            print('true glitches:%s %s' %(np.sum(true.mask()),true.ranges()))
            #print('\n')
            print('found ranges:%s %s' %(int(found.ranges().size/2),found.ranges()))
            #print('\n')
            print('found true glitches:', np.sum(results))
            print('found false ranges:', false_count)
            print('detection rate:', np.sum(results)/np.sum(true.mask()))
            print('true positive rate:', (int(found.ranges().size/2)-false_count)/int(found.ranges().size/2))
            print('\n')
            
    #print('false sum:', false_sum)
    #print('\n')
    
    true_sum = true_glitches_ranges2D.sum()
    found_sum = sum([mr.ranges().size/2 for mr in found_glitches])
    found_true_sum = result.sum()
    
    detection_rate = found_true_sum/true_sum
    
    if found_sum == 0:
        true_positive_rate = 0
    else:
        true_positive_rate = (found_sum - false_sum)/found_sum
    false_positive_rate = 1 - true_positive_rate
    
    return detection_rate, true_positive_rate, false_positive_rate

In [74]:
# get noise
def get_noise(tod):
    if 'turnarounds' in tod.flags:
        tod.flags.move('turnarounds', None)

    flags.get_turnaround_flags(tod, merge=True, name='turnarounds');
    print( tod.flags.turnarounds )

    tmsk = tod.flags.turnarounds.mask()

    scan_rate = np.median( np.abs(np.diff(tod.boresight.az[~tmsk]))) / np.median(np.diff(tod.timestamps))
    print( 'The scan rate is {} deg / s'.format(round(np.degrees(scan_rate),3) ))

    turn = np.where( np.diff(tod.timestamps[tmsk]) > 0.005 )[0]
    turn_time = np.median( np.diff(tod.timestamps[tmsk][turn]))
    
    ffts, freqs = rfft(tod)
    tsamp = np.median(np.diff(tod.timestamps))
    norm_fact = (1.0/tsamp)*np.sum(np.abs(np.hanning(tod.samps.count))**2)
    fmsk = freqs > 10
    det_white_noise = 1e6*np.median(np.sqrt(np.abs(ffts[:,fmsk])**2/norm_fact),axis=1)
    
    return np.median(det_white_noise)

In [80]:
# LF:0,95  MFF:95,190 MFS:190, 285  UHF:285,380
obs_num1, obs_num2 = 285, 295
obs_num = obs_num2 - obs_num1

det_set_num1, det_set_num2 = 6, 7
det_set_num = det_set_num2 - det_set_num1

det_num1, det_num2 = 0, 10
det_num = det_num2 - det_num1

n_glitch = 2
heights = np.arange(1, 10, 1)

# signal/noise ranges
S_Ns = np.arange(10, 51, 1)

signal_name = 'bad_signal'

# Choose observation
# for nersc
nersc_todsims = '/global/project/projectdirs/sobs/todsims/pipe-s0001/v4/context.yaml'
# for simons 1
simons_todsims = '/mnt/so1/shared/todsims/'

context = core.Context(nersc_todsims)
my_obs_list = context.obsdb.get()[obs_num1:obs_num2]

print('obs:', obs_num)
print('det_set:', det_set_num)
print('tod:', obs_num*det_set_num)
print('arrays:', obs_num * det_set_num*det_num)
for my_obs in my_obs_list:
    print(my_obs['obs_id'])

obs: 10
det_set: 1
tod: 10
arrays: 100
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-0_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-1_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-2_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-1-0_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-1-1_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-1-2_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-10-0_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-10-1_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-10-2_UHF
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-11-0_UHF


In [67]:
True_Positive_Rate_Matrix = []
Detection_Rate_Matrix= []

for my_obs in my_obs_list:
    print('\n')
    det_set_list = context.obsfiledb.get_detsets(my_obs['obs_id'])[det_set_num1:det_set_num2]
    print(my_obs['obs_id'])
    
    for det_set in det_set_list:
        print(det_set)
        dets = context.obsfiledb.get_dets(det_set)
        # Get the tod
        tod = context.get_obs(my_obs, dets=dets[det_num1:det_num2])
       # Trim tod
        tsamp = np.median(np.diff(tod.timestamps))
        n_cut = int(10//tsamp)
        #print('Trimming in time...')
        tod.restrict('samps', (n_cut, tod.samps.count-n_cut))
        
        # Get noise
        noise = get_noise(tod)
        
        True_Positive_Rate = []
        False_Positive_Rate = []
        Detection_Rate = []

        for S_N in S_Ns:
            #print('height:', height)
            height = 1e-6*S_N*noise
            if 'badness' in tod:
                tod.move('badness', None)
            # add glitches
            sim_flags.add_random_glitches(tod, params={'n_glitch':n_glitch, 'sig_n_glitch' : 0,'h_glitch':height}, signal='badness', overwrite='False')

            if signal_name in tod:
                tod.move(signal_name, None)
            # wrap the glitches to the tod
            tod.wrap(signal_name, tod.signal+tod.badness, [(0, tod.dets), (1, tod.samps)])
        
            # method 
            detection_rate, true_positive_rate, false_positive_rate = test_glitch(tod, detail=False)
            True_Positive_Rate.append(true_positive_rate)
            Detection_Rate.append(detection_rate)
        
        True_Positive_Rate_Matrix.append(True_Positive_Rate)
        Detection_Rate_Matrix.append(Detection_Rate)

True_Positive_Rate_Matrix = np.array(True_Positive_Rate_Matrix)
Detection_Rate_Matrix = np.array(Detection_Rate_Matrix)
print('True_Positive_Rate_Matrix:\n', True_Positive_Rate_Matrix)
print('Detection_Rate_Matrix:\n', Detection_Rate_Matrix)

# axis=0 for tods, axis=1 for S/N
True_Positive_Rate = True_Positive_Rate_Matrix.mean(axis=0)
Detection_Rate = Detection_Rate_Matrix.mean(axis=0)



CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-0_UHF
UHF2_wafer_00
Ranges(n=283558:rngs=77)
The scan rate is 1.229 deg / s






CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-1_UHF
UHF2_wafer_00
Ranges(n=283558:rngs=77)
The scan rate is 1.229 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-2_UHF
UHF2_wafer_00
Ranges(n=283558:rngs=77)
The scan rate is 1.229 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-1-0_UHF
UHF2_wafer_00
Ranges(n=310758:rngs=83)
The scan rate is 1.391 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-1-1_UHF
UHF2_wafer_00
Ranges(n=310758:rngs=83)
The scan rate is 1.391 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-1-2_UHF
UHF2_wafer_00
Ranges(n=310758:rngs=83)
The scan rate is 1.391 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-10-0_UHF
UHF2_wafer_00
Ranges(n=337958:rngs=89)
The scan rate is 1.505 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-10-1_UHF
UHF2_wafer_00
Ranges(n=337958:rngs=89)
The scan rate is 1.505 deg / s


CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-10-2_UHF
UHF2_wafer_03
Ranges(n=337958:rngs=89)
The 

In [68]:
import pandas as pd
df = pd.DataFrame(True_Positive_Rate_Matrix)
#df.to_csv('./output/TPR_%s.csv' %(det_set), index=False, header=heights)
df.to_csv('./output/test/TPR_%s.csv' %(det_set), index=False, header=S_Ns)

df = pd.DataFrame(Detection_Rate_Matrix)
#df.to_csv('./output/DR_%s.csv' %(det_set), index=False, header=heights)
df.to_csv('./output/test/DR_%s.csv' %(det_set), index=False, header=S_Ns)

In [8]:
my_obs = my_obs_list[0]
det_set_list = context.obsfiledb.get_detsets(my_obs['obs_id'])[det_set_num1:det_set_num2]
det_set = det_set_list[0]
dets = context.obsfiledb.get_dets(det_set)
tod = context.get_obs(my_obs, dets=dets[det_num1:det_num2])

# add glitch 
n_glitch = 2
height = 5
signal_name = 'bad_signal'

if 'badness' in tod:
    tod.move('badness', None)
    # add glitches
    sim_flags.add_random_glitches(tod, params={'n_glitch':n_glitch, 'sig_n_glitch' : 0,'h_glitch':height}, signal='badness', overwrite='False')

# Trim tod
tsamp = np.median(np.diff(tod.timestamps))
n_cut = int(10//tsamp)
#print('Trimming in time...')
tod.restrict('samps', (n_cut, tod.samps.count-n_cut))

noise = get_noise(tod)
print('noise:', noise)            


if signal_name in tod:
    tod.move(signal_name, None)
    # add the glitches to the tod
    tod.wrap(signal_name, tod.signal+tod.badness, [(0, tod.dets), (1, tod.samps)])
    
Found_glitches = flags.get_glitch_flags(tod, overwrite=True)

Ranges(n=50038:rngs=77)
The scan rate is 1.229 deg / s
noise: 256.37555
