In [None]:
%matplotlib widget

import matplotlib
matplotlib.rcParams['axes.formatter.useoffset'] = False
import matplotlib.gridspec as gspec
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from scipy.optimize import newton
from functools import partial
from glob import glob
from os import path
from scipy.special import erf
from time import time
from typing import Dict

In [None]:
colours = ["darkred", "dodgerblue", "forestgreen", "darkorange", "black", "purple", "grey", "gold", "crimson"]
bw_mhz = 856000000.0 / 1e+06
cfreq_mhz = 1284000000.0 / 1e+06

ftop = cfreq_mhz + bw_mhz / 2
fbottom = cfreq_mhz - bw_mhz / 2

# Just multiply by a DM value to get a delay across the band in s
disp_const = 4.15e+03 * (1.0 / (fbottom * fbottom) - 1.0 / (ftop * ftop))

print(f"Bandwidth: {bw_mhz}MHz")
print(f"Centre frequency: {cfreq_mhz}MHz")
print(f"Bottom frequency: {fbottom}MHz")
print(f"Top frequency: {ftop}MHz")

pulsars = ["j0835-4510", "j0901-4046", "j1326-6408"]
pulsar = pulsars[0]

spccl_files = sorted(glob(path.join("pulsars", pulsar, "tpn-0-*/2022*/beam*/Plots/used*")))
print(len(spccl_files))

In [None]:
columns_used = [1, 2, 3, 4, 5, 7, 8]
columns_names = ["MJD", "DM", "WIDTH", "SNR", "BEAM", "RA", "DEC"]
csv_part=partial(pd.read_csv, header=None, skiprows=1, delimiter="\s+", usecols=columns_used, names=columns_names)
test_data_full = pd.concat(map(csv_part, spccl_files)).sort_values(by=["MJD"], ignore_index=True)

In [None]:
for _ in np.arange(20):
    plt.close()

fig = plt.figure(figsize=(10,6))
ax = fig.gca()
sc = ax.scatter(test_data_full["MJD"], test_data_full["DM"], s=test_data_full["WIDTH"], c=test_data_full["SNR"])
plt.colorbar(sc)

In [None]:
## Full clustering test
test_data_partial = test_data_full[test_data_full["DM"] > 100.0]
#full_data = test_data_full.to_numpy()
full_data = test_data_partial.to_numpy()
full_data = np.insert(full_data, 0, np.arange(full_data.shape[0]), axis=1)

delta_dms = np.zeros(full_data.shape[0])
sigma_limit = 6.0

for idx in np.arange(delta_dms.shape[0]):
        
    delta_dm = np.linspace(0, full_data[idx, 2], 1024) - full_data[idx, 2] / 2
    zeta = 6.91e-03 * delta_dm * bw_mhz / full_data[idx, 3] / (cfreq_mhz / 1000.0)**3
    sigma_smear = full_data[idx, 4] * np.sqrt(np.pi) / 2 / zeta * erf(zeta)
    delta_dms[idx] = np.abs(delta_dm[np.where(sigma_smear >= sigma_limit)[0][0]])

full_data = np.insert(full_data, 5, delta_dms, axis=1)
full_data = np.insert(full_data, 6, np.zeros(full_data.shape[0]).astype(bool), axis=1)
full_data = np.append(full_data, np.zeros((full_data.shape[0], 1)).astype(int) - 1, axis=1)
print(full_data.shape)
print(full_data[0, :])
full_data = full_data.astype(object)

In [None]:
width_pad_s = 2.0
width_pad = width_pad_s / 86400.0

# We introduce different clustering modes
# Strong - A has to have B in its box and B has to have A in its box
# for both of them to be considered a cluster
# Relaxed - A has to have B in itx box, but B doesn't have to have A
# in its box for both of them to be considered a cluster (currently implemented)
# Weak - A and B boxes have to overlap - either of the candidates
# isn't actually required to be in the box of the other one

cluster_modes = ["strict", "relaxed", "weak"]
cluster_mode = cluster_modes[2] 

def cluster_point(point, cluster_data):
    
    if cluster_mode == "relaxed":
        mask_neighbour = np.logical_and(np.abs(cluster_data[:, 2] - point[2]) <= point[5], 
                                np.abs(cluster_data[:, 1] - point[1]) <= (point[3] / 2.0 / 1000.0 / 86400))
        mask = np.logical_and(mask_neighbour, np.logical_not(cluster_data[:, 6]))
        
    elif cluster_mode == "strict":
        
        dm_mask = np.logical_and(np.abs(cluster_data[:, 2] - point[2]) <= point[5],
                                np.abs(point[2] - cluster_data[:, 2]) <= cluster_data[:, 5])
        time_mask = np.logical_and(np.abs(cluster_data[:, 1] - point[1]) <= (point[3] / 2.0 / 1000.0 / 86400),
                                  np.abs(point[1] - cluster_data[:, 1]) <= cluster_data[:, 3] / 2.0 / 1000.0 / 86400)
        mask_neighbour = np.logical_and(dm_mask, time_mask)
        mask = np.logical_and(mask_neighbour, np.logical_not(cluster_data[:, 6]))
    
    elif cluster_mode == "weak":
        
        dm_mask = np.abs(cluster_data[:, 2] - point[2]) <= (cluster_data[:, 5] + point[5])
        time_mask = np.abs(cluster_data[:, 1] - point[1]) <= ((cluster_data[:, 3] + point[3]) / 2.0 / 1000.0 / 86400)
        
        mask_neighbour = np.logical_and(dm_mask, time_mask)
        mask = np.logical_and(mask_neighbour, np.logical_not(cluster_data[:, 6]))
    
    mask = mask.astype(bool)
    cluster_data[:, 6] = np.logical_or(cluster_data[:, 6], mask)
    
    if mask.any():
        point_neighbours = cluster_data[mask]
        for new_point in point_neighbours:
            cluster_point(new_point, cluster_data)
    else:
        return None
    
fig = plt.figure(figsize=(8, 16), tight_layout=True)

gs_main = gspec.GridSpec(2, 1, figure=fig)

ax_clusters = fig.add_subplot(gs_main[0, 0])

gs_dist = gspec.GridSpecFromSubplotSpec(8, 8, subplot_spec=gs_main[1, 0])
axes_dist = gs_dist.subplots()

iteration = 0
start_time = time()

#full_data_copy = np.copy(full_data)
#full_data_copy = np.append(full_data_copy, np.zeros((full_data_copy.shape[0], 1)) - 1, axis=1)
#print(full_data_copy.shape)

while full_data[full_data[:, 6] == False].size != 0:
#while iteration < 16:

    oldest_mjd = full_data[full_data[:, 6] == False][0][1]
    oldest_dm = full_data[full_data[:, 6] == False][0][2]
    cluster_data_mask = np.logical_and(full_data[:, 1] <= (oldest_mjd + width_pad), full_data[:, 6] == False)
    cluster_data = full_data[cluster_data_mask]

    first_point = cluster_data[0]
    cluster_point(first_point, cluster_data)
                                                                                              
    clustered = cluster_data[cluster_data[:, 6] == True]
    
    ax_clusters.scatter(clustered[:, 1], clustered[:, 2], c=colours[iteration % len(colours)], s=30, alpha=0.5)
    ax_clusters.plot(oldest_mjd, oldest_dm, marker="x", color=colours[iteration % len(colours)])
    
    if iteration < 64:
        # The DM-SNR analysis will not work in the current form
        # We are getting candidates from many beams now - the SNR will be different depending on
        # the beam the source was detected in
        axes_dist[int(iteration / 8), iteration % 8].scatter(clustered[:, 2], clustered[:, 4], c=colours[iteration % len(colours)])
    
    clustered_indices = clustered[:, 0].astype(int)

    full_data[clustered_indices, 6] = 1
    full_data[clustered_indices, 10] = iteration
    #full_data = np.delete(full_data, full_indices, axis=0)
    #full_data_copy[clustered_indices, 8] == iteration
    #full_data = np.delete(full_data, clustered_indices, axis=0)
    #full_data[:, 0] = np.arange(full_data.shape[0])
    
    #ax.scatter(full_data[:, 1], full_data[:, 2], c=colours[iteration % len(colours) + 1], s=30, alpha=0.5)
    
    iteration = iteration + 1
    
    if (iteration % 100 == 0):
        print(iteration, full_data[full_data[:, 6] == False].shape)
    
end_time = time()
    
print(f"Clustering took {iteration} iterations")
print(f"Clustering took {end_time - start_time}s")