```bash						
	NAME	DM [pc cm**-3]	S/N (AMBER)	RA	DEC	POINTING NAME	CB Found
							
1	FRB 190709	663.4	15.4	01h36m06.7s	+31d51m22.8s	3C48drift2732	10
2	FRB 190903	663.8	10.8	01h32m43.2805s	+33d04m48.9206s		4
3	FRB 190925	956.7	12.9	01h41m49.s	+30d59m24.4s		7
4	FRB 191020	465	13.2	20h30m52 	+61d58m47s	T2029+6307	5
5	FRB 191108	587	60	01h33m57.35s	+31d45m38s	FRBfield	21
6	FRB 191109	531	13	20:33:51	61:46:30	FRB191020	18
```

In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import pandas as pd
import numpy as np
from numpy.fft import rfft, irfft
from scipy.signal import find_peaks
import copy
import matplotlib.pyplot as plt

import os

from filter import *

load_data = True

In [None]:
def fix(trig):
    nans, x = nan_helper(trig)
    trig[nans] = np.interp(x(nans), x(~nans), trig[~nans])
    return trig

def normalize(trig, z=False): 
    if z:
        return (trig - np.mean(trig))/trig.std()
    else:
        return (trig-trig.min())/(trig.max()-trig.min())
    
def remove_dc(trig):
    return trig - np.mean(trig)

def convolve(trig):
    def abs2(x):
        return x.real**2 + x.imag**2
    return irfft(abs2(rfft(trig, norm="ortho")), norm="ortho")

def get_period(trig, plot=False):
    # Number of sample points
    N = trig.shape[0]
    # sample spacing
    T = 1.0 / N

    x = np.linspace(0.0, N*T, N)
    y = copy.deepcopy(trig)
    yf = np.fft.fft(y-y.mean())
    xf = np.linspace(0.0, 1.0/(2.0*T), N//2)
    
    period = int(xf[np.argmax(2.0/N * np.abs(yf[0:N//2]))])

    if plot:
        plt.figure()
        plt.plot(x, y)
        plt.figure()
        plt.plot(xf, 2.0/N * np.abs(yf[0:N//2]))
        plt.axvline(period)

    return period, xf, 2.0/N * np.abs(yf[0:N//2])

def get_period2(trig, plot=False):
    def abs2(x):
        return x.real**2 + x.imag**2
    
#     L = np.round(trig, 1)
    # Remove DC component, as proposed by Nils Werner
    L = trig - np.mean(trig)
    # Window signal
#     L *= scipy.signal.windows.hann(len(L))

    self_convolved = convolve(L)
    self_convolved = self_convolved/self_convolved[0]

    if plot:
        plt.figure()
        plt.plot(self_convolved)

    # let's get a max, assuming a least 4 periods...
    period_multiple = np.argmax(self_convolved[1:len(L)//4])
    Ltrunk = L[0:(len(L)//period_multiple) * period_multiple]

    self_convolved = convolve(Ltrunk)
    self_convolved = self_convolved / self_convolved[0]

    if plot:
        plt.figure()
        plt.plot(self_convolved)

    #get ranges for first min, second max
    fmax = np.max(self_convolved[1:len(Ltrunk)//4])
    fmin = np.min(self_convolved[1:len(Ltrunk)//4])
    xstartmin = 1
    while (
        self_convolved[xstartmin] > fmin + 0.2 * (fmax-fmin)
    ) and (
        xstartmin < len(Ltrunk)//4
    ):
        xstartmin = xstartmin + 1

    xstartmax = xstartmin
    while (
        self_convolved[xstartmax] < fmin + 0.7 * (fmax-fmin)
    ) and (
        xstartmax < len(Ltrunk)//4
    ):
        xstartmax=xstartmax+1

    xstartmin = xstartmax
    while (
        self_convolved[xstartmin] > fmin + 0.2 * (fmax-fmin)
    ) and (
        xstartmin < len(Ltrunk)//4
    ):
        xstartmin = xstartmin + 1

    period = np.argmax(self_convolved[xstartmax:xstartmin]) + xstartmax

    return period
   
def pool_triggers(beams, sigs, times, dms):
    triggers = {}

    for i, trigger in enumerate(zip(beams, sigs, times, dms)):
        trig_beams, trig_sigs, trig_times, trig_dms = trigger
        if len(trig_sigs) > 1:
            trig, dm = np.empty(71), np.empty(71)
            trig[:], dm[:] = 8, 0
            for j, sig in enumerate(trig_sigs):
                trig[trig_beams[j]] = sig
                dm[trig_beams[j]] = trig_dms[j]
                trig = fix(trig)
                dm = fix(dm)        

            triggers[i] = {}
            triggers[i]['trigger'] = trig
            triggers[i]['time'] = trig_times[0]
            triggers[i]['dm'] = dm
            try:
                triggers[i]['period'] = get_period(trig)
            except:
                triggers[i]['period'] = None
            
    return triggers

def plot_all(triggers, filename='sb_fft', outpath='images/'):
    fig, ax = plt.subplots(1,3, figsize=(10, 10))
#     min_dm, max_dm = np.inf, 0
#     for i in triggers.keys():
#         dm = triggers[i]['dm'][np.argmax(triggers[i]['trigger'])]
#         if dm > max_dm:
#             max_dm = dm
#         if dm < min_dm:
#             min_dm = dm
            
    scat_xs = []
    scat_ys = []
    scat_zs = []
    ii = 0
    
    for i in triggers.keys():
#         print (triggers[i]['trigger'])
        ax[0].plot(normalize(triggers[i]['trigger']) + ii, c='black')
        period, x, y = triggers[i]['period']
        ax[1].plot(x, normalize(y) + ii, c='black')   
        dm = triggers[i]['dm'][np.argmax(triggers[i]['trigger'])]
#         color = ((dm-min_dm)/(max_dm-min_dm))
        scat_xs.append(period)
        scat_ys.append(normalize(y)[period-1] + ii)
        scat_zs.append(dm)
#         ax[1].scatter(period, normalize(y)[period-1] + ii, c=color)
#         periods.append(period)
#         except IndexError:
#             pass

        ii += 2 

    ax[0].set_ylabel('Trigger #')
    ax[0].set_xlabel('SB #')
    ax[0].set_title('Trigger')    
    ax[1].set_xlabel('Freq')
    ax[1].set_yticks([])
    ax[1].set_title('fft(Trigger)')
    
    print (scat_zs)
    scat = ax[1].scatter(scat_xs, scat_ys, c=scat_zs, cmap='viridis')
    cbar = fig.colorbar(scat, ax=ax[1])
    cbar.ax.set_ylabel('DM (pc/cc)')

    ax[2].set_title('P(period)')
    
    periods = np.asarray(scat_xs)
    hist = ax[2].hist(periods+1, 71, density=True, orientation='horizontal')
    
    ax[2].set_ylabel('Period')
    ax[2].set_xlabel('Density')
#     print (hist)

#     plt.title(filename)

    plt.tight_layout()
    plt.savefig(outpath + filename + '.pdf')
    plt.savefig(outpath + filename + '.png', dpi=300)
    

base_path = '../data/trigger/'
for filename in os.listdir(base_path):
    print (filename)
    beams, sigs, times, dms = load_trigger_file(filename = base_path + filename,
                                                verbose = False,
                                                read_data = True,
                                                read_beam=True,
                                                replace = False)

    triggers = pool_triggers(beams, sigs, times, dms)
    plot_all(triggers, filename.split('.trigger')[0])
    print ()

In [None]:
fig, ax = plt.subplots(1,1)

for i in triggers.keys():
    try:
        ax.scatter(i, triggers[i]['period'])
        ax.set_xlabel('Trigger #')
        ax.set_ylabel('Period')
#         print(i, triggers[i]['period'])
    except:
        pass

In [None]:
a = np.zeros((71))
a[:] = 0
a[::5] = 8
# a[::14] = 10
print (a)
plt.figure()
plt.plot([i for i in range(a.shape[0])], a)
print (get_period(a)[0])

np.fft.rfftfreq?


In [None]:
a.shape

In [None]:
# signal = np.array([-2, 8, 6, 4, 1, 0, 3, 5, -3, 4], dtype=float)
signal = copy.deepcopy(trig)
fourier = np.fft.rfft(signal)
n = signal.size
sample_rate = 1
freq = np.fft.fftfreq(n, d=1./sample_rate)
print (freq)

freq = np.fft.rfftfreq(n, d=1./sample_rate)
print (freq)

print (fourier)

In [None]:
signal.size

In [None]:
x = np.array([0.0, 1.0, 0.0, -1.0, 0.0])
plt.plot(x)

y = np.fft.fft(x)
plt.figure()
plt.plot(y)

yinv = np.fft.ifft(y)
plt.figure()
plt.plot(yinv)

In [None]:
plt.plot(trig)

In [None]:
get_period(trig)

In [None]:
trig.min()

In [None]:
trig.max()

In [None]:
trig[np.where(trig > trig.min())] = 10

In [None]:
plt.plot(triggers[2]['trigger'])
plt.scatter(find_peaks(triggers[2]['trigger'])[0], 
            triggers[2]['trigger'][find_peaks(triggers[2]['trigger'])[0]], 
            c='red')

trig = copy.deepcopy(triggers[2]['trigger'])
trig[:] = 8
trig[find_peaks(triggers[2]['trigger'])[0]] = 10

plt.figure()
plt.plot(trig)
get_period(trig, plot=True)

In [None]:
find_peaks(triggers[2]['trigger'])[0]

In [None]:
L = copy.deepcopy(triggers[2]['trigger'])

convolved = convolve(L)
plt.plot(convolved[1:])
print (np.argmax(convolved[1:len(L)]))

In [None]:
from scipy.fft import fft
import matplotlib.pyplot as plt


# Number of sample points
N = 600
# sample spacing
# T = 1.0 / 800.0
T = 1

x = np.linspace(0.0, N*T, N)
y = np.sin(50.0 * 2.0*np.pi*x) + 0.5 * np.sin(80.0 * 2.0*np.pi*x)
yf = fft(y)
xf = np.linspace(0.0, 1.0/(2.0*T), N//2)

plt.plot(x, y)
plt.figure()
plt.plot(xf, 2.0/N * np.abs(yf[0:N//2]))

y = np.sin(50.0 * 2.0*np.pi*x) + 0.5 * np.sin(80.0 * 2.0*np.pi*x)
y[y < 0] = 0
yf = fft(y-y.mean())
xf = np.linspace(0.0, 1.0/(2.0*T), N//2)

plt.figure()
plt.plot(x, y)
plt.figure()
plt.plot(xf, 2.0/N * np.abs(yf[0:N//2]))
plt.axvline(xf[np.argmax(yf)], c='red')
print (xf[np.argmax(2.0/N * np.abs(yf[0:N//2]))])

In [None]:
from scipy.fft import fft
import matplotlib.pyplot as plt

# Number of sample points
N = triggers[2]['trigger'].shape[0]
# sample spacing
T = 1.0 / N

x = np.linspace(0.0, N*T, N)
y = copy.deepcopy(triggers[2]['trigger'])
yf = fft(y-y.mean())
xf = np.linspace(0.0, 1.0/(2.0*T), N//2)

# if plot:
plt.figure()
plt.plot(y)
plt.figure()

yyy = 2.0/N * np.abs(yf[0:N//2])

plt.plot(xf, yyy)
period = int(xf[np.argmax(yyy)])
plt.scatter(period, yyy[period])
plt.axvline(period)
print (period)
# return int(xf[np.argmax(2.0/N * np.abs(yf[0:N//2]))])

In [None]:
plt.plot(2.0/N * np.abs(yf[0:N//2]))

In [None]:
np.argmax(2.0/N * np.abs(yf[0:N//2]))

In [None]:
triggers[2]['trigger'].shape[0]

In [None]:
get_period(triggers[0]['trigger'])

In [None]:
plt.plot(triggers[0]['trigger'])

In [None]:
for i in triggers.keys():
    print (np.unique(triggers[i]['dm']))