# Currently 4 methods for removing the isolated clusters:
- old method (slow, but very precise)
- kdtree method (faster, also precise)
- miryam method (faster, seems to lose precision? would have to check)

All three are studied timewise here

## imports

In [4]:
import sys,os,os.path
sys.path.append("../../")
sys.path.append(os.path.expanduser('~/code/eol_hsrl_python'))
os.environ['ICTDIR']='/home/e78368jw/Documents/NEXT_CODE/IC'

#%load_ext autoreload
#%autoreload 2
#%matplotlib notebook

import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['mathtext.fontset'] = 'stix'
rcParams['font.family'] = 'STIXGeneral'
rcParams['figure.figsize'] = [10, 8]
rcParams['font.size'] = 22

import pandas as pd
import numpy  as np
import tables as tb

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as clrs

import IC.invisible_cities.core.core_functions                   as     coref
import IC.invisible_cities.io.dst_io                           as     dstio

from IC.invisible_cities.cities                 import beersheba as beerfun


from IC.invisible_cities.database.load_db       import DataSiPM

from IC.invisible_cities.evm.event_model        import Cluster, Hit
from IC.invisible_cities.types.ic_types         import xy
from IC.invisible_cities.reco.paolina_functions import voxelize_hits

from IC.invisible_cities.evm.event_model        import HitEnergy
from IC.invisible_cities.cities.beersheba          import DeconvolutionMode
from IC.invisible_cities.cities.beersheba          import CutType


from IC.invisible_cities.reco.deconv_functions import deconvolve
from IC.invisible_cities.reco.deconv_functions import deconvolution_input
from IC.invisible_cities.reco.deconv_functions import InterpolationMethod

import IC.invisible_cities.io.mcinfo_io as mcio

from sklearn.neighbors import NearestNeighbors



import matplotlib.cm as cm
from matplotlib.colors import Normalize

from scipy.spatial import cKDTree
import networkx as nx
import time

## basic plotting function

In [5]:
def raw_plotter(q, evt, pitch = 15.55, title = None):
    '''
    just plots the hits, nothing smart
    '''

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    xx = np.arange(q.X.min(), q.X.max() + pitch, pitch)
    yy = np.arange(q.Y.min(), q.Y.max() + pitch, pitch)
    zz = np.sort(q.Z.unique())

    zz_diff = np.diff(zz)
    zz_padded = []
    for i, z in enumerate(zz[:-1]):
        zz_padded.append(z)
        if zz_diff[i] > 4:
            zz_padded.extend(np.arange(z+4, zz[i+1], 4))
    zz_padded.append(zz[-1])
    zz = np.array(zz_padded)

    axes[0].hist2d(q.X, q.Y, bins=[xx, yy], weights=q.Q, cmin=0.0001);
    axes[0].set_xlabel('X (mm)');
    axes[0].set_ylabel('Y (mm)');

    axes[1].hist2d(q.X, q.Z, bins=[xx, zz], weights=q.Q, cmin=0.0001);
    axes[1].set_xlabel('X (mm)');
    axes[1].set_ylabel('Z (mm)');


    axes[2].hist2d(q.Y, q.Z, bins=[yy, zz], weights=q.Q, cmin=0.0001);
    axes[2].set_xlabel('Y (mm)');
    axes[2].set_ylabel('Z (mm)');
    plt.tight_layout()
    if title is None:
        fig.suptitle(f"{evt}")
    else:
        fig.suptitle(f"{title}")
    plt.show(fig)

## THE drop_isolated_clusters functions

In [6]:
#=========================================================================================================================
#============================================    MIRYAMS MAGICAL FUNCTION     ============================================
#=========================================================================================================================

def drop_isolated_miryam(distance = [16., 16., 4.],
                           variables = [],
                           nhit = 3,):
    dist = np.sqrt(3)

    def drop(df: pd.DataFrame) -> pd.DataFrame:
        if len(df) == 0:
            return df

        x   = df.X.values / distance[0]
        y   = df.Y.values / distance[1]
        z   = df.Z.values / distance[2]
        xyz = np.column_stack((x,y,z))

        

        try:
            nbrs = NearestNeighbors(radius=dist, algorithm='ball_tree').fit(xyz)
            neighbors = nbrs.radius_neighbors(xyz, return_distance=False)
            mask = np.array([len(neigh) > nhit for neigh in neighbors])
        except Exception as e:
            print(f"Error in NearestNeighbors: {{e}}")
            return df.iloc[:0]  # fallback: return empty

        pass_df = df.loc[mask].copy()

        if not pass_df.empty and variables:
            with np.errstate(divide='ignore', invalid='ignore'):
                columns = pass_df.loc[:, variables]
                scale = df[variables].sum().values / columns.sum().values
                columns *= scale
                pass_df.loc[:, variables] = columns

        return pass_df

    return drop

#=========================================================================================================================
#============================================    KDTREE MAGICAL FUNCTION      ============================================
#=========================================================================================================================

def drop_isolated_kdtree(distance = [16., 16., 4.],
                         variables = [],
                         nhit = 3,
                         verbose = False):
    
    def drop(df):
        # normalise distances and (x,y,z) array
        x   = df.X.values / distance[0]
        y   = df.Y.values / distance[1]
        z   = df.Z.values / distance[2]
        xyz = np.column_stack((x,y,z))
        
        # normalised, so define distance sqrt(3)
        dist = np.sqrt(3)

        # build KDTree of datapoints, collect pairs within distance
        xyz_tree = cKDTree(xyz)
        pairs    = xyz_tree.query_pairs(r=dist)
        
        # create graph that connects all close pairs between hit positions based on df index
        cluster_graph = nx.Graph()
        cluster_graph.add_nodes_from(range(len(df)))
        cluster_graph.add_edges_from((df.index[i], df.index[j]) for i,j in pairs)

        # Find all clusters within the graph
        clusters = list(nx.connected_components(cluster_graph))

        # collect indices of passing hits (cluster > nhit) within set
        passing_hits = set()
        clstrs = []
        for cluster in clusters:
            if len(cluster) > nhit:
                if verbose:
                    print(f'Cluster size: {len(cluster)}')
                    clstrs.append(len(cluster))
                    print(f'Passing cluster: {cluster}')
                passing_hits |= cluster
        
        if verbose:
            print(clstrs)
            plt.hist(clstrs, bins = 100)
            plt.show()
        # extract mask and apply it
        mask    = df.index.isin(passing_hits)
        pass_df = df.loc[mask, :].copy()

        # reweighting
        with np.errstate(divide='ignore'):
            columns = pass_df.loc[:, variables]
            columns *= np.divide(df.loc[:, variables].sum().values, columns.sum())
            pass_df.loc[:, variables] = columns

        return pass_df

    return drop



In [7]:
q_cut           = 5, #5 for 6mm  - 5 for 6mm
drop_dist       = [16, 16, 4.]
drop_dist_2d    = [16, 16]
nhits           = 3
nhits_low       = 1
nhits_high      = 10


cut_sensors       = beerfun.cut_over_Q   (q_cut, ['E', 'Ec'])
drop_sensors      = beerfun.drop_isolated(drop_dist, ['E', 'Ec'], nhits)
drop_sensors_2D   = beerfun.drop_isolated(drop_dist_2d, ['E', 'Ec'])
drop_sensors_miryam = drop_isolated_miryam(drop_dist, ['E', 'Ec'], nhits)
drop_sensors_3D_kdtree = drop_isolated_kdtree(drop_dist, ['E', 'Ec'], nhits)

### cut and load data

In [8]:
folder_path = 'data'

file_names = [f for f in os.listdir(folder_path) if f.endswith('.h5')]
file_paths = [os.path.join(folder_path, f) for f in file_names]

soph_hdst = dstio.load_dsts(file_paths, 'RECO', 'Events')#CHITS/lowTh

# cut
cut_hdst = cut_sensors(soph_hdst)

In [9]:
print(cut_hdst.event.nunique())

6693


### calculate execution time for each function across differing number of events averaged

In [None]:
def test_function_speed(func, df_full, iterations = 20, verbose = False):
    '''
    function that takes a function and a df and reports the time complexity of
    running the function over N events, N being the number of events.

    Returns array of time taken in seconds, and the final dropped dst
    '''
    time_taken = []

    full_array = []
    unique_events = df_full.event.unique()
    total_events = df_full.event.nunique()
    for n in range(1, len(unique_events) + 1, 1000):
        if verbose:
            print(f'{n}/{total_events} checked')
        n = min(n, len(unique_events))  # dont overflow!
        df_subset = df_full[df_full.event.isin(unique_events[:n])]
        
        execution_times = []
        for _ in range(iterations):
            start_time = time.time()
            dropped_dst = df_subset.groupby('event').apply(func).reset_index(drop=True).copy(deep=True)
            execution_times.append(time.time() - start_time)

        # Calculate the average execution time
        average_time = sum(execution_times) / len(execution_times)
        time_taken.append(average_time)

    return (time_taken, dropped_dst)


: 

In [None]:
time_taken_kdtree, df_kdtree = test_function_speed(drop_sensors_3D_kdtree, cut_hdst, verbose = True)
time_taken_miryam, df_miryam = test_function_speed(drop_sensors_miryam, cut_hdst, verbose = True)
time_taken_old, df_old       = test_function_speed(drop_sensors, cut_hdst, verbose = True)

1/6693 checked
1001/6693 checked
2001/6693 checked
3001/6693 checked
4001/6693 checked
5001/6693 checked
6001/6693 checked
1/6693 checked
1001/6693 checked
2001/6693 checked
3001/6693 checked
4001/6693 checked
5001/6693 checked
6001/6693 checked
1/6693 checked
1001/6693 checked
2001/6693 checked
3001/6693 checked
4001/6693 checked


In [None]:
N = [x for x in range(1, cut_hdst.event.nunique() + 1, 5)]

In [1]:
plt.plot(N, time_taken_miryam, label = "Miryam's method 3D")
plt.plot(N, time_taken_kdtree, label = "KDTree method 3D")
plt.plot(N, time_taken_old, label = "Current method 3D")

plt.xlabel('N')
plt.ylabel('Time (seconds)')
plt.title('Time complexity of differing functions for isolated cluster removal.')
plt.legend()
plt.show()

NameError: name 'plt' is not defined

In [None]:
plt.plot(N, time_taken_miryam, label = "Miryam's method 3D")
plt.plot(N, time_taken_kdtree, label = "KDTree method 3D")
plt.plot(N, time_taken_old, label = "Current method 3D")

plt.xlabel('N')
plt.ylabel('Time (seconds)')
plt.title('Time complexity of differing functions for isolated cluster removal.')
plt.legend()
plt.savefig('methods.pdf')
plt.show()