In [1]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import os
import sys
import time
from multiprocessing import Process, Pool
import warnings
warnings.simplefilter('ignore')

# Add module path to system path
module_paths = ['..', 
                '../..',
                '../../extern/PsrPopPy',
                # '/Users/vohl/opt/miniconda3/lib/python3.9/site-packages'
                # '/Users/vohl/miniconda3/lib/python3.10/site-packages'
               ]
for module_path in module_paths:
    if os.path.abspath(os.path.join(module_path)) not in sys.path:
        sys.path.insert(0, module_path)
    
    
# For convenience
import numpy as np
import pickle

import pandas as pd

from epn_mining.main import load_states, save, load
from epn_mining.preparation import epn
from epn_mining.preparation.pulsar import Population, Observation, Pulsar, Model, Component
from epn_mining.topology import topology
from epn_mining.analysis.stats import (
    centroid as compute_centroid,
    profile_as_distribution,
    evaluate_DPGMM,
    convert_x_to_phase
)
from epn_mining.analysis import stats
from epn_mining.analysis.distance import (check_bound, check_min_max, check_neg)

from epn_mining.preparation.signal import (
    shift_max_to_center, 
    shift_centroid_to_center,
    rotate,
    best_alignment
)

from epn_mining.analysis import plotting

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits import axes_grid1
from matplotlib import rc
from matplotlib import style, collections as mc, colors, cm

style.use('default')

style.use('default')

import json
from sklearn import mixture
from scipy.stats import norm
from joblib import parallel_backend
from tqdm import tqdm

# from dtaidistance import dtw

from dtw import dtw

import astropy.coordinates as coord
import astropy.units as u
from astropy.coordinates import SkyCoord

import copy

# Set session variables
verbose = True
state_store=True

pink = (230/255, 29/255, 95/255, 1)
blue = (47/255, 161/255, 214/255, 0.2)
blue_full = (47/255, 161/255, 214/255, 1)

cmap = cm.get_cmap('cubehelix').reversed()

Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [None]:
# Else start from scratch
reference = None
stokes = 'I' #'IQUV'
exclude_references = ['gl97', 'mhq97']
input_type='json'
verbose = True

epn_metadata = epn.load_epn_metadata(base_path = '../../www.epta.eu.org/epndb/json',
                                     reference=reference,
                                     exclude_references=exclude_references,
                                     stokes=stokes,
                                     # input_type=input_type,
                                     verbose=verbose)

In [None]:
state_prefix='paper_stokes_I'
epn_metadata = load('epn_metadata', state_prefix=state_prefix)

In [None]:
epn_metadata.loc[epn_metadata['jname'] == 'J0437-4715', 'bname'].values[0]

In [None]:
epn_metadata

In [None]:
normalize = True
shift = False
resize = True
remove_baseline = False

population, epn_metadata  = epn.load_epn_data(epn_metadata,
                                              shift=shift,
                                              normalize=normalize,
                                              remove_baseline=remove_baseline,
                                              resize=resize,
                                              verbose=verbose)

# state_prefix = 'paper_not_normalized_stokes_I_only'
# save('epn_metadata', epn_metadata, state_prefix=state_prefix)
# save('population', population, state_prefix=state_prefix)

In [None]:
population.as_array().size
# for pulsar in population.as_array():
    # print (pulsar.jname, pulsar.observations.keys())

In [None]:
# Keep pulsars with observations in bins of interest
min_snr = 20
freqs_to_include = [2,3,4,5]
ref_to_fix = ['wcl+99']

for pulsar in population.as_array():
    keep = True
    for f in freqs_to_include:
        keep &= f in pulsar.observations.keys()
    if not keep:
        del population.pulsars[pulsar.jname]
    else:
        for f in freqs_to_include:
            if pulsar.observations[f].snr < min_snr:
                if pulsar.jname in population.pulsars.keys():
                    del population.pulsars[pulsar.jname]
                    
# Add meta to population
# for pulsar in population.as_array():
#     name_cond = epn_metadata['jname'] == pulsar.jname
#     pulsar.period = epn_metadata.loc[name_cond, 'P0'.lower()].values[0]
#     pulsar.period_derivative = epn_metadata.loc[name_cond, 'P1'.lower()].values[0]
#     pulsar.spindown_energy = epn_metadata.loc[name_cond, 'EDOT'.lower()].values[0]
#     pulsar.bsurf = epn_metadata.loc[name_cond, 'BSURF'.lower()].values[0]
#     pulsar.w10 = epn_metadata.loc[name_cond, 'W10'.lower()].values[0]
#     pulsar.raj = epn_metadata.loc[name_cond, 'RAJ'.lower()].values[0]
#     pulsar.decj = epn_metadata.loc[name_cond, 'DECJ'.lower()].values[0]
#     pulsar.gl = epn_metadata.loc[name_cond, 'GL'.lower()].values[0]
#     pulsar.gb = epn_metadata.loc[name_cond, 'GB'.lower()].values[0]

# Invert Stokes V where necessary
# for pulsar in population.as_array():
#     for f in freqs_to_include:
#         if pulsar.observations[f].epn_reference_code in ref_to_fix:
#             pulsar.observations[f].stokes_V = -pulsar.observations[f].stokes_V
#         if pulsar.jname in ['J0332+5434', 'J1239+2453'] and f == 4 and pulsar.observations[f].epn_reference_code == 'hx97b':
#             pulsar.observations[f].stokes_V = -pulsar.observations[f].stokes_V
#         if pulsar.jname == 'J0826+2637' and f == 5 and pulsar.observations[f].epn_reference_code == 'hx97b':
#             pulsar.observations[f].stokes_V = -pulsar.observations[f].stokes_V
            
# Align J1857+0943 freq 5 to other bins
def fix(pulsar, f_ref, f):
    shift = best_alignment(pulsar.observations[f_ref].stokes_I, 
                           pulsar.observations[f].stokes_I)
    pulsar.observations[f].stokes_I = rotate(pulsar.observations[f].stokes_I, shift)
    # pulsar.observations[f].stokes_Q = rotate(pulsar.observations[f].stokes_Q, shift)
    # pulsar.observations[f].stokes_U = rotate(pulsar.observations[f].stokes_U, shift)
    # pulsar.observations[f].stokes_V = rotate(pulsar.observations[f].stokes_V, shift)
    # pulsar.observations[f].stokes_L = rotate(pulsar.observations[f].stokes_L, shift)
    # pulsar.observations[f].position_angle = rotate(pulsar.observations[f].position_angle, shift)
    # pulsar.observations[f].position_angle_yerr_low = rotate(pulsar.observations[f].position_angle_yerr_low, shift)
    # pulsar.observations[f].position_angle_yerr_high = rotate(pulsar.observations[f].position_angle_yerr_high, shift)
    pulsar.observations[f].set_centroid()
    pulsar.observations[f].set_fwhm()
        
    return pulsar


# jname = 'J1857+0943'
# pulsar = population.pulsars[jname]
# ref = 4
# f = 5
# pulsar = fix(pulsar, ref, f)

# jname = 'J1803-2137'
# for f in [2, 5]:
#     pulsar = population.pulsars[jname]
#     fix(pulsar, ref, f)

In [None]:
# print (population.as_array().size)
# for pulsar in population.as_array():
#     for f in freqs_to_include:
#         try:
#             print (pulsar.observations[f].model_agd)
#         except:
#             print ('nope.')
#     print ()

# Find which pulsars still need fitting

In [None]:
# state_prefix = 'paper'
# epn_metadata = load('epn_metadata', state_prefix=state_prefix, folder='../states/')
_population = load('population_agd_firstgo_0_001__1__5', state_prefix=state_prefix, folder='../states/')

In [None]:
to_fit = []
for p in population.pulsars.keys():
    if p not in list(_population.pulsars.keys()):
        to_fit.append(p)
    else:
        population.pulsars[p] = copy.deepcopy(_population.pulsars[p])
        
to_fit

In [None]:
# Save current state with previous fits
save('population', population, state_prefix=state_prefix)

In [None]:
# Gausspy related (to be incorporated to epn_mining later)

import gausspy.gp as gp

from epn_mining.analysis.stats import robust_statistics, median_of_medians, median_of_stdevs, snr
from epn_mining.utils.io import state_full_location, set_state_name

def gaussian(amp, fwhm, mean):
    return lambda x: amp * np.exp(-4. * np.log(2) * (x-mean)**2 / fwhm**2)
    
def unravel(list):
    return np.array([i for array in list for i in array])

def _set(x, y, y_err, data:list={}):
    data['data_list'] = data.get('data_list', []) + [y]
    data['x_values'] = data.get('x_values', []) + [x]
    data['errors'] = data.get('errors', []) + [y_err]
    return data

def set_n_save_data(obs:Observation, data={}, variable='data', state_prefix='', 
                    verbose=False):
    data = _set(
        data = data,
        x = obs.phase, 
        y = obs.stokes_I,
        y_err = np.ones(obs.phase.size) * median_of_stdevs(obs.stokes_I)
    )
    
    save(variable, data, state_prefix=state_prefix, verbose=verbose)

def autonomous_gaussian_decomposition(state_prefix='', 
                                      variable='observation', 
                                      alpha1=1., 
                                      alpha2=None,
                                      snr_thresh=5.,
                                      return_data=True,
                                      train=False,
                                      verbose=False):    
    # init
    g = gp.GaussianDecomposer()

    # AGD parameters
    g.set('phase', 'one' if alpha2 is None else 'two')
    g.set('SNR_thresh', [snr_thresh, snr_thresh])
    g.set('alpha1', alpha1)
    if alpha2 is not None:
        g.set('alpha2', alpha2)

    # decompose
    data_decomp = g.batch_decomposition(
        science_data_path=state_full_location(state_prefix, variable)
    )

    return data_decomp

def components_arrays(obs):
    components = [gaussian(amp, fwhm, mu)(obs.phase) for i, (mu, fwhm, amp) in enumerate(zip(
        unravel(obs.model_agd['means_fit']),
        unravel(obs.model_agd['fwhms_fit']),
        unravel(obs.model_agd['amplitudes_fit']),
    ))]
    
    return components

def model_array(obs):
    for i, component in enumerate(components_arrays(obs)):
        model = component if i==0 else model + component
        
    model = (model - model.min()) / (model.max() - model.min())
        
    return model


In [None]:
# Run gausspy fit on all

freqs_to_include = [2, 3, 4, 5]
verbose=False

compute_grid = True

# Fitted alpha1, alpha2
# alpha1, alpha2 = -0.4, 0.73
snr_thresh=5.
for jname in tqdm(population_agd.fitted_later):
    pulsar = population.pulsars[jname]
    print (jname)
    for f in freqs_to_include:
        print (f)
        obs = pulsar.observations[f]
        variable = 'observation'
        set_n_save_data(obs, 
                        variable=variable, 
                        state_prefix=state_prefix,
                        verbose=verbose)
        
        if not compute_grid:
            obs.model_agd = autonomous_gaussian_decomposition(state_prefix=state_prefix, 
                                                              variable=variable, 
                                                              alpha1=alpha1,
                                                              alpha2=alpha2,
                                                              snr_thresh=snr_thresh,
                                                              verbose=False)
            obs.model_agd['alpha1'] = alpha1
            obs.model_agd['alpha2'] = alpha2
        else:
            alpha_grid = {
                f"{alpha:.3f}": autonomous_gaussian_decomposition(state_prefix=state_prefix, 
                                                           variable=variable, 
                                                           alpha1=0.7,
                                                           alpha2=alpha,
                                                           snr_thresh=5.,
                                                           verbose=False) \
                for alpha in tqdm(np.arange(0.001, 1., .05))
            }

            a_i = np.argmin([alpha_grid[a]['best_fit_rchi2'][0][0] for a in alpha_grid.keys()])
            alpha1 = float(list(alpha_grid.keys())[a_i])

            alpha_grid = {
                f"{alpha:.3f}": autonomous_gaussian_decomposition(state_prefix=state_prefix, 
                                                           variable=variable, 
                                                           alpha1=alpha1,
                                                           alpha2=alpha,
                                                           snr_thresh=5.,
                                                           verbose=False) \
                for alpha in tqdm(np.arange(0.001, 1., .05))
            }

            a_i = np.argmin([alpha_grid[a]['best_fit_rchi2'][0][0] for a in alpha_grid.keys()])
            obs.model_agd = alpha_grid[list(alpha_grid.keys())[a_i]]
            # Add alphas to object
    print ()
        
save('population_agd', population, state_prefix=state_prefix)

In [None]:
population.pulsars['J1713+0747'].observations[2].s

In [None]:
# DPGMM Model fitting

for pulsar in tqdm(population.as_array()):
    snrs = np.array([[pulsar.observations[f].snr, f] for f in freqs_to_include])
    freqs_snr_sorted = snrs[np.argsort(snrs, axis=0).T[0]][::-1].T[1]
    for i, f in enumerate(freqs_to_include):       
        pulsar.observations[f].set_model(
            'stokes_I', 
            alpha=10**4,
            threshold=True,
            n_components=30,
            cut=False, 
            scale=True,
            # mean_prior=None if i==0 else pulsar.observations[freqs_snr_sorted[0]].gmm.mean_prior_,
            # mean_precision_prior=None if i==0 else pulsar.observations[freqs_snr_sorted[0]].gmm.mean_precision_prior_,
            fwhm=pulsar.observations[freqs_snr_sorted[0]].fwhm,
            override=True
        )
        
save('population', population, state_prefix=state_prefix)

In [2]:
state_prefix='paper'
population = load('population', state_prefix=state_prefix)

In [None]:
save('population', population, state_prefix=state_prefix)

In [None]:
with_interpulse = ['J1705-1906', 'J1825-0935', 'J1857+0943', 'J1932+1059', 'J1939+2134', 'J0534+2200']

to_be_refitted = ['J2145-0750']

# Plot individual pulsars

blue_full = (47/255, 161/255, 214/255, 1)
freqs_to_include = [2,3,4,5]
# population = population_agd

# for pulsar in tqdm(population.as_array()):
# for jname in with_interpulse:
for jname in to_be_refitted:
    pulsar = population.pulsars[jname]

    fig, _ax = plt.subplots(2, 4, figsize=(15, 5), sharex=True)

    for i, f in enumerate(freqs_to_include):
        
        if f in pulsar.observations:
            obs = pulsar.observations[f]
            
            ax = _ax[0, i]
            # ax.plot(obs.phase, model_array(copy.deepcopy(obs)), color=blue_full)
            ax.plot(obs.phase, obs.model, color=blue_full)
            ax.plot(obs.phase, obs.stokes_I, color='black', linestyle=':', zorder=1000)
            ax.set_title(f"{f} {obs.frequency:.0f}MHz ({obs.epn_reference_code}) S/N:{obs.snr:.0f}")
            
            # ax.plot(obs.phase, obs.stokes_V, color='black', linestyle=':', zorder=1000)
            # ax.plot(obs.phase, obs.stokes_L, color='black', linestyle=':', zorder=1000)
            
            ax = _ax[1, i]
            ax.plot(obs.phase, np.sqrt((obs.model - obs.stokes_I)**2), color='black')
        
        _ax[0, 0].set_ylabel(f'$I$ (arb. unit)')
        _ax[1, 0].set_ylabel(r"$\sqrt{(x-\overline{x})^2}$")
        for ax in _ax[1, :]:
            ax.set_xlabel(f'$\phi$')            
            # ax.annotate(f"{obs.model_agd['best_fit_rchi2'][0][0]:.2f}", [0.1, 0.9], xycoords='axes fraction')

    plt.suptitle(f"{pulsar.jname} {pulsar.bname}", fontsize='x-large')
    plt.tight_layout()
    plt.savefig(f'images/all_freq_bins_stokes_I/{pulsar.jname}.pdf')

In [None]:
# Plot individual pulsars

#After re-fit

blue_full = (47/255, 161/255, 214/255, 1)

for jname in tqdm(population_agd.fitted_later, total=len(population_agd.fitted_later)):
    pulsar = population.pulsars[jname]

    fig, _ax = plt.subplots(2, 4, figsize=(15, 5))

    for i, f in enumerate(freqs_to_include):
        
        if f in pulsar.observations:
            obs = pulsar.observations[f]
            
            ax = _ax[0, i]
            ax.plot(obs.phase, model_array(copy.deepcopy(obs)), color=blue_full)
            ax.plot(obs.phase, obs.stokes_I, color='black', linestyle=':', zorder=1000)
            ax.set_title(f"{f} {obs.frequency:.0f}MHz ({obs.epn_reference_code}) S/N:{obs.snr:.0f}")
            
            # ax.plot(obs.phase, obs.stokes_V, color='black', linestyle=':', zorder=1000)
            # ax.plot(obs.phase, obs.stokes_L, color='black', linestyle=':', zorder=1000)
            
            ax = _ax[1, i]
            ax.plot(obs.phase, model_array(copy.deepcopy(obs)) - obs.stokes_I, color='black')
            ax.annotate(f"{obs.model_agd['best_fit_rchi2'][0][0]:.2f}", [0.1, 0.9], xycoords='axes fraction')

    plt.suptitle(f"{pulsar.jname} {pulsar.bname}", fontsize='x-large')
    plt.tight_layout()
    plt.savefig(f'images/all_freq_bins_stokes_I/{pulsar.jname}.pdf')

In [15]:
# save('epn_metadata', epn_metadata, state_prefix=state_prefix)
# print (state_prefix)
save('population', population, state_prefix=state_prefix)

In [None]:
def fix(pulsar, f_ref, f):
    shift = best_alignment(pulsar.observations[f_ref].stokes_I, 
                           pulsar.observations[f].stokes_I)
    pulsar.observations[f].stokes_I = rotate(pulsar.observations[f].stokes_I, shift)
    # pulsar.observations[f].stokes_Q = rotate(pulsar.observations[f].stokes_Q, shift)
    # pulsar.observations[f].stokes_U = rotate(pulsar.observations[f].stokes_U, shift)
    # pulsar.observations[f].stokes_V = rotate(pulsar.observations[f].stokes_V, shift)
    # pulsar.observations[f].stokes_L = rotate(pulsar.observations[f].stokes_L, shift)
    # pulsar.observations[f].position_angle = rotate(pulsar.observations[f].position_angle, shift)
    # pulsar.observations[f].position_angle_yerr_low = rotate(pulsar.observations[f].position_angle_yerr_low, shift)
    # pulsar.observations[f].position_angle_yerr_high = rotate(pulsar.observations[f].position_angle_yerr_high, shift)
    pulsar.observations[f].set_centroid()
    pulsar.observations[f].set_fwhm()
        
    return pulsar


jname = 'J2145-0750'
population.pulsars[jname] = fix(population.pulsars[jname], 4, 2)

In [None]:
population = load('population', state_prefix=state_prefix)

In [None]:
population.fitted_later = to_fit

In [None]:
population.fitted_later

In [None]:
chi_squares = np.array([[p.observations[f].model_agd['best_fit_rchi2'][0][0] for f in freqs_to_include] for p in population.as_array()])

In [None]:
population.as_array()[0].observations[2].model_agd['best_fit_rchi2'][0][0]

In [None]:
fig, ax = plt.subplots(4, 1, sharey=True, sharex=True)
for i in range(4):
    ax[i].plot(chi_squares.T[i])
    ax[i].set_yscale('log')

In [None]:
fig, ax = plt.subplots(4, 1, sharey=True, sharex=True)
chi_s = np.array([[p.observations[f].model_agd['best_fit_rchi2'][0][0] for f in freqs_to_include] for p in population.as_array()])
for i in range(4):
    ax[i].plot(chi_s.T[i])
    ax[i].set_yscale('log')

In [None]:
for pulsar in population.as_array():
    print (np.sum(model_array(pulsar.observations[2]) - model_array(pulsar.observations[3])))

In [None]:
obs.model_agd['means_fit']

In [None]:
state_prefix='paper'
population_agd = load('population_agd', state_prefix=state_prefix)
population = load('population', state_prefix=state_prefix)

In [None]:
p = population_agd.fitted_later[0]

fig, ax = plt.subplots(1, 5, figsize=(10, 3))

for i, f in enumerate(freqs_to_include):
    obs = population_agd.pulsars[p].observations[f]
    ax[i].plot(obs.phase, obs.stokes_I, c='orange', linestyle='--', zorder=1000)
    # ax[i].plot(obs.phase, model_array(population_agd.pulsars[p].observations[f]), 'black')
    ax[i].plot(obs.phase, model_array(population.pulsars[p].observations[f]))

for f in freqs_to_include:
    sc = ax[-1].plot(obs.phase, model_array(population.pulsars[p].observations[f]), label=f)
plt.legend()
    

In [None]:
p = population_agd.fitted_later[0]
population_agd.pulsars[p].observations[2].model_agd

In [None]:
population.as_array()

In [None]:
population_agd.fitted_later

In [None]:
population_agd = load('population_agd', state_prefix=state_prefix)

In [12]:
# Initialize population with Rankin's classes 
epn_metadata = load('epn_metadata', state_prefix=state_prefix)

df_rankin = pd.read_csv('../rankin-classification.csv')
df_rankin['Class'] = df_rankin['Class'].fillna('N/A')
df_rankin['Code'] = df_rankin['Code'].fillna('N/A')

In [13]:
for p in population.as_array():
    try:
        p.bname = epn_metadata.loc[epn_metadata['jname'] == p.jname, 'bname'].values[0]
        p.morphological_class = df_rankin.loc[df_rankin['JNAME'] == p.jname, 'Class'].values[0]
        p.morphological_code = df_rankin.loc[df_rankin['JNAME'] == p.jname, 'Code'].values[0]
    except IndexError:
        p.morphological_class = 'N/A'
        p.morphological_code = 'N/A'
#     print (p.jname, p.bname)


In [14]:
df_rankin['Class'].unique()

array(['Conal single', 'Core single', 'Conal double', 'Triple',
       'Multiple', 'N/A'], dtype=object)

In [None]:
df_rankin

In [None]:
save()