In [16]:
# This notebook is for writing a script that iteratively batches a dataset, applies HDBSCAN, pools the anomalies,
# then does it over again.

# Differences from batch_hdbscan.ipynb:
#
#  -  This notebook must keep track of non-anomalous clusters as well as anomalies,
#       and we must find a decent way to match them to each other across batches.
#
#  -  This notebook must get rid of the day-long process in retrieve_anomalous_hits.ipynb.
#       This will be done by assigning hits indices that are carried around with them as they
#       get batched and shuffled.
#
#  -  At the end, we should have:
#       (a) a list of anomalous hits, ideally no more than 10% of the initial dataset, which
#           can be passed through FindEvent for interesting results
#       (b) a list of RFI classes, which can be plotted in PDF form to demonstrate the efficacy
#           of the clustering

import glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
from scipy.stats import skew, kurtosis, norm, mode
from sklearn.cluster import DBSCAN, HDBSCAN
from sklearn.preprocessing import quantile_transform
import psutil
import shutil
import os
%matplotlib inline

In [10]:
def preprocess(data):

    freqs = data[:,0]
    drifts = data[:,1]
    snrs = data[:,2]
    skews = data[:,3]
    kurts = data[:,4]
    sarles = data[:,5]
    corrs = data[:,6]
    tbws = data[:,7]
    tskews = data[:,8]
    tstds = data[:,9]
    fstds = data[:,10]
    sigbws = data[:,11]

    normal_drifts = quantile_transform(drifts.reshape(len(drifts), 1), n_quantiles=100000, 
                                   output_distribution='normal', subsample=100000)
    normal_drifts = normal_drifts.reshape(len(normal_drifts))

    data_arr = np.array([np.argsort(np.argsort(freqs))/len(freqs),
                        #0.1*(freqs-np.min(freqs))/np.max(freqs-np.min(freqs)), 
                        np.abs(normal_drifts)/np.max(np.abs(normal_drifts)), 
                        (np.log10(snrs)-np.min(np.log10(snrs)))/np.max(np.log10(snrs)-np.min(np.log10(snrs))), 
                        (skews-np.min(skews))/np.max((skews-np.min(skews))), 
                        (np.log10(kurts)-np.min(np.log10(kurts)))/np.max(np.log10(kurts)-np.min(np.log10(kurts))), 
                        sarles, 
                        corrs, 
                        (np.log10(tbws*1e6)-np.min(np.log10(tbws*1e6)))/np.max(np.log10(tbws*1e6)-np.min(np.log10(tbws*1e6))),
                        (tskews-np.min(tskews))/np.max((tskews-np.min(tskews))),
                        (np.log10(tstds)-np.min(np.log10(tstds))),
                        (np.log10(fstds)-np.min(np.log10(fstds))),
                        sigbws/np.max(sigbws)
                        ])  ### PRE-PROCESSED FOR HDBSCAN

    data_arr_unscaled = np.array([freqs, 
                        drifts, 
                        snrs, 
                        skews, 
                        kurts, 
                        sarles, 
                        corrs, 
                        tbws*1e6, # units of Hz
                        tskews,
                        tstds,
                        fstds,
                        sigbws*1e6 # units of Hz
                        ])

    return np.transpose(data_arr), np.transpose(data_arr_unscaled)

def batch_hdbscan(batch_arr_scaled, nmincluster, nminsamples, eps):

    hdb = HDBSCAN(
        min_cluster_size=nmincluster, 
        min_samples=nminsamples, 
        cluster_selection_epsilon=eps, 
        #metric = 'haversine',
        leaf_size=40,
        n_jobs=10,
        cluster_selection_method='eom')
    
    X = batch_arr_scaled
    hdb.fit(X)

    labels_list = hdb.labels_
    centroids = hdb.centroids_

    return labels_list, centroids

In [7]:
hit_params = np.load('/datax/scratch/benjb/C23_L_unique_param_array.npy', allow_pickle=True)
hit_dats = np.load('/datax/scratch/benjb/C23_L_unique_dat_list.npy', allow_pickle=True)[:,1]

for dat in np.unique(hit_dats):
    shutil.copy(dat, '/datax/scratch/benjb/C23_L_dats_iterative/')

In [None]:
dataset_sizes = [len(hit_dats)]
non_anom_centroids = [] # tracks idx and centroid for each non-anomalous hit

stg_params = np.load('/datax/scratch/benjb/C23_L_full_injected_params.npy', allow_pickle=True)

full_params = np.vstack((hit_params, stg_params)) # need to be preprocessed together for quantile transforms
full_params_scaled, full_params_unscaled = preprocess(full_params)

# separate injected hits again so they can be re-injected into each batch later
hit_params_scaled_0 = full_params_scaled[:-10]
print(f'Size hit_params_scaled = {len(hit_params_scaled_0)}')
hit_params_unscaled_0 = full_params_unscaled[:-10]
print(f'Size hit_params_unscaled = {len(hit_params_unscaled_0)}')
stg_params_scaled = full_params_scaled[-10:]
stg_params_unscaled = full_params_unscaled[-10:]

hit_idxs = np.arange(len(hit_dats))

round_idxs = np.copy(hit_idxs) # i.e. 'anomalous' idxs for the current round

while len(hit_dats) > 0.1 * dataset_sizes[0]:

    print(f'{len(hit_dats)} hits in dataset ({100*len(hit_dats)/dataset_sizes[0]}% of original).')

    n_batches = len(hit_dats) // 5000
    batch_size = len(hit_dats) // n_batches  # should be approx but probably not exactly 5k

    #round_idxs = hit_idxs[keep_idxs] # these are for tracking anom hits

    hit_params_scaled = hit_params_scaled_0[round_idxs]
    hit_params_unscaled = hit_params_unscaled_0[round_idxs]

    hit_batches_scaled = []
    hit_batches_unscaled = []
    dat_batches = []
    idx_batches = []

    # shuffle hits before batching
    idxs = np.arange(len(hit_params_scaled))
    np.random.shuffle(idxs) 
    hit_params_scaled_shuffled = hit_params_scaled[idxs]
    hit_params_unscaled_shuffled = hit_params_unscaled[idxs]
    hit_dats_shuffled = hit_dats[idxs]
    round_idxs_shuffled = round_idxs[idxs]

    # do the batching
    for i in range(n_batches):

        if i != n_batches-1:
            batch_scaled = hit_params_scaled_shuffled[i*batch_size:(i+1)*batch_size]
            batch_unscaled = hit_params_unscaled_shuffled[i*batch_size:(i+1)*batch_size]
            hit_dats_batch = hit_dats_shuffled[i*batch_size:(i+1)*batch_size]
            round_idxs_batch = round_idxs_shuffled[i*batch_size:(i+1)*batch_size]
        else:
            batch_scaled = hit_params_scaled_shuffled[i*batch_size:] # catch the last few entries in the final batch
            batch_unscaled = hit_params_unscaled_shuffled[i*batch_size:]
            hit_dats_batch = hit_dats_shuffled[i*batch_size:]
            round_idxs_batch = round_idxs_shuffled[i*batch_size:]

        hit_batches_scaled.append(batch_scaled)
        hit_batches_unscaled.append(batch_unscaled)
        dat_batches.append(hit_dats_batch)
        idx_batches.append(round_idxs_batch)

    # do the HDBSCANning
    nmincluster = 4
    nminsamples = 2
    eps = 0.17

    dat_list_for_hit_deletion = np.empty((0, 2))

    for i in range(len(hit_batches_scaled)):
        #print('-----')
        bbb_scaled = np.vstack((hit_batches_scaled[i], stg_params_scaled))
        labels_list, centroids = batch_hdbscan(bbb_scaled, nmincluster, nminsamples, eps)
        dat_freq_obj = np.transpose(np.array([hit_batches_unscaled[i][:,0], dat_batches[i]], dtype='object'))
        non_anom_labels = np.array(list(
            set(np.unique(labels_list))-set(np.unique(np.concatenate((labels_list[-10:], [-1]))))
            ))
        mask = np.isin(labels_list[:-10], non_anom_labels)
        dat_freq_obj = dat_freq_obj[mask]
        centroids = centroids[non_anom_labels]

        # track non-anomalous hits for later superclustering 
        truncated_labels_list = labels_list[:-10]
        non_anom_labels_list = truncated_labels_list[mask]
        centroids_list = centroids[non_anom_labels_list]
        for j in idx_batches[i]:
            non_anom_centroids.append([idx_batches[i][j], centroids_list[j]])

        # non-anomalous hits to be deleted
        dat_list_for_hit_deletion = np.concatenate((dat_list_for_hit_deletion, dat_freq_obj))

        # delete non-anomalous hits
        for i in range(len(dat_list_for_hit_deletion)):
            freq = dat_list_for_hit_deletion[i][0]
            dat_path = '/datax/scratch/benjb/C23_L_dats_iterative/'+os.path.basename(dat_list_for_hit_deletion[i][1])
            if i%20000 == 0:
                print(f'{i}: Removing hit at frequency {freq} from {dat_path}')

            bytes_available = psutil.virtual_memory()[1]
            if bytes_available <= 32e9:
                print(f'Memory dangerously low: {bytes_available} bytes remaining. Breaking ...')
                break

            lines = []
            for line in open(dat_path):
                if not str(freq) in line:
                    lines.append(line)

            with open(dat_path, 'w') as file:
                file.writelines(lines)
                file.close()

        # update round_idxs ### CURRENTLY NOT CORRECT!!! Need to do for each batch.
        round_idxs = round_idxs[~mask] # recall that round_idxs gives the anomalous idxs for the current round

In [15]:
a = [2, 3, 5, 7, 11]
b = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

mask = np.isin(b, a)
print(mask)
print(~mask)

[False  True  True False  True False  True False False False  True False]
[ True False False  True False  True False  True  True  True False  True]
