In [2]:
import os
# https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html

import sys
sys.path.append('/Users/sophiapchung/anaconda3/lib/python3.10/site-packages')
import pandas as pd
import numpy as np
from sklearn.cluster import AgglomerativeClustering

In [11]:
class Cluster:
    def __init__(self, boxes):
        self.boxes = boxes

# Box Combination

The goal of this task is to combine boxes that are close to each other. We will use Agglomerative Clustering to cluster the boxes and then combine them based on the method (union or intersection).

In [3]:
def box_combo(file_path, method, threshold=0.5):
    """
    This function reads in a file containing boxes and clusters them using Agglomerative Clustering.
    :param file_path: predictions file path containing bounding boxes
    :param method: overlap or intersection 
    :param threshold: Agglomerative Clustering distance threshold
    :return: dataframe containing combined boxes
    """
    # Read in the data and omit unnecessary columns
    data = pd.read_csv(file_path, delimiter='\t')
    X = data[['Begin Time (s)', 'End Time (s)', 'Low Freq (Hz)', 'High Freq (Hz)']]
    clusters = AgglomerativeClustering(n_clusters=None, distance_threshold=threshold).fit(X) # distance_threshold is the maximum distance between samples for them to be considered as in the same neighborhood

    cluster_lst = []
    # For each cluster, get data points and put them in a cluster object
    for c in range(max(clusters.labels_) + 1):
        indices = list(np.where(clusters.labels_ == c)[0])
        df = X.iloc[indices]
        #print(f'Cluster {c}:')
        #print(df)
        #print('\n')

        cluster = Cluster(boxes=df.transpose().values.tolist())
        cluster_lst.append(cluster)
    
    # Combine the boxes by method (union or intersection)
    combined = pd.DataFrame(columns=['Begin Time (s)', 'End Time (s)', 'Low Freq (Hz)', 'High Freq (Hz)'])
    for c in cluster_lst:
        if method == "union":
            box = [min(c.boxes[0]), max(c.boxes[1]), min(c.boxes[2]), max(c.boxes[3])]
        if method == "intersection":
            box = [max(c.boxes[0]), min(c.boxes[1]), max(c.boxes[2]), min(c.boxes[3])]
        # Add combined box to the dataframe
        combined.loc[len(combined)] = box
    # Save the combined boxes to a file
    if not os.path.exists('combos'):
        os.makedirs('combos')
    # make a directory for 'intersection' and 'union' folders within combos
        
    combined.to_csv(f'combos/{file_path.split("/")[-1].split("_")[0]}_{method}_box_combo.txt', sep='\t', index=False)
    return combined

## Test Cases

In [17]:
file_path = 'predictions/6805.230206163827_predictions.txt'
box_combo(file_path, method="intersection")

FileExistsError: [Errno 17] File exists: 'intersection'

In [51]:
box_combo(file_path, method="union")

Cluster 0:
      Begin Time (s)  End Time (s)  Low Freq (Hz)  High Freq (Hz)
52             415.3        415.55      13.179572      164.744646
237            415.3        415.55      13.179572      164.744646
423            415.3        415.75      13.179572      164.744646
613            415.3        415.55      13.179572      164.744646
808            415.3        415.55      13.179572      164.744646
997            415.3        415.55      13.179572      164.744646
1187           415.3        415.55      13.179572      164.744646
1547           415.3        415.55      13.179572      164.744646
1735           415.3        415.55      13.179572      164.744646


Cluster 1:
      Begin Time (s)  End Time (s)  Low Freq (Hz)  High Freq (Hz)
110          1303.60        1303.8      13.179572      164.744646
293          1303.65        1303.8      13.179572      164.744646
483          1303.60        1303.8      13.179572      164.744646
681          1303.70        1303.8      13.179572   

Unnamed: 0,Begin Time (s),End Time (s),Low Freq (Hz),High Freq (Hz)
0,415.30,415.75,13.179572,164.744646
1,1303.60,1303.80,13.179572,164.744646
2,1314.55,1314.75,13.179572,151.565074
3,1532.25,1532.40,19.769357,164.744646
4,972.60,972.75,59.308072,210.873147
...,...,...,...,...
448,1669.90,1670.00,13.179572,158.154860
449,1148.15,1148.20,32.948929,144.975288
450,375.05,375.10,13.179572,164.744646
451,724.10,724.20,6.589786,151.565074
