In [18]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

from sotodlib import core
import sotodlib.io.load as io_load

from moby2.analysis import socompat
socompat.register_loaders()

import tools

#from IPython.core.interactiveshell import InteractiveShell
#InteractiveShell.ast_node_interactivity = "all"

In [19]:
from sotodlib.core import FlagManager

import sotodlib.flags as flags
import sotodlib.sim_flags as sim_flags

import sotodlib.tod_ops.filters as filters

from sotodlib.tod_ops import fourier_filter, rfft, detrend_data

In [20]:
import tools
import importlib
importlib.reload(tools)
from tools import in_range

In [21]:
def test_glitch(tod, detail=False):
    found_glitches = flags.get_glitch_flags(tod, signal=signal_name, overwrite=True).ranges
    true_glitches = tod.flags.true_glitches
    
    found_glitches_ranges2D = np.vstack([r.mask() for r in found_glitches])
    true_glitches_ranges2D = np.vstack(r.mask() for r in true_glitches)
    
    result = found_glitches_ranges2D*true_glitches_ranges2D
            
    false_sum = 0
    for det in range(tod.dets.count):
        false_count = 0
        r = found_glitches[det].ranges()
        
        for index in r:
            s = np.sum(found_glitches[det].mask()[index[0]:index[1]] * true_glitches[det].mask()[index[0]:index[1]])
            if s == 0:
                false_count = false_count + 1
                false_sum = false_sum + 1
                
        if detail == True:
            true = tod.flags.true_glitches[det]
            
    
            found = flags.get_glitch_flags(tod, signal=signal_name, overwrite=True).ranges[det]
        
            results = true.mask()*found.mask()
            print('det:', det)
            print('true glitches:%s %s' %(np.sum(true.mask()),true.ranges()))
            #print('\n')
            print('found ranges:%s %s' %(int(found.ranges().size/2),found.ranges()))
            #print('\n')
            print('found true glitches:', np.sum(results))
            print('found false ranges:', false_count)
            print('detection rate:', np.sum(results)/np.sum(true.mask()))
            print('true positive rate:', (int(found.ranges().size/2)-false_count)/int(found.ranges().size/2))
            print('\n')
            
    #print('false sum:', false_sum)
    #print('\n')
    
    true_sum = true_glitches_ranges2D.sum()
    found_sum = sum([mr.ranges().size/2 for mr in found_glitches])
    found_true_sum = result.sum()
    
    detection_rate = found_true_sum/true_sum
    
    
    true_positive_rate = (found_sum - false_sum)/found_sum
    false_positive_rate = 1 - true_positive_rate
    
    return detection_rate, true_positive_rate, false_positive_rate

In [22]:
# LF:0,95  MFF:95,190 MFS:190, 285  UHF:285,380
obs_num1, obs_num2 = 0, 1
obs_num = obs_num2 - obs_num1

det_set_num1, det_set_num2 = 0, 1
det_set_num = det_set_num2 - det_set_num1

det_num1, det_num2 = 0, 1
det_num = det_num2 - det_num1

n_glitch = 2
heights = np.arange(1, 10, 1)

s_n = np.arange(10, 51, 1)

signal_name = 'bad_signal'

# Choose observation
nersc_todsims = '/global/project/projectdirs/sobs/todsims/pipe-s0001/v4/context.yaml'
context = core.Context(nersc_todsims)
my_obs_list = context.obsdb.get()[obs_num1:obs_num2]

print('tod:', obs_num * det_set_num)
for my_obs in my_obs_list:
    print(my_obs['obs_id'])
print(context.obsfiledb.get_detsets(my_obs['obs_id'])[det_set_num1:det_set_num2])

tod: 1
CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-0_LF
['LF1_tube_LT6']


In [6]:
my_obs = my_obs_list[0]
det_set_list = context.obsfiledb.get_detsets(my_obs['obs_id'])[det_set_num1:det_set_num2]
det_set = det_set_list[0]
dets = context.obsfiledb.get_dets(det_set)
tod = context.get_obs(my_obs, dets=dets[det_num1:det_num2])
tod
# add glitch 
n_glitch = 2
height = 5
signal_name = 'bad_signal'

if 'badness' in tod:
    tod.move('badness', None)
    # add glitches
    sim_flags.add_random_glitches(tod, params={'n_glitch':n_glitch, 'sig_n_glitch' : 0,'h_glitch':height}, signal='badness', overwrite='False')

if signal_name in tod:
    tod.move(signal_name, None)
    # add the glitches to the tod
    tod.wrap(signal_name, tod.signal+tod.badness, [(0, tod.dets), (1, tod.samps)])
    
Found_glitches = flags.get_glitch_flags(tod, overwrite=True)

In [17]:
Found_glitches.ranges[0].ranges()

array([[    0,   202],
       [51037, 51238]], dtype=int32)

In [7]:
True_Positive_Rate_Matrix = []
Detection_Rate_Matrix= []

for my_obs in my_obs_list:
    print('\n')
    det_set_list = context.obsfiledb.get_detsets(my_obs['obs_id'])[det_set_num1:det_set_num2]
    print(my_obs['obs_id'])
    
    for det_set in det_set_list:
        print(det_set)
        dets = context.obsfiledb.get_dets(det_set)
        tod = context.get_obs(my_obs, dets=dets[det_num1:det_num2])
       # Trim tod
        tsamp = np.median(np.diff(tod.timestamps))
        n_cut = int(10//tsamp)
        #print('Trimming in time...')
        tod.restrict('samps', (n_cut, tod.samps.count-n_cut))
        
        True_Positive_Rate = []
        False_Positive_Rate = []
        Detection_Rate = []


        for height in heights:
            #print('height:', height)
            if 'badness' in tod:
                tod.move('badness', None)
            # add glitches
            sim_flags.add_random_glitches(tod, params={'n_glitch':n_glitch, 'sig_n_glitch' : 0,'h_glitch':height}, signal='badness', overwrite='False')

            if signal_name in tod:
                tod.move(signal_name, None)
            # add the glitches to the tod
            tod.wrap(signal_name, tod.signal+tod.badness, [(0, tod.dets), (1, tod.samps)])
        
            # method 
            detection_rate, true_positive_rate, false_positive_rate = test_glitch(tod, detail=False)
            True_Positive_Rate.append(true_positive_rate)
            Detection_Rate.append(detection_rate)
        
        True_Positive_Rate_Matrix.append(True_Positive_Rate)
        Detection_Rate_Matrix.append(Detection_Rate)

True_Positive_Rate_Matrix = np.array(True_Positive_Rate_Matrix)
Detection_Rate_Matrix = np.array(Detection_Rate_Matrix)
print('True_Positive_Rate_Matrix:\n', True_Positive_Rate_Matrix)
print('Detection_Rate_Matrix:\n', Detection_Rate_Matrix)

True_Positive_Rate = True_Positive_Rate_Matrix.mean(axis=0)
Detection_Rate = Detection_Rate_Matrix.mean(axis=0)



CES-Atacama-LAT-Tier1DEC-035..-045_RA+040..+050-0-0_LF
LF1_tube_LT6
True_Positive_Rate_Matrix:
 [[0.5        0.5        0.5        0.33333333 0.66666667 0.66666667
  0.66666667 0.66666667 0.66666667]]
Detection_Rate_Matrix:
 [[1. 1. 1. 1. 1. 1. 1. 1. 1.]]




In [80]:
import pandas as pd
df = pd.DataFrame(True_Positive_Rate_Matrix)
df.to_csv('./output/TPR_%s.csv' %(det_set), index=False, header=heights)

df = pd.DataFrame(Detection_Rate_Matrix)
df.to_csv('./output/DR_%s.csv' %(det_set), index=False, header=heights)