In [115]:
import numpy as np
import scipy.fft as sp_fft

class OFtrigger:

    def __init__(self, template, noise_psd, sampling_frequency):
        self._sampling_frequency = sampling_frequency
        self._length = len(template)
        self.set_template(template)
        self.set_noise_psd(noise_psd)

    def set_template(self, template):
        self._template = template
        self._template_fft = sp_fft.rfft(template) / self._sampling_frequency
        self._update_kernel_fft()

    def set_noise_psd(self, noise_psd):
        self._noise_psd = noise_psd.copy()

        self._inv_psd = np.zeros_like(noise_psd)
        self._inv_psd[1:] = 1.0 / (noise_psd[1:] + 1e-30)
        if self._length % 2 == 0:
            self._inv_psd[-1] = 1.0 / (noise_psd[-1] + 1e-30)
        self._inv_psd[0] = 0.0

        self._update_kernel_fft()

    def _update_kernel_fft(self):
        if hasattr(self, '_template_fft') and hasattr(self, '_inv_psd'):
            self._kernel_fft = self._template_fft.conjugate() * self._inv_psd
            self._kernel_normalization =  np.real(
                np.sum(self._kernel_fft * self._template_fft)
            ) * self._sampling_frequency / self._length

    def fit(self, trace):#The chi2 of this suppose to be 1/4 of oringinal, since the integration is only on the real part
        trace_fft = sp_fft.rfft(trace) / self._sampling_frequency
        trace_filtered = self._kernel_fft * trace_fft
        amp0 = np.real(np.sum(trace_filtered)) * self._sampling_frequency / (self._length * self._kernel_normalization)

        chisq0 = np.real(np.vdot(trace_fft, trace_fft * self._inv_psd)) * self._sampling_frequency / self._length
        chisq = (chisq0 - amp0**2 * self._kernel_normalization) / (self._length - 2)

        return amp0, chisq

    def fit_with_shift(self, trace, allowed_shift_range=None):
        trace_fft = sp_fft.rfft(trace) / self._sampling_frequency
        trace_filtered = self._kernel_fft * trace_fft / self._kernel_normalization

        # A(t0) is the inverse FFT of the filtered signal
        trace_filtered_td = sp_fft.irfft(trace_filtered) * self._sampling_frequency

        # Compute chi^2_0 (independent of shift)
        chisq0 = np.real(np.vdot(trace_fft, trace_fft * self._inv_psd)) * self._sampling_frequency / self._length

        # Compute chi^2(t0) = chisq0 - A(t0)^2 * norm
        amp_series = trace_filtered_td*0.5#correct irfft and rfft
        chisq_series = chisq0 - amp_series**2 * self._kernel_normalization

        if allowed_shift_range is None:
            ind = np.arange(len(chisq_series))
        else:
            start = (self._length + allowed_shift_range[0]) % self._length
            stop = (allowed_shift_range[1] + 1) % self._length
            if start < stop:
                ind = np.arange(start, stop)
            else:
                ind = np.concatenate((np.arange(start, self._length), np.arange(0, stop)))

        best_ind = ind[np.argmin(chisq_series[ind])]
        amp = amp_series[best_ind]
        chisq = chisq_series[best_ind] / (self._length - 3)
        t0 = best_ind if best_ind < self._length // 2 else best_ind - self._length

        return amp, chisq, t0


In [116]:
import numpy as np
import time
from OF_trigger import OptimumFilter  # Replace with actual module path


sampling_frequency = 3906250

data = np.load("/ceph/dwong/delight/Ka_traces_1.npz")

template = np.load("../templates/template_K_alpha_tight.npy")
noise_psd = np.load("../templates/noise_psd_from_MMC.npy")
loaded_traces = data['data']
data.close()

# Ensure correct shape

assert len(template) == loaded_traces.shape[1], "Trace length must match template length"

# Benchmark function using fit_with_shift
def benchmark_filter_with_shift(FilterClass, template, psd, sampling_frequency, traces):
    of = FilterClass(template, psd, sampling_frequency)
    start = time.time()
    results = []
    for i in range(traces.shape[0]):
        results.append(of.fit_with_shift(traces[i]))
    elapsed = time.time() - start
    return elapsed, results

# Run benchmarks
t_old, results_old = benchmark_filter_with_shift(OptimumFilter, template, noise_psd, sampling_frequency, loaded_traces)
t_new, results_new = benchmark_filter_with_shift(OFtrigger, template, noise_psd, sampling_frequency, loaded_traces)
print(f"OptimumFilter time (fit_with_shift): {t_old:.4f}s")
print(f"OFtrigger time (fit_with_shift): {t_new:.4f}s")
print(f"Speedup: {t_old / t_new:.2f}x")

OptimumFilter time (fit_with_shift): 5.0937s
OFtrigger time (fit_with_shift): 3.4464s
Speedup: 1.48x


In [117]:
results_old[10:20] 

[(10414.42191904015, 1.0308331060424099, 0),
 (10405.665579890228, 0.9785663927509168, 0),
 (10416.386487652875, 1.0108577524644262, 0),
 (10411.395402202568, 0.9891110814783786, 0),
 (10395.465460254756, 1.0255180864330105, 0),
 (10414.140305288478, 0.9932680430274764, 0),
 (10411.829155000516, 1.0325721542698436, 0),
 (10416.330759369552, 1.0022067390324305, 0),
 (10415.4236747321, 1.0033962530497103, 0),
 (10416.520771292313, 1.0048969021311756, 0)]

In [118]:
results_new[10:20]

[(10414.42078932861, 0.25784410784467565, 0),
 (10405.664952202407, 0.24499471535573347, 0),
 (10416.385715569613, 0.2529626691084899, 0),
 (10411.394485897516, 0.24745454503782974, 0),
 (10395.464678790739, 0.2566201568327477, 0),
 (10414.139542269872, 0.24857066459220437, 0),
 (10411.828556084582, 0.25852220220527955, 0),
 (10416.330060062335, 0.2508490466919634, 0),
 (10415.422497087706, 0.25098617979834575, 0),
 (10416.519818148405, 0.25138876899615253, 0)]