In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import os
import sys
import warnings
warnings.simplefilter('ignore')

from IPython.display import display

# Add modules paths to system path
module_paths = ['.']
for module_path in module_paths:
    if os.path.abspath(os.path.join(module_path)) not in sys.path:
        sys.path.append(module_path)
    
import matplotlib.pyplot as plt

from ipywidgets import interact
import ipywidgets as widgets

from psrpy.spectra import Spectra
# from psrpy.rfifind import rfifind

import pandas as pd

from time_domain_astronomy_sandbox.backend import Backend
from time_domain_astronomy_sandbox.observation import Observation
from time_domain_astronomy_sandbox.rfim import RFIm

import numpy as np
from scipy.fftpack import fft, ifft
import scipy.ndimage.filters as filters

from blimpy import Waterfall
from astropy.time import Time, TimeDelta

current_input_id = -1


input_filterbanks_repository = '../data/filterbanks/R3/'

# Prepare bursts metadata
(detection parameters, repository name, local and arts paths) 

In [None]:
def run_fast_scandir(folder, ext, substrings=[]):
    subfolders, files = [], []

    for f in os.scandir(folder):
        if f.is_dir():
            subfolders.appaend(f.path)
        if f.is_file():
            if os.path.splitext(f.name)[1].lower() in ext:
                if len(substrings) == 0:
                    files.append(f.path)
                else:
                    found = True
                    for s in substrings:
                        if s not in os.path.splitext(f.name)[0]:
                            found = False
                    if found:
                        files.append(f.path)

    for folder in list(subfolders):
        sf, f = run_fast_scandir(folder, ext, substrings)
        subfolders.extend(sf)
        files.extend(f)
    return subfolders, files

In [None]:
zapped_channels = np.array([189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 1534, 1535])
detection_folders = ['2020-03-22-10:03:39.R3', '2020-03-23-11:05:38.R3', '2020-03-23-11:05:38.R3', '2020-03-23-11:05:38.R3', '2020-05-09-11:45:55.R3', '2020-05-10-09:41:43.R3', '2020-05-11-07:36:22.R3', '2020-05-11-07:36:22.R3', '2020-05-11-10:42:26.R3', '2020-05-11-10:42:26.R3', '2020-05-11-10:42:26.R3', '2020-05-11-10:42:26.R3', '2020-05-11-10:42:26.R3', '2020-05-11-10:42:26.R3', '2020-05-11-10:42:26.R3', '2020-05-11-14:40:00.R3', '2020-05-11-14:40:00.R3', '2020-05-11-14:40:00.R3', '2020-05-12-08:36:35.R3', '2020-05-27-03:20:38.R3', '2020-05-27-03:20:38.R3', '2020-05-27-03:20:38.R3', '2020-05-27-03:20:38.R3', '2020-05-27-07:21:12.R3', '2020-05-27-07:21:12.R3', '2020-05-27-10:52:06.R3', '2020-05-27-13:37:55.R3', '2020-05-27-13:37:55.R3', '2020-05-28-03:45:00.R3', '2020-05-28-05:13:48.R3', '2020-05-28-05:13:48.R3', '2020-05-28-08:19:28.R3', '2020-05-28-08:19:28.R3', ]
observation_datetimes = ['%s%s%s' % (d[:10], 'T', d[11:].replace('.R3', '.0')) for d in detection_folders]
detection_times = [4590.8, 4354.47, 7599.3, 9402.4, 9363.27, 9234, 3610.84, 6219.06, 4.15, 1780.7, 2317.14, 2688.72, 7153.32, 8913, 8959.9, 1495.37, 1889.1, 4387.14, 2216.21, 1612.21, 4810.98, 8867.2, 11658.7, 3592.32, 5082.87, 2616.07, 3437.12, 4390.28, 141.992, 2063.73, 4728.98, 1220.04, 3843.53, ]
detection_mjd = Time(observation_datetimes, format='isot').mjd + TimeDelta(detection_times, format='sec')
detection_dm = [348, 349, 348, 348, 348.2, 349, 348.2, 350.8, 352, 347.61, 349.2, 350.2, 354.45, 348, 350, 349.4, 348.06, 348.24, 350.3, 352.05, 349.49, 348.71, 349.2, 352.05, 354.2, 348.8, 351.8, 359.34, 348.4, 350.4, 349.56, 349.89, 349.09, ]
detection_downsampling = [5.0, 5.0, 25.0, 5.0, 5.0, 10.0, 10.0, 25.0, 25.0, 100.0, 25.0, 10.0, 50.0, 1., 5.0, 10.0, 25.0, 10.0, 25.0, 50.0, 5.0, 5.0, 10.0, 50.0, 50.0, 5.0, 25.0, 250.0, 5.0, 25.0, 10.0, 5.0, 5.0]
detection_snr = [11.5, 12.7, 13.4, 13.4, 13.6, 8.88, 16.38, 29.89, 13.86, 17.79, 10.12, 11.02, 38.61, 14, 12.5, 11.47, 58.14, 25.56, 31.5, 12.05, 12.67, 20.23, 20.95, 21.45, 19.48, 20.98, 20.87, 9.03, 25.71, 36.54, 29.9, 29.2, 20.92]
detection_files, detection_filenames = [], []

_, local_files = run_fast_scandir(input_filterbanks_repository, ['.fil'])
arts_files = [
 '/tank/data/FRBs/R3/20200527/2020-05-27-10:52:06.R3/snippet/all/CB00_10.0sec_dm0_t02616_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-03:20:38.R3/snippet/all/CB00_10.0sec_dm0_t01612_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-03:20:38.R3/snippet/all/CB00_10.0sec_dm0_t011658_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-03:20:38.R3/snippet/all/CB00_10.0sec_dm0_t04810_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-03:20:38.R3/snippet/all/CB00_10.0sec_dm0_t08867_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-13:37:55.R3/snippet/all/CB00_10.0sec_dm0_t03437_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-13:37:55.R3/snippet/all/CB00_10.0sec_dm0_t04390_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-07:21:12.R3/snippet/all/CB00_10.0sec_dm0_t05082_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200527/2020-05-27-07:21:12.R3/snippet/all/CB00_10.0sec_dm0_t03592_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200512/snippet/all/CB00_10.0sec_dm0_t02216_sb35_tab00.fil',
 '/tank/data/FRBs/R3/20200322/snippet/all/CB00_10.0sec_dm0_t04590_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200529/2020-05-29-03:20:12.R3/snippet/CB00_10.0sec_dm0_t01815_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t02317_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t03610_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t01495_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t01780_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t02688_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t06219_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t01889_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t08913_sb35_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t07153_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t04387_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t08959_sb35_tab00.fil',
 '/tank/data/FRBs/R3/20200511/snippet/all/CB00_10.0sec_dm0_t04_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200509/snippet/all/CB00_10.0sec_dm0_t09363_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200528/2020-05-28-03:45:00.R3/snippet/all/CB00_10.0sec_dm0_t0141_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200528/2020-05-28-08:19:28.R3/snippet/all/CB00_10.0sec_dm0_t03843_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200528/2020-05-28-08:19:28.R3/snippet/all/CB00_10.0sec_dm0_t01220_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200528/2020-05-28-05:13:48.R3/snippet/all/CB00_10.0sec_dm0_t04729_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200528/2020-05-28-05:13:48.R3/snippet/all/CB00_10.0sec_dm0_t02063_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200510/snippet/all/CB00_10.0sec_dm0_t09234_sb-1_tab00.fil',
 '/tank/data/FRBs/R3/20200323/snippet/CB00_10.0sec_dm0_t04354_sb35_tab00.fil',
 '/tank/data/FRBs/R3/20200323/snippet/CB00_10.0sec_dm0_t09402_sb35_tab00.fil',
 '/tank/data/FRBs/R3/20200323/snippet/CB00_10.0sec_dm0_t07599_sb35_tab00.fil']

# Construct detection_files and detection_filenames 
for folder, t, snr in zip(detection_folders, detection_times, detection_snr):
    date = "".join(folder.split('-')[:3])
    found = False
    for f in arts_files:
        # If time and date somehwere in filename or in folder
        if str(int(t)) in f and date in f:
            for ff in local_files:
                if f.split('/')[-1] == ff.split('/')[-1]:
                    detection_files.append(ff)
                    detection_filenames.append(ff.split('/')[-1].split('.fil')[0])
                    
                    found = True
                    break
        if found:
            break
    if not found:
        detection_files.append('')
        detection_filenames.append('')

df_R3 = pd.DataFrame({'detection_folder': detection_folders, 
                      'observation_datetimes': observation_datetimes,
                      'detection_time': detection_times, 
                      'detection_mjd': detection_mjd,
                      'detection_dm': detection_dm, 
                      'detection_downsampling': detection_downsampling, 
                      'detection_snr': detection_snr, 
                      'filename': detection_filenames,
                      'file_location': detection_files}).sort_values('detection_snr', ascending=True)

# df_R3.to_csv('arts_R3.csv', index=False)

# Functions 
## Needs some organizing

In [None]:
def read_filterbank(filename:str,
                    t_res:float = Backend().sampling_time, 
                    f_channels:list = Backend().frequencies[::-1],
                    output_type:str='spectra'):

    data = Waterfall(filename).data[:,0,:].T[::-1, :]
    
    if output_type == 'spectra':
        return Spectra(f_channels,
                       t_res,
                       data)
    elif output_type  == 'observation':
        return Observation(backend=Backend(), 
                           length=data.data.shape[1]*data.dt,
                           window=data.data)
    
def zoom_around_peak(spectra:Spectra, 
                     t_zoom:float = 1.):
    peak_ind = np.argmax(spectra.data.sum(axis=0))
    n_samp = int(np.round(t_zoom / spectra.dt))
    samp_start = int(peak_ind - 0.5 * n_samp)
    return data.data[:, samp_start:samp_start + n_samp]

def get_dm_trials(estimated_dm:float = 349.2,
                  dm_step:float = 0.1,
                  dm_range:int = 5):
    return np.arange(estimated_dm - dm_range, estimated_dm + dm_range + .5 * dm_step, dm_step)
   
def correct_bandpass(spectra:Spectra):
    return spectra.data - np.mean(spectra.data, axis=1, keepdims=True)

def downsample_freq(spectra:Spectra, 
                    factor:int = 2):
    x, y = spectra.data.shape[0], spectra.data.shape[1]
    x_fact, y_fact = factor, 1
    return spectra.data.reshape(x//x_fact, x_fact, y, 1).mean(-1).mean(1)
    
def crop(spectra:Spectra, 
         t_zoom:float = 0.25,
         around_peak=True):
        
    n_samp = int(np.round(t_zoom / spectra.dt))
    if around_peak:
        peak_ind = np.argmax(np.median(spectra.data, axis=0))
        start = int(np.round(peak_ind - (0.5 * n_samp)))    
    else:
        start = int(np.round(spectra.data.shape[1]//2 - (0.5 * n_samp)))
        
    if start < 0:
        n_samp += start
        start = 0
    
    return spectra.data[:, start:start+n_samp]

def to_snr(spectra:Spectra, axis=1):
    data = spectra.data
    data = data - np.nanmean(data, axis=axis)[:, None]
    data = data / np.sqrt(np.nanvar(data, axis=axis))[:, None]
    data[~np.isfinite(data)] = np.nanmedian(data)
    return data

def acf(x):
    l = 2 ** int(np.log2(x.shape[1] * 2 - 1))
    fftx = np.fft.fft(x, n = l, axis = 1)
    ret = np.fft.ifft(fftx * np.conjugate(fftx), axis = 1)
    ret = np.fft.fftshift(ret, axes=1)
    return ret

def subband(data, sub_factor, dim='freq'):
    nfreq, nsamp = data.shape       
    return np.nansum(
        data.reshape(-1, sub_factor, nsamp) if dim == 'freq' else \
        data.reshape(nfreq, sub_factor, -1, order='f'), 
        axis=1
    )

def get_yticks(freq_id_low, freq_id_high):
    return np.linspace(freq_id_low  - 0.5, 
                       freq_id_high + 0.5, 
                       9)

def get_yticklabels(f_channels, freq_id_low, freq_id_high):
    df = np.median(np.diff(f_channels))
    
    return np.round(
        np.linspace(
            f_channels[freq_id_low - freq_id_low] - df / 2.,
            f_channels[(freq_id_high - freq_id_low) - 1] + df / 2., 
            9
        ), 
        1
    )

def dedisperse_waterfall(wfall, DM, freq, dt, ref_freq="top"):
    """
    D. Michilli's Dedisperse a wfall matrix to DM.
    """

    k_DM = 1. / 2.41e-4
    dedisp = np.zeros_like(wfall)

    # pick reference frequency for dedispersion
    if ref_freq == "top":
        reference_frequency = freq[-1]
    elif ref_freq == "center":
        center_idx = len(freq) // 2
        reference_frequency = freq[center_idx]
    elif ref_freq == "bottom":
        reference_frequency = freq[0]
    else:
        print("`ref_freq` not recognized, using 'top'")
        reference_frequency = freq[-1]

    shift = (k_DM * DM * (reference_frequency**-2 - freq**-2) / dt).round().astype(int)
    for i,ts in enumerate(wfall):
        dedisp[i] = np.roll(ts, shift[i])
    return dedisp

def get_xticks(t0, t1):
    return np.linspace(t0, 
                       t1, 
                       5)

def get_xticklabels(t0, t1):    
    return ['%.1f'  % (i) for i in np.linspace(
            (-(t0 - (t1/2)) * spectra.dt) * 1000,
            (-((t1/2) - t0) * spectra.dt) * 1000, 
            5
        )]

def get_cohenrent_spectrum(waterfall):
    """Get the coherent spectrum of the waterfall."""

    ft_waterfall = fft(waterfall)
    amp = np.abs(ft_waterfall)
    amp[amp == 0] = 1
    spect = np.sum(ft_waterfall / amp, axis=0)
    return spect

def get_coherent_power(waterfall):
    """Get the coherent power of the waterfall."""

    spectra = get_cohenrent_spectrum(waterfall)
    power = np.abs(spectra)**2
    return power

def prep_power(freq_id_low = 0,
               freq_id_high = None,
               t0 = 0,
               t1 = None,
               verbose=False):    
    if verbose:
        print ('Computing coherent power vs DM...')
        print ()
    waterfall, f_channels, freq_id_high, t1 = initialize_observation(freq_id_low=freq_id_low, 
                                                                     freq_id_high=freq_id_high, 
                                                                     t0=t0, 
                                                                     t1=t1)
    
    # Compute coherent power vs DM
    nbin = int(np.round(waterfall.shape[1] / 2))
    global power_vs_dm
    power_vs_dm = np.zeros([nbin, dm_trials.size])
    for i, dm in enumerate(dm_trials):
        power_vs_dm[:, i] = get_coherent_power(
            dedisperse_waterfall(waterfall,
                                 dm,
                                 f_channels,
                                 spectra.dt)
        )[:nbin]
        
    v = np.arange(0, nbin)
    d_power_vs_dm = power_vs_dm * v[:, np.newaxis]**2
    
    return power_vs_dm, d_power_vs_dm

def poly_max(x, y, Err):
    """
    D. Michilli's Polynomial fit
    """
    n = np.linalg.matrix_rank(np.vander(y))
    p = np.polyfit(x, y, n)
    Fac = np.std(y) / Err

    dp      = np.polyder(p)
    ddp     = np.polyder(dp)
    cands   = np.roots(dp)
    r_cands = np.polyval(ddp, cands)
    first_cut = cands[(cands.imag==0) & 
                      (cands.real>=min(x)) & 
                      (cands.real<=max(x)) & 
                      (r_cands<0)]
    
    if first_cut.size > 0:
        Value     = np.polyval(p, first_cut)
        Best      = first_cut[Value.argmax()]
        delta_x   = np.sqrt(np.abs(2 * Err / np.polyval(ddp, Best)))
    else:
        Best    = 0.
        delta_x = 0.

    return float(np.real(Best)), delta_x, p , Fac

def plot_coherent_power(power_vs_dm, 
                        d_power_vs_dm, 
                        f_channels,
                        nchan,
                        estimated_dm,
                        delta_dm, 
                        t0, 
                        t1, 
                        fluct_id_low, 
                        fluct_id_high,
                        ax_power, 
                        ax_power_prof, 
                        ax_power_res,  
                        cmap='viridis'):
    """Plot coherent power: fluctuation freq. vs DM"""

    dm_curve = d_power_vs_dm[fluct_id_low : fluct_id_high].sum(axis=0)

    fact_idx = fluct_id_low - fluct_id_high
    _max   = dm_curve.max()
    _nchan = len(f_channels)
    _mean  = nchan              # Base on Gamma(2,)
    _std   = _mean / np.sqrt(2)  # Base on Gamma(2,)
    m_fact = np.sum(np.arange(fluct_id_low, fluct_id_high)**2)
    s_fact = np.sum(np.arange(fluct_id_low, fluct_id_high)**4)**0.5
    d_mean = _mean * m_fact
    d_std  = _std  * s_fact
    snr    = (_max - d_mean) / d_std

    _peak  = dm_curve.argmax()
    _range = np.arange(_peak - 5, _peak + 5)
    y = dm_curve[_range]
    x = dm_trials[_range]
    returns_poly = poly_max(x, y, d_std)

    # Profile
    X, Y = dm_trials, dm_curve
    ax_power_prof.plot(X, Y, linewidth=3, clip_on=False)
    ax_power_prof.plot(X[_range], 
                       np.polyval(returns_poly[2], X[_range]), 
                       color='orange', 
                       linewidth=3, 
                       zorder=2, 
                       clip_on=False)
    ax_power_prof.set_xlim([X.min(), X.max()])
    ax_power_prof.set_ylim([Y.min(), Y.max()])
    ax_power_prof.ticklabel_format(useOffset=False)
    
    ax_power_prof.text(0.1, 0.8, 
                       'S/N=%.2f' % (snr), 
                       horizontalalignment='center',
                       verticalalignment='center', 
                       transform=ax_power_prof.transAxes)

    # Residuals
    res = y - np.polyval(returns_poly[2], x)
    res -= res.min()
    res /= res.max()
    
    ax_power_res.plot(x, res, 'x', linewidth=2, clip_on=False)
    ax_power_res.set_ylim([np.min(res) - np.std(res) / 2, 
                           np.max(res) + np.std(res) / 2])
    ax_power_res.set_ylabel('$\Delta$')
    ax_power_res.tick_params(axis='both', 
                             labelbottom='off', 
                             labelleft='off', 
                             direction='in', 
                             left='off', 
                             top='on')
    ax_power_res.ticklabel_format(useOffset=False)

    # Power vs DM map
    FT_len = power_vs_dm.shape[0]
    indx2Ang = 1. / (2 * FT_len * spectra.dt * 1000)
    extent = [np.min(X), np.max(X), fluct_id_low * indx2Ang, fluct_id_high * indx2Ang]
    
    ax_power.imshow(power_vs_dm[fluct_id_low : fluct_id_high], 
                    origin='lower', 
                    aspect='auto', 
                    cmap=cmap, 
                    extent=extent, 
                    interpolation='nearest')
    ax_power.tick_params(axis='both', 
                         direction='in', 
                         right='on', 
                         top='on')   
   
    dm = returns_poly[0]
    dm_std = returns_poly[1]
    
    return dm, dm_std, snr
    

def plot_waterfall(waterfall, 
                   f_channels, 
                   t0, 
                   t1, 
                   freq_id_low, 
                   freq_id_high, 
                   ax_waterfall, 
                   ax_t_snr, 
                   ax_power_prof,
                   ax_power_res,
                   delta_dm, 
                   dm_std,
                   cmap='viridis'
                  ):
    plot_wat_map = ax_waterfall.imshow(
        waterfall, 
        origin='lower', 
        aspect='auto',
        cmap=cmap, 
        interpolation='nearest',
        extent=(t0 - 0.5, 
                t1 + 0.5, 
                freq_id_low  - 0.5, 
                freq_id_high + 0.5)
    )
    
    # set time as label instead of channel numbers
    ax_waterfall.set_xticks(
        get_xticks(t0, t1)        
    )
    ax_waterfall.set_xticklabels(
        get_xticklabels(t0, t1), 
#         rotation=90
    )

    # set frequencies as label instead of channel numbers
    ax_waterfall.set_yticks(
        get_yticks(freq_id_low, freq_id_high)
    )
    ax_waterfall.set_yticklabels(
        get_yticklabels(f_channels, freq_id_low, freq_id_high), 
    )
    
    plot_wat_map.autoscale()

    # plot summed profile
    wat_prof = np.nansum(waterfall, axis=0)
    plot_wat_prof, = ax_t_snr.plot(wat_prof, '-', linewidth=2)
    ax_t_snr.set_ylim([wat_prof.min()-1, wat_prof.max()+1])
    ax_t_snr.set_xlim([0, wat_prof.size])
    ax_t_snr.text(0.1, 0.8, 
                  r'DM=%.2f $\pm$ %.2f pc/cm$^3$' % (spectra.dm + delta_dm, dm_std), 
                  horizontalalignment='center',
                  verticalalignment='center', 
                  transform=ax_t_snr.transAxes)

    ax_power_prof.axis('off')
    ax_power_res.axis('off')
    ax_t_snr.axis('off')
        
    fig.canvas.draw()
    display(fig)
    
def initialize_observation(freq_id_low = 0,
                           freq_id_high = None,
                           t0 = 0,
                           t1 = None):
    
    if freq_id_high is None:
        freq_id_high = spectra.data.shape[0]

    if t1 is None:
        t1 = spectra.data.shape[1]
    
    waterfall = spectra.data[int(freq_id_low):int(freq_id_high), int(t0):int(t1)]
    f_channels = spectra.freqs[int(freq_id_low):int(freq_id_high), ...]
        
    return waterfall, f_channels, freq_id_high, t1

def set_layout():    
    # Fluctuation vs dDM
    ax_power_prof = fig.add_subplot(gs[0:4, 0:3])
    ax_power_prof.clear()
    
    ax_power_res = fig.add_subplot(gs[4:5, 0:3])
    ax_power_res.clear()
    
    ax_power = fig.add_subplot(gs[5:, 0:3])
    ax_power.clear()
    ax_power.set_xlabel(r'$\Delta$DM (pc/cc)')
    ax_power.set_ylabel(r'Fluctuation frequency (ms$^{-1}$)')
    
    
    # Waterfall
    ax_t_snr = fig.add_subplot(gs[0:4, 3:])
    ax_t_snr.clear()

    ax_waterfall = fig.add_subplot(gs[5:, 3:])
    ax_waterfall.clear()
    ax_waterfall.set_xlabel('Time (ms)')
    ax_waterfall.set_ylabel('Frequency (MHz)')
    
#     # ACF
#     ax_acf_prof = fig.add_subplot(gs[0:4, 6:])
#     ax_acf_prof.clear()

#     ax_acf = fig.add_subplot(gs[5:, 6:])
#     ax_acf.clear()
#     ax_acf.set_xlabel('Time (ms)')
#     ax_acf.set_ylabel('Frequency (MHz)')
    
    return ax_t_snr, ax_waterfall, ax_power_prof, ax_power, ax_power_res

def select_frequency_range(fluct_id_low = 0,
                           fluct_id_high = 30,
                           freq_id_low = 0,
                           freq_id_high = None,
                           t0 = 0,
                           t1 = None,
                           ds_freq = 1,
                           ds_time = 1,
                           delta_dm = 0,
                           smooth = 0):
    """Select a frequency range from the waterfall 2D array."""   
    
    # Prep figure layout
    ax_t_snr, ax_waterfall, ax_power_prof, ax_power, ax_power_res = set_layout()
    
    # Initialize observation data
    waterfall, f_channels, freq_id_high, t1 = initialize_observation(freq_id_low=freq_id_low, 
                                                                     freq_id_high=freq_id_high, 
                                                                     t0=t0, 
                                                                     t1=t1)              
    
    dm, dm_std, snr = plot_coherent_power(filters.gaussian_filter(power_vs_dm, smooth), 
                                          filters.gaussian_filter(d_power_vs_dm, smooth), 
                                          f_channels,
                                          waterfall.shape[0],
                                          spectra.dm,
                                          delta_dm, 
                                          t0, 
                                          t1, 
                                          fluct_id_low, 
                                          fluct_id_high,
                                          ax_power, 
                                          ax_power_prof, 
                                          ax_power_res)
    
    ax_power.vlines(dm + delta_dm, 
                    ax_power.get_ylim()[0],
                    ax_power.get_ylim()[1], 
                    alpha=0.7, 
                    color='red')

    global struct_opt_dm, struct_opt_dm_err
    struct_opt_dm, struct_opt_dm_err = delta_dm + dm, dm_std
    
    waterfall = dedisperse_waterfall(waterfall,
                                     delta_dm + dm,
                                     f_channels,
                                     spectra.dt)
    
    global sub_waterfall 
    sub_waterfall = subband(
        subband(
            waterfall, 
            ds_freq, 
            dim='freq'
        ), 
        ds_time, 
        dim='time'
    )
    
    plot_waterfall(sub_waterfall, 
                   f_channels, 
                   t0, 
                   t1, 
                   freq_id_low, 
                   freq_id_high, 
                   ax_waterfall, 
                   ax_t_snr, 
                   ax_power_prof, 
                   ax_power_res,
                   delta_dm + dm, 
                   dm_std)
    
def prep_data(file, estimated_dm, downsampling, around_peak=True, verbose=False):
    """Prepare data for analysis and plotting
    """
    if verbose:
        print ('Preprocessing data...')
        print ()
    
    t_res = Backend().sampling_time
    f_channels = Backend().frequencies
    dm_trials = get_dm_trials(estimated_dm = 0,
                          dm_step = 0.1,
                          dm_range = 10)

    spectra = read_filterbank(file,
                              t_res = t_res,
                              f_channels = f_channels)

    spectra.data = RFIm().dm0clean(spectra.data)
    spectra.dedisperse(dm=estimated_dm)
#     spectra.downsample(factor=downsampling)
#     spectra.subband(spectra.data.shape[0]//4)
    spectra.data = correct_bandpass(spectra.data)
    spectra.data = RFIm().tdsc_amber(spectra.data)
#     spectra.data = RFIm().fdsc_amber(spectra.data)
        
    spectra.data = crop(spectra, 
                        t_zoom=0.05 if downsampling < 25 else 0.1 if downsampling > 1 else 0.015,
                        around_peak = around_peak)
    
    spectra.data = to_snr(spectra.data)
    
    return spectra, dm_trials
    
def initialize(input_filename, estimated_dm, downsampling, verbose=False):
    if verbose:
        print ('Loading data... %s' % (input_filename))
        print ()
    spectra, dm_trials = None, None
    i = 0
    filename = ''

    try:
        plt.clf()
    except:
        pass

    try:
        spectra, dm_trials = prep_data(input_filename, 
                                       estimated_dm, 
                                       downsampling, 
                                       around_peak=True, 
                                       verbose=verbose)
    except IndexError:
        print ('Errror with %s' % file)
        print ()
        spectra, dm_trials = prep_data(input_filename, 
                                       estimated_dm, 
                                       downsampling, 
                                       around_peak=False, 
                                       verbose=verbose)
        
    return spectra, dm_trials, input_filename
        

# Load input data

In [None]:
# Main
verbose = True

# Go from one file to the other in the list in ascending (True) or decending (False) order
# `None` will reload the same file.
incr = False

if incr:
    current_input_id += 1
elif incr is not None:
    current_input_id -= 1

print ('Current burst:')
print (df_R3.iloc[current_input_id])
print ()

if df_R3.iloc[current_input_id]['file_location'] == '':
    if incr:
        current_input_id += 1
    elif incr is not None:
        current_input_id -= 1
    
# Global variables for interaction in next panel
spectra, dm_trials, filename = initialize(df_R3.iloc[current_input_id]['file_location'], 
                                          df_R3.iloc[current_input_id]['detection_dm'],
                                          df_R3.iloc[current_input_id]['detection_downsampling'],
                                          verbose=verbose)
detection_mjd = df_R3.iloc[current_input_id]['detection_mjd']
power_vs_dm, d_power_vs_dm = prep_power(verbose=verbose)


# Interactive visualisation

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(10, 7))
gs = fig.add_gridspec(15, 6)

interact(
    select_frequency_range, 
    fluct_id_low = widgets.IntSlider(min = 0, 
                                     max = power_vs_dm.shape[0], 
                                     step = 1, 
                                     value = 0, 
                                     continuous_update=False), 
    fluct_id_high = widgets.IntSlider(min = 0, 
                                      max = power_vs_dm.shape[0], 
                                      step = 1, 
                                      value = power_vs_dm.shape[0], 
                                      continuous_update=False), 
    freq_id_low = widgets.IntSlider(min = 0, 
                                    max = spectra.data.shape[0], 
                                    step = 1, 
                                    value = 0, 
                                    continuous_update=False),
    freq_id_high = widgets.IntSlider(min = 1, 
                                     max = spectra.data.shape[0], 
                                     step = 1, 
                                     value = spectra.data.shape[0], 
                                     continuous_update=False),
    t0 = widgets.IntSlider(min = 0, 
                           max = spectra.data.shape[1], 
                           step = 1, 
                           value = 0, 
                           continuous_update=False), 
    t1 = widgets.IntSlider(min = 1, 
                           max = spectra.data.shape[1], 
                           step = 1, 
                           value = spectra.data.shape[1], 
                           continuous_update=False),
    ds_freq = widgets.IntSlider(min=1, 
                                max=32, 
                                step=1, 
                                value=1, 
                                continuous_update=False),
    ds_time = widgets.IntSlider(min=1, 
                                max=32, 
                                step=1, 
                                value=1, 
                                continuous_update=False),
    delta_dm = widgets.FloatSlider(min = -10, 
                                   max = 10, 
                                   step = 0.01, 
                                   value = 0, 
                                   continuous_update=False),
    smooth = widgets.IntSlider(min=0, 
                               max=4, 
                               step=1, 
                               value=0, 
                               continuous_update=False),
    
)

# Save button
button = widgets.Button(description="Save figure")
display(button)

button_raw = widgets.Button(description="Save raw figure")
display(button_raw)

def save_dm_to_df():
    df_R3.loc[df_R3['detection_mjd'] == detection_mjd, 'struct_opt_dm'] = struct_opt_dm 
    df_R3.loc[df_R3['detection_mjd'] == detection_mjd, 'struct_opt_dm_err'] = struct_opt_dm_err
    
    df_R3.to_csv('arts_r3.csv', index=False)

def check_dir(folder):
    if not os.path.exists(folder):
        os.makedirs(folder)
    
def save_figure(b):
    check_dir('images/manual_opt')
    check_dir('images/manual_opt/data')
    
    print ("saved to images/manual_opt/%s.png" % (filename.split('/')[-1].split('.fil')[0]))
    fig.savefig("images/manual_opt/%s.png" % (filename.split('/')[-1].split('.fil')[0]), dpi=300)

    with open("images/manual_opt/data/%s_waterfall.npy" % (filename.split('/')[-1].split('.fil')[0]), 'wb') as f:
        np.savetxt(f, sub_waterfall)
    with open("images/manual_opt/data/%s_fluctuation.npy" % (filename.split('/')[-1].split('.fil')[0]), 'wb') as f:
        np.savetxt(f, power_vs_dm)
        
    save_dm_to_df()
        
def save_figure_raw(b):
    check_dir('images/manual_opt/raw')
    check_dir('images/manual_opt/raw/data')
    
    
    print ("saved to images/manual_opt/raw/%s.png" % (filename.split('/')[-1].split('.fil')[0]))
    fig.savefig("images/manual_opt/raw/%s.png" % (filename.split('/')[-1].split('.fil')[0]), dpi=300)

    with open("images/manual_opt/raw/data/%s_waterfall.npy" % (filename.split('/')[-1].split('.fil')[0]), 'wb') as f:
        np.savetxt(f, sub_waterfall)
    with open("images/manual_opt/raw/data/%s_fluctuation.npy" % (filename.split('/')[-1].split('.fil')[0]), 'wb') as f:
        np.savetxt(f, power_vs_dm)
        
    save_dm_to_df()

button.on_click(save_figure)
button_raw.on_click(save_figure_raw)


# Dev zone 

In [None]:
spectra.data.reshape?

In [None]:
arr = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])

np.nansum(
    arr.reshape(2, 2, -1, order='f'),
    axis = 1
)
