## Spatio-Temporal Clustering Benchmark for Collective Animal Behavior

In this notebook we benchmark various ST clustering algorithms. 
We further test their ability to operate on large amounts of data using their `fit_frame_split` method. Also, we implement grid search for spatial-temporal data.

Clustering algorithms covered
- ST DBSCAN
- ST Agglomerative
- ST KMeans
- ST Optics
- ST Spectral Clustering
- ST Affinity Propagation
- ST BIRCH
- ST HDBSCAN

### 00 Some helpful functions

In [1]:
# control execution time of functions
import threading

TIMER = 120
PERMUT = 12

class TimeoutError(Exception):
    pass

class InterruptableThread(threading.Thread):
    def __init__(self, func, *args, **kwargs):
        threading.Thread.__init__(self)
        self._func = func
        self._args = args
        self._kwargs = kwargs
        self._result = None

    def run(self):
        self._result = self._func(*self._args, **self._kwargs)

    @property
    def result(self):
        return self._result


class timeout(object):
    def __init__(self, sec):
        self._sec = sec

    def __call__(self, f):
        def wrapped_f(*args, **kwargs):
            it = InterruptableThread(f, *args, **kwargs)
            it.start()
            it.join(self._sec)
            if not it.is_alive():
                return it.result
            raise TimeoutError('execution expired')
        return wrapped_f

In [2]:
def make_generator(parameters):
    """Helper function for st_grid_search. Returns a dictionary of all possible parameter combinations."""
    if not parameters:
        yield dict()
    else:
        key_to_iterate = list(parameters.keys())[0]
        next_round_parameters = {p : parameters[p]
                    for p in parameters if p != key_to_iterate}
        for val in parameters[key_to_iterate]:
            for pars in make_generator(next_round_parameters):
                temp_res = pars
                temp_res[key_to_iterate] = val
                yield temp_res
                
def st_silhouette_score(X, labels, eps1=0.05, eps2=10, metric='euclidean'):
    """Helper function for st_grid_search. Hyperparameter combinations are evaluated with the Silhouette score."""
    n, m = X.shape
    time_dist = pdist(X[:, 0].reshape(n, 1), metric=metric)
    euc_dist = pdist(X[:, 1:], metric=metric)

    # filter the euc_dist matrix using the time_dist
    dist = np.where(time_dist <= eps2, euc_dist, 2 * eps1)

    return silhouette_score(squareform(dist), labels, metric='precomputed')

@timeout(TIMER*PERMUT)
def st_grid_search(estimator, split, X, param_dict, metric, y=None, frame_size=None, frame_overlap=None):
    """
    Grid Search of hyperparameters for spatial-temporal clustering algorithms
    
    Parameters
    ----------
    estimator: class
        ST clustering algorithm
    split: boolean
        Flag to indicate whether whole X should be loaded in RAM or processed in smaller chunks.
    X: numpy array
        Data on which grid search is performed
    param_dict: dict
        Dictionary with parameters to be optimized as keys and value range of grid search as value.
    metric: str
        The metric to evaluate the clustering quality
    y: numpy array
        Optional. Some metrics compare predictions with ground truth. Then, labels need to be provided.
    frame_size: int
        Optional. If split is True, indicate how large the chunks should be.
    
    Returns
    -------
    param_opt
        Optimal hyperparameter combination
    """
    param_opt = None
    s_max = 0
    for param in make_generator(param_dict):
        clust = estimator(**param)
        if not split:
            clust.st_fit(X)
        else:
            clust.st_fit_frame_split(X, frame_size, frame_overlap)
            
        if param_opt is None: 
            param_opt = param
        
        # different performance evaluation metrics
        if metric=='silhouette':
            try:
                score = st_silhouette_score(X=X, labels=clust.labels, eps1=param['eps1'] , eps2=param['eps2'], metric='euclidean')
            except (TypeError, ValueError) as e:
                continue
            #print('Silhouette score for parameters {}: {}'.format(param,score))
        elif metric=='ami':
            score = adjusted_mutual_info_score(y,clust.labels)

        # store parameter combination if it outperforms given the metric
        if score > s_max:
            s_max = score
            param_opt = param
    return param_opt

@timeout(TIMER*PERMUT)
def traj_grid_search(estimator, X, param_dict, metric):
    """
    Grid Search of hyperparameters for spatial-temporal clustering algorithms
    
    Parameters
    ----------
    estimator: class
        ST clustering algorithm
    split: boolean
        Flag to indicate whether whole X should be loaded in RAM or processed in smaller chunks.
    X: numpy array
        Data on which grid search is performed
    param_dict: dict
        Dictionary with parameters to be optimized as keys and value range of grid search as value.
    metric: str
        The metric to evaluate the clustering quality
    y: numpy array
        Optional. Some metrics compare predictions with ground truth. Then, labels need to be provided.
    frame_size: int
        Optional. If split is True, indicate how large the chunks should be.
    
    Returns
    -------
    param_opt
        Optimal hyperparameter combination
    """
    param_opt = {'detect_radius':40, 'similarity_threshold':0.5}
    s_max = 0
    for param in make_generator(param_dict):
        clust = estimator(**param)
        clust.st_fit(X)
        
        if param_opt is None: 
            param_opt = param
        
        # different performance evaluation metrics
        if metric=='silhouette':
            try:
                score = st_silhouette_score(X=X, labels=clust.labels, eps1=param['eps1'] , eps2=param['eps2'], metric='euclidean')
            except (TypeError, ValueError) as e:
                continue
            #print('Silhouette score for parameters {}: {}'.format(param,score))
        elif metric=='ami':
            score = adjusted_mutual_info_score(clust.true_labels,clust.labels)
            #print('AMI score for parameters {}: {}'.format(param,score))
            
        # store parameter combination if it outperforms given the metric
        if score > s_max:
            s_max = score
            param_opt = param
    return param_opt

In [3]:
class Test(object):       
    # use this function with st clusterers
    @timeout(TIMER) # set seconds for timeout
    def frame_split_cluster(self, algorithm, data, frame_size, frame_overlap):
        start_time = time.time()
        algorithm.st_fit_frame_split(data, frame_size, frame_overlap)
        runtime = time.time() - start_time
        ami = adjusted_mutual_info_score(labels, algorithm.labels)
        return ami, runtime
        
    # use this with trajectory clustering
    @timeout(TIMER)
    def traj_cluster(self,algorithm, data):
        start_time = time.time()
        algorithm.st_fit(data)
        runtime = time.time() - start_time
        ami = adjusted_mutual_info_score(algorithm.true_labels, algorithm.labels)
        return ami, runtime
        
    # use this with dbscan2
    @timeout(TIMER)
    def cluster(self, algorithm, data):
        start_time = time.time()
        algorithm.st_fit(data)
        runtime = time.time() - start_time
        ami = adjusted_mutual_info_score(labels, algorithm.labels)
        return ami, runtime

### 01 Setup

In [4]:
import numpy as np
import pandas as pd
import os
import time
import logging
import json
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import silhouette_score, adjusted_mutual_info_score
from st_clustering import ST_DBSCAN, ST_Agglomerative, ST_KMeans, ST_OPTICS, ST_SpectralClustering, ST_AffinityPropagation, ST_BIRCH, ST_HDBSCAN

In [5]:
# Configure logging of results
logging.basicConfig(level=logging.INFO, filename='cluster_results.log', filemode='a', format='%(asctime)s - %(message)s')

### 02 Find Hyperparameters

Procedure: 
1. Extensive search on small dataset. 
2. Define a smaller search space around optimal values from step 1.
3. Search on larger dataset for values from step 2.
4. Evaluate if optimal hyperparameters from step 3 are extreme values (min or max of search space). If not, search space is properly defined.

### 03 Fit

Loop to iterate over datasets

* For each dataset
    * Extract name
    * For each clusterer
         * find hyperparameter
         * cluster and write results to file

In [6]:
PATH = 'test_files'
dataset_files = os.listdir(PATH)

FRAME_SIZE = 100
FRAME_OVERLAP = 10

t = Test()

# flags set to false if timeouterror occured. we need to avoid running into the same error 
not_timed_out_dbscan = True
not_timed_out_agglo = True
not_timed_out_kmeans = True
# not_timed_out_optics = True
# not_timed_out_spectral = True
# not_timed_out_affinity = True
not_timed_out_birch = True
not_timed_out_hdbscan = True

for ds in dataset_files:
    # read data
    filename = ds
    df = pd.read_csv(PATH+"/"+filename)
    df['x'] = (df['x'] - df['x'].min()) / (df['x'].max() - df['x'].min())
    df['y'] = (df['y'] - df['y'].min()) / (df['y'].max() - df['y'].min())
    # transform to numpy array
    data = df.loc[:, ['frame','x','y']].values
    labels = df['cid'].to_numpy()
    
    # get number of clusters
    if -1 in np.unique(labels):
        n_cluster = len(np.unique(labels)) - 1
    else:
        n_cluster = len(np.unique(labels))
    
    # grid search
    param_dict_dbscan = {'eps1': [0.02, 0.03, 0.04, 0.05],'eps2':[5, 25, 50, 100], 'min_samples': [2]}
    param_dict_agglo = {'eps2':[5, 25, 50, 100], 'n_clusters': [n_cluster]}
    param_dict_kmeans = {'eps2':[5, 25, 50, 100], 'n_clusters': [n_cluster]}
    # param_dict_optics = {'eps2':[5, 15, 25], 'min_cluster_size': [2], 'min_samples' : [2]}
    # param_dict_spectral = {'eps2':[5, 15, 25], 'n_clusters': [n_cluster]}
    # param_dict_affinity = {'eps2':[5, 15, 25]}
    param_dict_birch = {'eps2':[5, 25, 50, 100], 'threshold': [0.3, 0.5], 'n_clusters': [n_cluster]}
    param_dict_hdbscan = {'eps2':[5, 25, 50, 100], 'min_cluster_size': [n_cluster], 'min_samples': [2]}
    ## ---
    
    
    # cluster and write results to log file
    
    ## ST_DBSCAN 
    if not_timed_out_dbscan:
        try:
            opt_param_dbscan = st_grid_search(estimator=ST_DBSCAN, split=True, X=data, param_dict=param_dict_dbscan, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
            dbscan = ST_DBSCAN(**opt_param_dbscan)
            ami, runtime = t.frame_split_cluster(dbscan, data, FRAME_SIZE, FRAME_OVERLAP)
            logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, dbscan, ami, runtime))
        except TimeoutError:
            logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'DBSCAN'))
            not_timed_out_dbscan = False
            pass
        except:
            logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'DBSCAN'))
            pass
    
    ## ST_Agglomerative 
    if not_timed_out_agglo:
        try:
            opt_param_agglo = st_grid_search(estimator=ST_Agglomerative, split=True, X=data, param_dict=param_dict_agglo, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
            agglo = ST_Agglomerative(**opt_param_agglo)
            ami, runtime = t.frame_split_cluster(agglo, data, FRAME_SIZE, FRAME_OVERLAP)
            logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, agglo, ami, runtime))
        except TimeoutError:
            logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'Agglomerative'))
            not_timed_out_agglo = False
            pass
        except:
            logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'Agglomerative'))
            pass
    
    ## K-MEANS
    if not_timed_out_kmeans:
        try:
            opt_param_kmeans = st_grid_search(estimator=ST_KMeans, split=True, X=data, param_dict=param_dict_kmeans, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
            kmeans = ST_KMeans(**opt_param_kmeans)
            ami, runtime = t.frame_split_cluster(kmeans, data, FRAME_SIZE, FRAME_OVERLAP)
            logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, kmeans, ami, runtime))
        except TimeoutError:
            logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'KMeans'))
            not_timed_out_kmeans = False
            pass
        except:
            logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'KMeans'))
            pass
    
    ## BIRCH
    if not_timed_out_birch:
        try:
            opt_param_birch = st_grid_search(estimator=ST_BIRCH, split=True, X=data, param_dict=param_dict_birch, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
            birch = ST_BIRCH(**opt_param_birch)
            ami, runtime = t.frame_split_cluster(birch, data, FRAME_SIZE, FRAME_OVERLAP)
            logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, birch, ami, runtime))
        except TimeoutError:
            logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'BIRCH'))
            not_timed_out_birch = False
            pass
        except:
            logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'BIRCH'))
            pass
        
    ## HDBSCAN  
    if not_timed_out_hdbscan:
        try:
            opt_param_hdbscan = st_grid_search(estimator=ST_HDBSCAN, split=True, X=data, param_dict=param_dict_birch, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
            hdbscan = ST_HDBSCAN(**opt_param_hdbscan)
            ami, runtime = t.frame_split_cluster(hdbscan, data, FRAME_SIZE, FRAME_OVERLAP)
            logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, hdbscan, ami, runtime))
        except TimeoutError:
            logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'HDBSCAN'))
            not_timed_out_hdbscan = False
            pass 
        except:
            logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'HDBSCAN'))
            pass
        
    ## OPTICS 
    # if not_timed_out_optics:
    #     try:
    #         opt_param_optics = st_grid_search(estimator=ST_OPTICS, split=True, X=data, param_dict=param_dict_optics, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
    #         optics = ST_OPTICS(**opt_param_optics)
    #         ami, runtime = t.frame_split_cluster(optics, data, FRAME_SIZE, FRAME_OVERLAP)
    #         logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, optics, ami, runtime))
    #     except TimeoutError:
    #         logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'OPTICS'))
    #         not_timed_out_optics = False
    #         pass
    #     except:
    #         logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'OPTICS'))
    #         pass
    
    ## SPECTRAL 
    # if not_timed_out_spectral:
    #     try:
    #         opt_param_spectral = st_grid_search(estimator=ST_SpectralClustering, split=True, X=data, param_dict=param_dict_spectral, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
    #         spectral = ST_SpectralClustering(**opt_param_spectral)
    #         ami, runtime = t.frame_split_cluster(spectral, data, FRAME_SIZE,FRAME_OVERLAP)
    #         logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, spectral, ami, runtime))
    #     except TimeoutError:
    #         logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'Spectral'))
    #         not_timed_out_spectral = False
    #         pass
    #     except:
    #         logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'Spectral'))
    #         pass
        
    ## AFFINITY 
    # if not_timed_out_affinity:
    #     try:
    #         opt_param_affinity = st_grid_search(estimator=ST_AffinityPropagation, split=True, X=data, param_dict=param_dict_affinity, metric='ami', y=labels, frame_size=FRAME_SIZE, frame_overlap=FRAME_OVERLAP)
    #         affinity = ST_AffinityPropagation(**opt_param_affinity)
    #         ami, runtime = t.frame_split_cluster(affinity, data, FRAME_SIZE, FRAME_OVERLAP)
    #         logging.info('dataset: {}, method: {}, ami: {}, execution time: {}'.format(filename, affinity,  ami, runtime))
    #     except TimeoutError:
    #         logging.info('TimeoutError! dataset: {}, method: {}'.format(filename, 'Affinity'))
    #         not_timed_out_affinity = False
    #         pass
    #     except:
    #         logging.info('ComputationalError! dataset: {}, method: {}'.format(filename, 'Affinity'))
    #         pass