In [None]:
##generate ROIS from .wav Snippets

In [None]:
#generate snippets from predictions


In [None]:
#ROI generation 

#run this block to generate a CSV of ROIS (regions of interest) for your .wav file snippets
#an example test dataset of snippets can be downloaded in the supplementary materials
#if running this on your own data, we suggest placing your snippets in the following structure: /audio/sites/yoursitename/snippets
#to populate all rows of the csv you will need your .wav files named in the following convention prefix_filename_score0.0000 where score is the confidence score assigned by your classifier. 
#eg. POWL_20200423_175404_385.00-390.00_score0.4516.wav
#if you want to skip this step you can use our example ROI .csv file in the following step

# Set input and output paths here
INPUT_PATH = r"path_to_your_snippets"  # Path to folder containing .wav files - you can use our example_snippets folder 
OUTPUT_PATH = r"path_to_your_snippets"   # Path to save output csv file - you will need this csv file in the next step 

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from maad import sound, features as maad_features, rois
from maad.util import power2dB, plot2d, format_features, overlay_rois, rand_cmap
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_PATH, exist_ok=True)

def process_file(file_path, bg_smooth_coef=0.1, smooth_std=0.8, bin_std=1.5, bin_per=0.2, min_roi=30):
    try:
        s, fs = sound.load(file_path)
        s_filt = sound.select_bandwidth(s, fs, fcut=1000, forder=3, ftype='lowpass')

        db_max = 70
        Sxx, tn, fn, ext = sound.spectrogram(s_filt, fs, nperseg=2048, noverlap=1024)
        Sxx_db = power2dB(Sxx, db_range=db_max) + db_max

        Sxx_db_rmbg, _, _ = sound.remove_background(Sxx_db, smooth_coef=bg_smooth_coef)
        Sxx_db_smooth = sound.smooth(Sxx_db_rmbg, std=smooth_std)
        im_mask = rois.create_mask(im=Sxx_db_smooth, mode_bin='relative', bin_std=bin_std, bin_per=bin_per)
        im_rois, df_rois = rois.select_rois(im_mask, min_roi=min_roi, max_roi=None)

        if df_rois.empty:
            print(f"No ROIs found in file {file_path}")
            return None, None, None, None

        df_rois = format_features(df_rois, tn, fn)

        df_shape, params = maad_features.shape_features(Sxx_db, resolution='low', rois=df_rois)
        df_centroid = maad_features.centroid_features(Sxx_db, df_rois)

        if df_shape.empty or df_centroid.empty:
            print(f"Empty shape or centroid features for file {file_path}")
            return None, None, None, None

        median_freq = fn[np.round(df_centroid.centroid_y).astype(int)]
        df_centroid['centroid_freq'] = median_freq/fn[-1]

        features_df = df_shape.join(df_centroid, lsuffix='_shape', rsuffix='_centroid')
        
        return features_df, Sxx_db, ext, df_rois
    except Exception as e:
        print(f"Error processing file {file_path}: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None, None, None

def main():
    folder_path = INPUT_PATH
    output_path = OUTPUT_PATH
    
    print(f"Input path: {folder_path}")
    print(f"Output path: {output_path}")
    
    all_features = []
    file_data = []  # To store data for spectrogram generation

    wav_files = [f for f in os.listdir(folder_path) if f.endswith('.wav')]
    print(f"Found {len(wav_files)} .wav files in the folder.")

    for filename in wav_files:
        file_path = os.path.join(folder_path, filename)
        print(f"Processing file: {filename}")
        
        features, Sxx_db, ext, df_rois = process_file(
            file_path, 
            bg_smooth_coef=0.4, 
            smooth_std=0.8, 
            bin_std=1.5, 
            bin_per=0.2, 
            min_roi=20
        )
        
        if features is not None:
            print(f"Successfully processed {filename}. Shape: {features.shape}")
            features['file_name'] = filename
            all_features.append(features)
            file_data.append((filename, Sxx_db, ext, df_rois))
        else:
            print(f"Skipping file {filename} due to processing error.")
        
        print("---")

    print(f"Successfully processed {len(all_features)} out of {len(wav_files)} files.")

    if len(all_features) == 0:
        print("No features were extracted. Please check your input files and processing function.")
    else:
        combined_features = pd.concat(all_features, ignore_index=True)
        features_csv_path = os.path.join(output_path, 'all_features.csv')
        combined_features.to_csv(features_csv_path, index=False)
        print(f"Saved features to {features_csv_path}. Shape: {combined_features.shape}")

        # Prepare data for t-SNE and clustering
        X = combined_features.loc[:, combined_features.columns.str.startswith('shp')]
        X = X.join(combined_features.centroid_freq)

        # Perform t-SNE
        tsne = TSNE(n_components=2, perplexity=12, init='pca', verbose=True)
        Y = tsne.fit_transform(X)

        # Perform DBSCAN clustering
        cluster = DBSCAN(eps=5, min_samples=4).fit(Y)
        print('Number of soundtypes found:', np.unique(cluster.labels_).size)

        # Add cluster labels to the combined features
        combined_features['cluster'] = cluster.labels_

        # Save the final result with cluster labels
        clustered_csv_path = os.path.join(output_path, 'all_features_with_clusters.csv')
        combined_features.to_csv(clustered_csv_path, index=False)
        print(f"Saved clustered features to {clustered_csv_path}")

        # Visualize clustering results
        plt.figure(figsize=(12, 8))
        scatter = plt.scatter(Y[:,0], Y[:,1], c=cluster.labels_, cmap=rand_cmap(np.unique(cluster.labels_).size, first_color_black=False), alpha=0.8)
        plt.xlabel('t-SNE dimension 1')
        plt.ylabel('t-SNE dimension 2')
        plt.title('Clustering Results for All Files')
        plt.colorbar(scatter)
        clustering_plot_path = os.path.join(output_path, 'clustering_results.png')
        plt.savefig(clustering_plot_path)
        plt.close()
        print(f"Saved clustering visualization to {clustering_plot_path}")

        # Generate spectrograms with labeled ROIs for the first few files
        num_spectrograms = 10
        for i, (filename, Sxx_db, ext, df_rois) in enumerate(file_data[:num_spectrograms]):
            # Get cluster labels for this file
            file_features = combined_features[combined_features['file_name'] == filename]
            df_rois['label'] = file_features['cluster'].astype(str).values

            fig, ax = plt.subplots(figsize=(10, 6))
            plot2d(Sxx_db, ax=ax, extent=ext, vmin=0, vmax=70)
            overlay_rois(Sxx_db, ax=ax, rois=df_rois, ext=ext)
            plt.title(f"Spectrogram with Clustered ROIs: {filename}")
            plt.tight_layout()
            spectrogram_path = os.path.join(output_path, f"{filename}_clustered_spectrogram.png")
            plt.savefig(spectrogram_path)
            plt.close()
            print(f"Saved clustered spectrogram to {spectrogram_path}")

if __name__ == "__main__":
    main()

In [None]:
## cluster gui

#run the code block below to open a simple GUI that will cluster your ROIs and allow for inspection 
#if you want to colour the clusters by other variables, you can wrangle data into your csv file as needed - eg site_name or ground_truth
#we have included a sample csv (verified_rois) in the supplementary materials which you can use to test this feature
#make sure to full screen or some buttons may be missing, this will be fixed hopefully in future updates
#if you don't want to use the gui, a code just to generate the plots is available in the following cell 
#you do not need to change any paths in the code, just run the gui and then use the dropdowns to select your csv in the 'data' section, and the folder where the audio snippets were in the 'audio section' 
#you can also tweak the UMAP and HDBSCAN parameters. Example parameters are available in the paper.

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Use non-interactive Agg backend
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from sklearn.preprocessing import StandardScaler
import umap
import hdbscan
from maad import sound
from maad.util import power2dB, plot2d
import threading
import queue
import soundfile as sf
import sounddevice as sd

class ClusterAnalysisGUI:
    def __init__(self, master):
        self.master = master
        self.master.title("Owl Call Cluster Analysis Tool")
        self.master.geometry("1400x900")
        self.master.config(bg="#f0f0f0")
        
        # Data storage
        self.df = None
        self.base_dir = None
        self.clusterer = None
        self.reducer = None
        self.X_scaled = None
        self.embedding = None
        self.clusters = None
        self.progress_queue = queue.Queue()
        self.column_categories = {'numeric': [], 'categorical': []}
        
        # Feature groups
        self.shape_features = [
            'min_y_shape', 'min_x_shape', 'max_y_shape', 'max_x_shape',
            'min_f_shape', 'min_t_shape', 'max_f_shape', 'max_t_shape'
        ]
        self.shp_features = [f'shp_{i:03d}' for i in range(1, 17)]
        self.centroid_features = [
            'min_y_centroid', 'min_x_centroid', 'max_y_centroid', 'max_x_centroid',
            'min_f_centroid', 'min_t_centroid', 'max_f_centroid', 'max_t_centroid',
            'centroid_y', 'centroid_x'
        ]
        self.summary_features = [
            'duration_x', 'bandwidth_y', 'area_xy', 'centroid_freq'
        ]
        self.delta_features = [
            'delta_time', 'delta_freq', 'distance_between'
        ]
        
        # Create main layout
        self.create_main_layout()
        
        # Start monitoring the progress queue
        self.master.after(100, self.check_progress_queue)
    
    def create_main_layout(self):
        """Create the main layout for the application."""
        # Main paned window - vertical split
        main_paned = ttk.PanedWindow(self.master, orient=tk.VERTICAL)
        main_paned.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Create top and bottom frames
        top_frame = ttk.Frame(main_paned)
        bottom_frame = ttk.Frame(main_paned)
        
        # Add frames to paned window
        main_paned.add(top_frame, weight=1)
        main_paned.add(bottom_frame, weight=2)
        
        # Top section - horizontal split
        top_paned = ttk.PanedWindow(top_frame, orient=tk.HORIZONTAL)
        top_paned.pack(fill=tk.BOTH, expand=True)
        
        # Control and results frames
        control_frame = ttk.LabelFrame(top_paned, text="Control Panel")
        results_frame = ttk.LabelFrame(top_paned, text="Cluster Results")
        
        # Add frames to the horizontal pane
        top_paned.add(control_frame, weight=2)
        top_paned.add(results_frame, weight=1)
        
        # Set up control panel
        self.create_control_panel(control_frame)
        
        # Set up results panel
        self.create_results_panel(results_frame)
        
        # Set up visualization panel
        self.create_viz_panel(bottom_frame)
        
        # Set up status bar
        self.create_status_bar()
    
    def create_control_panel(self, parent):
        """Create the control panel with file selection, buttons, and settings tabs."""
        parent.columnconfigure(0, weight=1)
        parent.columnconfigure(1, weight=1)
        
        # Left column - File selection and main controls
        file_controls = ttk.Frame(parent)
        file_controls.grid(row=0, column=0, sticky="nw", padx=5, pady=5)
        
        # Right column - Settings tabs
        settings_frame = ttk.Frame(parent)
        settings_frame.grid(row=0, column=1, sticky="nsew", padx=5, pady=5)
        
        # File input section
        file_frame = ttk.LabelFrame(file_controls, text="Data Files")
        file_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # Data file row
        ttk.Label(file_frame, text="Data File:").grid(row=0, column=0, sticky="w", pady=2, padx=5)
        self.file_path_var = tk.StringVar()
        ttk.Entry(file_frame, textvariable=self.file_path_var, width=30).grid(row=0, column=1, sticky="ew", pady=2)
        ttk.Button(file_frame, text="Browse...", command=self.browse_file).grid(row=0, column=2, sticky="e", pady=2, padx=5)
        
        # Audio directory row
        ttk.Label(file_frame, text="Audio Directory:").grid(row=1, column=0, sticky="w", pady=2, padx=5)
        self.base_dir_var = tk.StringVar()
        ttk.Entry(file_frame, textvariable=self.base_dir_var, width=30).grid(row=1, column=1, sticky="ew", pady=2)
        ttk.Button(file_frame, text="Browse...", command=self.browse_dir).grid(row=1, column=2, sticky="e", pady=2, padx=5)
        
        # Main control buttons
        control_buttons = ttk.LabelFrame(file_controls, text="Main Controls")
        control_buttons.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Button(control_buttons, text="1. Load Data", 
                  command=self.load_data, width=15).grid(row=0, column=0, padx=5, pady=10)
        ttk.Button(control_buttons, text="2. Run Clustering", 
                  command=self.run_clustering, width=15).grid(row=0, column=1, padx=5, pady=10)
        ttk.Button(control_buttons, text="3. Update UMAP Plot", 
                  command=self.update_umap_plot, width=15).grid(row=0, column=2, padx=5, pady=10)
        
        # Settings tabs
        settings_notebook = ttk.Notebook(settings_frame)
        settings_notebook.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Create settings tabs
        self.create_umap_settings_tab(settings_notebook)
        self.create_hdbscan_settings_tab(settings_notebook)
        self.create_feature_settings_tab(settings_notebook)
    
    def create_umap_settings_tab(self, notebook):
        """Create the UMAP settings tab."""
        umap_tab = ttk.Frame(notebook, padding=10)
        notebook.add(umap_tab, text="UMAP Settings")
        
        # UMAP parameters
        ttk.Label(umap_tab, text="n_neighbors:").grid(row=0, column=0, sticky="w", pady=2)
        self.n_neighbors_var = tk.IntVar(value=15)
        ttk.Spinbox(umap_tab, from_=2, to=100, textvariable=self.n_neighbors_var, width=5).grid(row=0, column=1, pady=2, padx=5)
        
        ttk.Label(umap_tab, text="min_dist:").grid(row=0, column=2, sticky="w", pady=2)
        self.min_dist_var = tk.DoubleVar(value=0.1)
        ttk.Spinbox(umap_tab, from_=0.0, to=1.0, increment=0.05, textvariable=self.min_dist_var, width=5).grid(row=0, column=3, pady=2, padx=5)
        
        ttk.Label(umap_tab, text="n_components:").grid(row=1, column=0, sticky="w", pady=2)
        self.n_components_var = tk.IntVar(value=2)
        ttk.Spinbox(umap_tab, from_=2, to=5, textvariable=self.n_components_var, width=5).grid(row=1, column=1, pady=2, padx=5)
        
        ttk.Label(umap_tab, text="metric:").grid(row=1, column=2, sticky="w", pady=2)
        self.metric_var = tk.StringVar(value="euclidean")
        metrics = ["euclidean", "manhattan", "chebyshev", "minkowski", "cosine", "correlation"]
        ttk.Combobox(umap_tab, textvariable=self.metric_var, values=metrics, width=10).grid(row=1, column=3, pady=2, padx=5)
        
        ttk.Label(umap_tab, text="spread:").grid(row=2, column=0, sticky="w", pady=2)
        self.spread_var = tk.DoubleVar(value=1.0)
        ttk.Spinbox(umap_tab, from_=0.1, to=5.0, increment=0.1, textvariable=self.spread_var, width=5).grid(row=2, column=1, pady=2, padx=5)
        
        ttk.Label(umap_tab, text="local_connectivity:").grid(row=2, column=2, sticky="w", pady=2)
        self.local_connectivity_var = tk.DoubleVar(value=1.0)
        ttk.Spinbox(umap_tab, from_=0.5, to=2.0, increment=0.1, textvariable=self.local_connectivity_var, width=5).grid(row=2, column=3, pady=2, padx=5)
    
    def create_hdbscan_settings_tab(self, notebook):
        """Create the HDBSCAN settings tab."""
        hdbscan_tab = ttk.Frame(notebook, padding=10)
        notebook.add(hdbscan_tab, text="HDBSCAN Settings")
        
        ttk.Label(hdbscan_tab, text="min_cluster_size:").grid(row=0, column=0, sticky="w", pady=2)
        self.min_cluster_size_var = tk.IntVar(value=15)
        ttk.Spinbox(hdbscan_tab, from_=2, to=100, textvariable=self.min_cluster_size_var, width=5).grid(row=0, column=1, pady=2, padx=5)
        
        ttk.Label(hdbscan_tab, text="min_samples:").grid(row=0, column=2, sticky="w", pady=2)
        self.min_samples_var = tk.IntVar(value=5)
        ttk.Spinbox(hdbscan_tab, from_=1, to=50, textvariable=self.min_samples_var, width=5).grid(row=0, column=3, pady=2, padx=5)
        
        ttk.Label(hdbscan_tab, text="cluster_selection_epsilon:").grid(row=1, column=0, sticky="w", pady=2)
        self.epsilon_var = tk.DoubleVar(value=0.0)
        ttk.Spinbox(hdbscan_tab, from_=0.0, to=1.0, increment=0.05, textvariable=self.epsilon_var, width=5).grid(row=1, column=1, pady=2, padx=5)
        
        ttk.Label(hdbscan_tab, text="alpha:").grid(row=1, column=2, sticky="w", pady=2)
        self.alpha_var = tk.DoubleVar(value=1.0)
        ttk.Spinbox(hdbscan_tab, from_=0.1, to=2.0, increment=0.1, textvariable=self.alpha_var, width=5).grid(row=1, column=3, pady=2, padx=5)
        
        ttk.Label(hdbscan_tab, text="cluster_selection_method:").grid(row=2, column=0, sticky="w", pady=2)
        self.selection_method_var = tk.StringVar(value="eom")
        methods = ["eom", "leaf"]
        ttk.Combobox(hdbscan_tab, textvariable=self.selection_method_var, values=methods, width=10).grid(row=2, column=1, pady=2, padx=5)
        
        ttk.Label(hdbscan_tab, text="metric:").grid(row=2, column=2, sticky="w", pady=2)
        self.hdbscan_metric_var = tk.StringVar(value="euclidean")
        h_metrics = ["euclidean", "manhattan", "chebyshev", "minkowski", "cosine", "correlation"]
        ttk.Combobox(hdbscan_tab, textvariable=self.hdbscan_metric_var, values=h_metrics, width=10).grid(row=2, column=3, pady=2, padx=5)
    
    def create_feature_settings_tab(self, notebook):
        """Create the feature selection tab."""
        features_tab = ttk.Frame(notebook, padding=10)
        notebook.add(features_tab, text="Feature Selection")
        
        # Simple feature toggles
        ttk.Label(features_tab, text="Include Feature Groups:").grid(row=0, column=0, columnspan=2, sticky="w", pady=5)
        
        # Group 1: Basic shape features
        self.use_shape_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(features_tab, text="Basic Shape Features (min/max)", variable=self.use_shape_var)\
            .grid(row=1, column=0, sticky="w", padx=10)
        
        # Group 2: Shape descriptors
        self.use_shp_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(features_tab, text="Shape Descriptors (shp_001-016)", variable=self.use_shp_var)\
            .grid(row=1, column=1, sticky="w", padx=10)
        
        # Group 3: Centroid features
        self.use_centroid_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(features_tab, text="Centroid Features", variable=self.use_centroid_var)\
            .grid(row=2, column=0, sticky="w", padx=10)
        
        # Group 4: Summary metrics
        self.use_summary_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(features_tab, text="Summary Metrics (duration, etc.)", variable=self.use_summary_var)\
            .grid(row=2, column=1, sticky="w", padx=10)
        
        # Group 5: Note pair features
        self.use_pairs_var = tk.BooleanVar(value=False)
        ttk.Checkbutton(features_tab, text="Note Pair Features", variable=self.use_pairs_var)\
            .grid(row=3, column=0, sticky="w", padx=10)
        
        # Group 6: Delta measurements
        self.use_delta_var = tk.BooleanVar(value=False)
        ttk.Checkbutton(features_tab, text="Delta Measurements", variable=self.use_delta_var)\
            .grid(row=3, column=1, sticky="w", padx=10)
        
        # Color selector
        ttk.Label(features_tab, text="Color UMAP Plot By:").grid(row=4, column=0, sticky="w", pady=(15,2), padx=10)
        self.color_by_var = tk.StringVar(value="cluster")
        self.color_by_combo = ttk.Combobox(features_tab, textvariable=self.color_by_var, width=20, state="readonly")
        self.color_by_combo.grid(row=4, column=1, sticky="w", pady=(15,2), padx=10)
        self.color_by_combo['values'] = ['cluster']  # Will be updated after data loads
    
    def create_results_panel(self, parent):
        """Create the results panel with cluster selection and statistics."""
        # Cluster selection
        ttk.Label(parent, text="Select Cluster:").grid(row=0, column=0, sticky="w", pady=5)
        self.cluster_var = tk.StringVar()
        self.cluster_dropdown = ttk.Combobox(parent, textvariable=self.cluster_var, state="disabled", width=10)
        self.cluster_dropdown.grid(row=0, column=1, pady=5, padx=5)
        self.cluster_dropdown.bind("<<ComboboxSelected>>", self.on_cluster_selected)
        
        # Sample size
        ttk.Label(parent, text="Samples to Display:").grid(row=1, column=0, sticky="w", pady=5)
        self.sample_size_var = tk.IntVar(value=20)
        ttk.Spinbox(parent, from_=1, to=100, textvariable=self.sample_size_var, width=5).grid(row=1, column=1, pady=5, padx=5)
        ttk.Button(parent, text="Update Samples", command=self.update_samples).grid(row=1, column=2, pady=5, padx=5)
        
        # Stats display
        ttk.Label(parent, text="Cluster Statistics:").grid(row=2, column=0, sticky="w", pady=5)
        self.stats_text = tk.Text(parent, width=40, height=15, wrap=tk.WORD)
        self.stats_text.grid(row=3, column=0, columnspan=3, pady=5, sticky="nsew")
        self.stats_text.config(state=tk.DISABLED)
        
        # Make the text area expand
        parent.grid_rowconfigure(3, weight=1)
        parent.grid_columnconfigure(0, weight=1)
    
    def create_viz_panel(self, parent):
        """Create the visualization panel with UMAP and spectrograms tabs."""
        viz_frame = ttk.LabelFrame(parent, text="Visualizations")
        viz_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Create tabs for different visualizations
        self.viz_tabs = ttk.Notebook(viz_frame)
        self.viz_tabs.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # UMAP plot tab
        self.umap_tab = ttk.Frame(self.viz_tabs)
        self.viz_tabs.add(self.umap_tab, text="UMAP Plot")
        
        # Controls for the plot
        controls_frame = ttk.Frame(self.umap_tab)
        controls_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
        
        save_button = ttk.Button(controls_frame, text="Save Plot as PNG", command=self.save_umap_plot, width=15)
        save_button.pack(side=tk.RIGHT, padx=5, pady=2)
        
        # The plot itself
        self.umap_fig = Figure(figsize=(10, 10), dpi=100)
        self.umap_canvas = FigureCanvasTkAgg(self.umap_fig, self.umap_tab)
        canvas_widget = self.umap_canvas.get_tk_widget()
        canvas_widget.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Spectrograms tab
        self.spectro_tab = ttk.Frame(self.viz_tabs)
        self.viz_tabs.add(self.spectro_tab, text="Spectrograms")
        
        # Create a canvas with scrollbar for spectrograms
        self.spectro_canvas = tk.Canvas(self.spectro_tab)
        scrollbar = ttk.Scrollbar(self.spectro_tab, orient="vertical", command=self.spectro_canvas.yview)
        self.spectro_frame = ttk.Frame(self.spectro_canvas)
        
        self.spectro_canvas.configure(yscrollcommand=scrollbar.set)
        
        scrollbar.pack(side="right", fill="y")
        self.spectro_canvas.pack(side="left", fill="both", expand=True)
        self.spectro_canvas.create_window((0, 0), window=self.spectro_frame, anchor="nw")
        
        self.spectro_frame.bind("<Configure>", lambda e: self.spectro_canvas.configure(scrollregion=self.spectro_canvas.bbox("all")))
    
    def create_status_bar(self):
        """Create the status bar at the bottom of the window."""
        status_frame = ttk.Frame(self.master)
        status_frame.pack(side=tk.BOTTOM, fill=tk.X)
        
        self.status_var = tk.StringVar(value="Ready")
        status_label = ttk.Label(status_frame, textvariable=self.status_var, relief=tk.SUNKEN, anchor=tk.W)
        status_label.pack(side=tk.LEFT, fill=tk.X, expand=True)
        
        self.progress_var = tk.IntVar(value=0)
        self.progress_bar = ttk.Progressbar(status_frame, variable=self.progress_var,
                                            mode='determinate', length=200)
        self.progress_bar.pack(side=tk.RIGHT, padx=10)
    
    def browse_file(self):
        """Open file dialog to select CSV file."""
        filename = filedialog.askopenfilename(
            title="Select CSV file with ROI data",
            filetypes=(("CSV files", "*.csv"), ("All files", "*.*"))
        )
        if filename:
            self.file_path_var.set(filename)
    
    def browse_dir(self):
        """Open directory dialog to select audio directory."""
        dirname = filedialog.askdirectory(title="Select base directory for audio files")
        if dirname:
            self.base_dir_var.set(dirname)
    
    def load_data(self):
        """Load data from CSV without running clustering."""
        file_path = self.file_path_var.get()
        self.base_dir = self.base_dir_var.get()
        
        if not file_path or not os.path.exists(file_path):
            messagebox.showerror("Error", "Please select a valid CSV file.")
            return
        
        if not self.base_dir or not os.path.exists(self.base_dir):
            messagebox.showerror("Error", "Please select a valid audio directory.")
            return
        
        # Start loading in a separate thread
        self.status_var.set("Loading data...")
        self.progress_var.set(0)
        threading.Thread(target=self._load_data_thread, args=(file_path,), daemon=True).start()
    
    def _load_data_thread(self, file_path):
        """Thread function for data loading."""
        try:
            self.progress_queue.put(("status", "Loading data file..."))
            self.df = pd.read_csv(file_path)
            
            self.progress_queue.put(("status", "Analyzing columns..."))
            self.progress_queue.put(("progress", 50))
            
            self._categorize_columns()
            
            self.progress_queue.put(("update_columns", None))
            self.progress_queue.put(("progress", 100))
            self.progress_queue.put(("status", f"Data loaded: {len(self.df)} rows, {len(self.df.columns)} columns"))
            
        except Exception as e:
            self.progress_queue.put(("error", f"Error loading data: {str(e)}"))
    
    def _categorize_columns(self):
        """Categorize columns as numeric or categorical for coloring options."""
        self.column_categories = {
            'numeric': [],
            'categorical': []
        }
        
        for col in self.df.columns:
            # Example logic: 'file_name' and 'site_name' -> categorical
            if col in ['file_name', 'site_name']:
                self.column_categories['categorical'].append(col)
            elif self.df[col].dtype in [np.int64, np.float64]:
                # If it has < 20 unique numeric values, treat as categorical
                if len(self.df[col].unique()) < 20:
                    self.column_categories['categorical'].append(col)
                else:
                    self.column_categories['numeric'].append(col)
            else:
                self.column_categories['categorical'].append(col)
    
    def update_columns_in_ui(self):
        """Update the color_by dropdown with available columns."""
        if self.df is not None:
            all_options = []
            
            # Always include 'cluster' if we have cluster labels
            if hasattr(self, 'clusters') and self.clusters is not None:
                all_options.append('cluster')
            
            # Add everything from numeric/categorical
            all_options += self.column_categories['categorical'] + self.column_categories['numeric']
            
            self.color_by_combo['values'] = all_options
            
            # Default selection
            if 'cluster' in all_options:
                self.color_by_var.set('cluster')
            elif len(all_options) > 0:
                self.color_by_var.set(all_options[0])
    
    def get_selected_features(self):
        """Get selected feature columns based on UI selections."""
        selected_features = []
        
        if self.use_shape_var.get():
            selected_features.extend(self.shape_features)
        
        if self.use_shp_var.get():
            selected_features.extend(self.shp_features)
        
        if self.use_centroid_var.get():
            selected_features.extend(self.centroid_features)
        
        if self.use_summary_var.get():
            selected_features.extend(self.summary_features)
        
        # Add note2 features if selected
        if self.use_pairs_var.get():
            base_features = []
            if self.use_shp_var.get():
                base_features.extend(self.shp_features)
            if self.use_centroid_var.get():
                # Exclude min_/max_ from centroid so we only get 'centroid_y/x' basically
                base_features.extend(
                    [f for f in self.centroid_features if not (f.startswith('min_') or f.startswith('max_'))]
                )
            if self.use_summary_var.get():
                base_features.extend(self.summary_features)
            
            # Add note2_ prefix
            note2_features = [f'note2_{col}' for col in base_features]
            selected_features.extend(note2_features)
        
        # Add delta features if selected
        if self.use_delta_var.get():
            selected_features.extend(self.delta_features)
        
        if self.df is not None:
            # Filter to only features present in DataFrame
            available_features = [col for col in selected_features if col in self.df.columns]
            missing = set(selected_features) - set(available_features)
            if missing:
                print(f"Warning: Some selected features are not in the data: {missing}")
            return available_features
        
        return selected_features
    
    def run_clustering(self):
        """Run UMAP and HDBSCAN clustering."""
        file_path = self.file_path_var.get()
        self.base_dir = self.base_dir_var.get()
        
        if not file_path or not os.path.exists(file_path):
            messagebox.showerror("Error", "Please select a valid CSV file.")
            return
        
        if not self.base_dir or not os.path.exists(self.base_dir):
            messagebox.showerror("Error", "Please select a valid audio directory.")
            return
        
        # Start in a separate thread
        self.status_var.set("Loading data...")
        self.progress_var.set(0)
        
        threading.Thread(target=self.clustering_process, args=(file_path,), daemon=True).start()
    
    def clustering_process(self, file_path):
        """Process clustering in a separate thread."""
        try:
            # Step 1: Load data if not already loaded
            if self.df is None:
                self.progress_queue.put(("status", "Loading data..."))
                self.df = pd.read_csv(file_path)
                self._categorize_columns()
            
            # Step 2: Get selected features
            self.progress_queue.put(("status", "Preparing features..."))
            self.progress_queue.put(("progress", 10))
            feature_columns = self.get_selected_features()
            
            if not feature_columns:
                self.progress_queue.put(("error", "No features selected. Please select at least one feature group."))
                return
            
            missing_columns = [col for col in feature_columns if col not in self.df.columns]
            if missing_columns:
                self.progress_queue.put(("error", f"Missing columns in dataset: {missing_columns}"))
                return
            
            # Step 3: Prepare data
            self.progress_queue.put(("status", "Normalizing data..."))
            self.progress_queue.put(("progress", 20))
            
            X = self.df[feature_columns].copy()
            X.fillna(0, inplace=True)  # fill NaN with 0
            
            scaler = StandardScaler()
            self.X_scaled = scaler.fit_transform(X)
            
            # Step 4: UMAP
            self.progress_queue.put(("status", "Running UMAP..."))
            self.progress_queue.put(("progress", 30))
            
            umap_params = {
                'n_neighbors': self.n_neighbors_var.get(),
                'min_dist': self.min_dist_var.get(),
                'n_components': self.n_components_var.get(),
                'metric': self.metric_var.get(),
                'spread': self.spread_var.get(),
                'local_connectivity': self.local_connectivity_var.get(),
                'random_state': 42
            }
            
            self.reducer = umap.UMAP(**umap_params)
            self.embedding = self.reducer.fit_transform(self.X_scaled)
            
            # Step 5: HDBSCAN
            self.progress_queue.put(("status", "Running HDBSCAN..."))
            self.progress_queue.put(("progress", 60))
            
            hdbscan_params = {
                'min_cluster_size': self.min_cluster_size_var.get(),
                'min_samples': self.min_samples_var.get(),
                'cluster_selection_epsilon': self.epsilon_var.get(),
                'alpha': self.alpha_var.get(),
                'cluster_selection_method': self.selection_method_var.get(),
                'metric': self.hdbscan_metric_var.get(),
                'gen_min_span_tree': True
            }
            
            self.clusterer = hdbscan.HDBSCAN(**hdbscan_params)
            self.clusters = self.clusterer.fit_predict(self.embedding)
            
            # Step 6: Update DataFrame
            self.progress_queue.put(("status", "Processing results..."))
            self.progress_queue.put(("progress", 80))
            
            self.df['cluster'] = self.clusters
            out_file = os.path.join(os.path.dirname(file_path), "clustered_data_gui.csv")
            self.df.to_csv(out_file, index=False)
            
            # Step 7: UI updates
            self.progress_queue.put(("update_ui", None))
            self.progress_queue.put(("progress", 100))
            self.progress_queue.put(("status", f"Clustering complete. {len(np.unique(self.clusters))} clusters found."))
            
        except Exception as e:
            self.progress_queue.put(("error", f"Error during clustering: {str(e)}"))
    
    def check_progress_queue(self):
        """Check the queue for updates from the background thread."""
        try:
            while True:
                message, data = self.progress_queue.get_nowait()
                
                if message == "status":
                    self.status_var.set(data)
                elif message == "progress":
                    self.progress_var.set(data)
                elif message == "error":
                    messagebox.showerror("Error", data)
                    self.status_var.set("Ready")
                elif message == "update_ui":
                    self.update_ui_after_clustering()
                elif message == "update_columns":
                    self.update_columns_in_ui()
                elif message == "add_spectrogram":
                    self.add_spectrogram_to_ui(*data)
        except queue.Empty:
            pass
        
        self.master.after(100, self.check_progress_queue)
    
    def update_ui_after_clustering(self):
        """Update UI after clustering completes."""
        unique_clusters = sorted(np.unique(self.clusters))
        self.cluster_dropdown['values'] = unique_clusters
        self.cluster_dropdown['state'] = 'readonly'
        if unique_clusters:
            self.cluster_dropdown.current(0)
        
        # Update color_by dropdown to include cluster
        self.update_columns_in_ui()
        
        # Update UMAP plot
        self.update_umap_plot()
        
        # Show stats
        self.update_cluster_stats()
    
    def update_umap_plot(self):
        """Update the UMAP plot with clustering results, colored by the chosen column."""
        if self.embedding is None:
            messagebox.showinfo("Information", "Please run clustering first or load precomputed data.")
            return
        
        self.umap_fig.clear()
        ax = self.umap_fig.add_subplot(111)
        ax.set_facecolor('#f8f8f8')
        
        color_by = self.color_by_var.get()
        plot_title = 'UMAP Embedding'
        
        if color_by == 'cluster':
            if self.clusters is not None:
                color_data = self.clusters
                cmap = 'viridis'
                plot_title = 'UMAP Embedding with HDBSCAN Clustering'
            else:
                messagebox.showinfo("Information", "Clusters not available. Run clustering first.")
                return
        elif color_by in self.df.columns:
            color_data = self.df[color_by].values
            # Numeric vs categorical
            if color_by in self.column_categories['numeric']:
                cmap = 'plasma'
                plot_title = f'UMAP Embedding Colored by {color_by} (Numeric)'
            else:
                # For categorical, map to numeric
                unique_vals = sorted(self.df[color_by].unique())
                val_map = {val: i for i, val in enumerate(unique_vals)}
                color_data = np.array([val_map[v] for v in color_data])
                
                cmap = 'tab20' if len(unique_vals) <= 20 else 'tab20b'
                plot_title = f'UMAP Embedding Colored by {color_by} (Categorical)'
                
                # Custom legend
                handles = []
                for val, idx in val_map.items():
                    color = plt.cm.get_cmap(cmap)(
                        idx / (len(unique_vals) - 1 if len(unique_vals) > 1 else 1)
                    )
                    handles.append(plt.Line2D(
                        [0], [0], marker='o', color='w',
                        markerfacecolor=color, markersize=8, label=str(val)
                    ))
                ax.legend(handles=handles, title=color_by, bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            messagebox.showinfo("Information", f"Column {color_by} not found in data.")
            return
        
        # Scatter
        scatter = ax.scatter(
            self.embedding[:, 0], self.embedding[:, 1],
            c=color_data, cmap=cmap, alpha=0.8, s=15, edgecolors='none'
        )
        
        # For numeric data or cluster, add colorbar
        if color_by == 'cluster' or color_by in self.column_categories['numeric']:
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.ax.tick_params(labelsize=9)
        
        ax.grid(True, linestyle='--', alpha=0.3, color='#cccccc')
        ax.set_title(plot_title, fontsize=16, pad=15)
        ax.set_xlabel('UMAP 1', fontsize=12)
        ax.set_ylabel('UMAP 2', fontsize=12)
        ax.tick_params(axis='both', which='major', labelsize=9)
        ax.set_aspect('equal')
        
        for spine in ax.spines.values():
            spine.set_edgecolor('#aaaaaa')
            spine.set_linewidth(0.5)
        
        self.umap_fig.tight_layout()
        self.umap_canvas.draw()
    
    def save_umap_plot(self):
        """Save the current UMAP plot as a PNG file."""
        if self.embedding is None:
            messagebox.showinfo("Information", "No UMAP plot available to save.")
            return
        
        save_path = filedialog.asksaveasfilename(
            title="Save UMAP Plot",
            defaultextension=".png",
            filetypes=(("PNG files", "*.png"), ("All files", "*.*"))
        )
        
        if save_path:
            try:
                self.umap_fig.savefig(save_path, dpi=300, bbox_inches='tight')
                self.status_var.set(f"Plot saved to {save_path}")
            except Exception as e:
                messagebox.showerror("Error", f"Could not save plot: {str(e)}")
    
    def update_cluster_stats(self):
        """Update the statistics text for the current cluster."""
        selected_cluster = self.cluster_var.get()
        if not selected_cluster:
            return
        
        selected_cluster = int(selected_cluster)
        cluster_data = self.df[self.df['cluster'] == selected_cluster]
        
        stats = [
            f"Cluster: {selected_cluster}",
            f"Number of ROIs: {len(cluster_data)}",
            f"Percentage: {len(cluster_data) / len(self.df) * 100:.2f}%"
        ]
        
        if 'site_name' in self.df.columns:
            site_counts = cluster_data['site_name'].value_counts()
            stats.append("\nSites:")
            for site, count in site_counts.items():
                stats.append(f"  {site}: {count} ({count / len(cluster_data)*100:.1f}%)")
        
        stats.append("\nAverage Feature Values:")
        key_features = ['duration_x', 'bandwidth_y', 'centroid_freq', 'area_xy']
        for feature in key_features:
            if feature in cluster_data.columns:
                stats.append(f"  {feature}: {cluster_data[feature].mean():.4f}")
        
        self.stats_text.config(state=tk.NORMAL)
        self.stats_text.delete(1.0, tk.END)
        self.stats_text.insert(tk.END, "\n".join(stats))
        self.stats_text.config(state=tk.DISABLED)
    
    def on_cluster_selected(self, event):
        """Handle cluster dropdown selection."""
        self.update_cluster_stats()
        self.update_samples()
    
    def update_samples(self):
        """Display spectrograms for the selected cluster."""
        selected_cluster = self.cluster_var.get()
        if not selected_cluster:
            return
        
        sample_size = self.sample_size_var.get()
        selected_cluster = int(selected_cluster)
        
        # Clear spectrograms
        for widget in self.spectro_frame.winfo_children():
            widget.destroy()
        
        # Load in a separate thread
        threading.Thread(
            target=self.load_spectrograms,
            args=(selected_cluster, sample_size),
            daemon=True
        ).start()
    
    def load_spectrograms(self, cluster, sample_size):
        """Load and display spectrograms for the selected cluster (threaded)."""
        try:
            self.progress_queue.put(("status", f"Loading samples for cluster {cluster}..."))
            self.progress_queue.put(("progress", 0))
            
            cluster_data = self.df[self.df['cluster'] == cluster]
            if len(cluster_data) > sample_size:
                cluster_data = cluster_data.sample(sample_size, random_state=42)
            
            total_samples = len(cluster_data)
            
            for i, (_, row) in enumerate(cluster_data.iterrows()):
                progress = int((i+1)/total_samples*100)
                self.progress_queue.put(("progress", progress))
                
                file_name = row['file_name']
                site_name = row.get('site_name', 'unknown')
                
                # Construct possible paths
                possible_paths = [
                    os.path.join(self.base_dir, site_name, 'snippets', file_name),
                    os.path.join(self.base_dir, 'snippets', file_name),
                    os.path.join(self.base_dir, file_name)
                ]
                
                file_path = None
                for path in possible_paths:
                    if os.path.exists(path):
                        file_path = path
                        break
                
                # Add to UI, either with or without spectrogram
                self.progress_queue.put(("add_spectrogram", (file_path, row, i)))
            
            self.progress_queue.put(("status", f"Loaded {total_samples} samples for cluster {cluster}"))
            self.progress_queue.put(("progress", 100))
        
        except Exception as e:
            self.progress_queue.put(("error", f"Error loading spectrograms: {str(e)}"))
    
    def add_spectrogram_to_ui(self, file_path, row, index):
        """
        Add a single spectrogram widget to the Spectrograms tab.
        - file_path: the resolved path to the audio file (or None if not found).
        - row: the row of data (containing min_t_shape, etc.).
        - index: index in the sample list, used for layout positioning.
        """
        # Create a frame for this spectrogram
        frame = ttk.Frame(self.spectro_frame, padding=5)
        frame.grid(row=index // 2, column=index % 2, padx=5, pady=5, sticky="nsew")
        
        file_name = row.get('file_name', 'Unknown File')
        site_name = row.get('site_name', 'Unknown Site')
        
        ttk.Label(frame, text=f"File: {file_name}").grid(row=0, column=0, sticky="w")
        ttk.Label(frame, text=f"Site: {site_name}").grid(row=1, column=0, sticky="w")
        
        # Create and embed the spectrogram
        fig = self.create_spectrogram(file_path, row)
        canvas = FigureCanvasTkAgg(fig, frame)
        canvas.get_tk_widget().grid(row=2, column=0)
        
        # Add "Play Audio" or "File Not Found"
        if file_path and os.path.exists(file_path):
            play_btn = ttk.Button(frame, text="Play Audio",
                                  command=lambda fp=file_path: self.play_audio(fp))
            play_btn.grid(row=3, column=0, pady=5)
        else:
            play_btn = ttk.Button(frame, text="File Not Found", state="disabled")
            play_btn.grid(row=3, column=0, pady=5)
    
    def create_spectrogram(self, file_path, row):
        """Generate a spectrogram for the given file."""
        try:
            if file_path is None:
                # No file: show an error figure
                fig = Figure(figsize=(4, 3), dpi=100)
                ax = fig.add_subplot(111)
                ax.text(0.5, 0.5, "Audio file not found",
                        ha='center', va='center', fontsize=10)
                ax.axis('off')
                return fig
            
            # Load audio
            s, fs = sound.load(file_path)
            
            # Compute spectrogram
            Sxx, tn, fn, ext = sound.spectrogram(s, fs, nperseg=1024, noverlap=512)
            Sxx_db = power2dB(Sxx) + 96
            
            fig = Figure(figsize=(4, 3), dpi=100)
            ax = fig.add_subplot(111)
            
            # Plot spectrogram
            plot2d(Sxx_db, ax=ax, extent=ext, vmin=0, vmax=70)
            
            # Add ROI rectangle if present
            if all(c in row for c in ['min_t_shape','min_f_shape','max_t_shape','max_f_shape']):
                rect = plt.Rectangle(
                    (row['min_t_shape'], row['min_f_shape']),
                    row['max_t_shape'] - row['min_t_shape'],
                    row['max_f_shape'] - row['min_f_shape'],
                    fill=False, edgecolor='yellow', linewidth=2
                )
                ax.add_patch(rect)
            
            # Add second note rectangle if note2 columns exist
            if self.use_pairs_var.get():
                note2_keys = ['note2_min_t_shape','note2_min_f_shape','note2_max_t_shape','note2_max_f_shape']
                if all(k in row for k in note2_keys) and pd.notna(row['note2_min_t_shape']):
                    rect2 = plt.Rectangle(
                        (row['note2_min_t_shape'], row['note2_min_f_shape']),
                        row['note2_max_t_shape'] - row['note2_min_t_shape'],
                        row['note2_max_f_shape'] - row['note2_min_f_shape'],
                        fill=False, edgecolor='red', linewidth=2
                    )
                    ax.add_patch(rect2)
            
            cluster_label = row.get('cluster', 'Unknown')
            ax.set_title(f"Cluster {cluster_label}", fontsize=10)
            fig.tight_layout()
            return fig
        
        except Exception as e:
            # Show an error figure
            fig = Figure(figsize=(4, 3), dpi=100)
            ax = fig.add_subplot(111)
            ax.text(0.5, 0.5, f"Error: {str(e)}",
                    ha='center', va='center', wrap=True)
            ax.axis('off')
            return fig
    
    def play_audio(self, file_path):
        """Play the audio file using sounddevice."""
        try:
            data, samplerate = sf.read(file_path)
            sd.play(data, samplerate)
            self.status_var.set(f"Playing audio: {os.path.basename(file_path)}")
        except Exception as e:
            messagebox.showerror("Error", f"Cannot play audio: {str(e)}")

def start_gui():
    root = tk.Tk()
    app = ClusterAnalysisGUI(root)
    
    def on_closing():
        root.destroy()
    
    root.protocol("WM_DELETE_WINDOW", on_closing)
    root.mainloop()

if __name__ == "__main__":
    start_gui()


In [None]:
#create plots without gui

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import hdbscan
import matplotlib.pyplot as plt
import seaborn as sns
import umap
from tqdm import tqdm
from matplotlib.colors import ListedColormap
import sys

# Set your desired parameters here
UMAP_PARAMS = {
    'n_neighbors': 20,
    'min_dist': 0.1,
    'n_components': 2
}

HDBSCAN_PARAMS = {
    'min_cluster_size': 20,
    'min_samples': 20  # Set to None if you want to use the default
}

# Function to create color map based on column values
def create_color_map(values):
    # Convert all values to strings to handle mixed types
    str_values = values.astype(str)
    unique_values = np.unique(str_values)
    n_values = len(unique_values)
    colors = plt.cm.tab20(np.linspace(0, 1, n_values))
    color_map = ListedColormap(colors)
    value_to_int = {val: i for i, val in enumerate(unique_values)}
    return color_map, value_to_int

# Load the CSV file
try:
    df = pd.read_csv(r"C:\Users\calla\Dropbox\2024\powerfulowl\paper_important_csv_and_figs\clustered_data_custom_allroiverified.csv")
    print(f"Successfully loaded CSV file. Shape: {df.shape}")
except Exception as e:
    print(f"Error loading CSV file: {e}")
    print("Please check the file path and ensure the file exists and is accessible.")
    sys.exit(1)

# Select relevant columns for clustering
relevant_columns = [
    # New columns
    'min_y_shape', 'min_x_shape', 'max_y_shape', 'max_x_shape',
    'min_f_shape', 'min_t_shape', 'max_f_shape', 'max_t_shape',
    'shp_001', 'shp_002', 'shp_003', 'shp_004', 'shp_005', 'shp_006',
    'shp_007', 'shp_008', 'shp_009', 'shp_010', 'shp_011', 'shp_012',
    'shp_013', 'shp_014', 'shp_015', 'shp_016',
    'min_y_centroid', 'min_x_centroid', 'max_y_centroid', 'max_x_centroid',
    'min_f_centroid', 'min_t_centroid', 'max_f_centroid', 'max_t_centroid',
    'centroid_y', 'centroid_x', 'duration_x', 'bandwidth_y', 'area_xy', 'centroid_freq',
]

# Check if all relevant columns are in the DataFrame
missing_columns = [col for col in relevant_columns if col not in df.columns]
if missing_columns:
    print(f"Error: The following columns are missing from the DataFrame: {missing_columns}")
    print("Available columns:", df.columns.tolist())
    sys.exit(1)

# Prepare the data for clustering
X = df[relevant_columns]

# Normalize the data
scaler = StandardScaler()
X_normalized = scaler.fit_transform(X)

# Apply UMAP
print("Applying UMAP...")
reducer = umap.UMAP(**UMAP_PARAMS, random_state=42)
embedding = reducer.fit_transform(X_normalized)

# Apply HDBSCAN
print("Applying HDBSCAN...")
clusterer = hdbscan.HDBSCAN(**HDBSCAN_PARAMS)
clusters = clusterer.fit_predict(embedding)

# Print cluster distribution information
unique_clusters = np.unique(clusters)
print(f"\nFound {len(unique_clusters)} unique cluster labels: {unique_clusters}")
print(f"Number of noise points (cluster -1): {np.sum(clusters == -1)}")
cluster_counts = pd.Series(clusters).value_counts().sort_index()
print("Cluster distribution:")
print(cluster_counts)

# Ask user for the column to use for coloring
print("\nAvailable columns for coloring:")
print(df.columns.tolist())
color_column = input("Enter the name of the column to use for coloring: ")

if color_column not in df.columns:
    print(f"Error: Column '{color_column}' not found in the DataFrame.")
    sys.exit(1)

# Create color map based on the selected column
color_map, value_to_int = create_color_map(df[color_column])
color_values = [value_to_int[str(val)] for val in df[color_column]]

# Visualize results
print("Generating enhanced visualization...")

# Set the style for the plot
plt.style.use('seaborn-whitegrid')
sns.set_style("white")
sns.set_context("paper", font_scale=1.4)  # Increased font scale for better readability

# Create the figure and axis objects
fig, ax = plt.subplots(figsize=(18, 12))  # Increased figure size for better visibility

# Create the scatter plot
scatter = ax.scatter(embedding[:, 0], embedding[:, 1], c=color_values, cmap=color_map,
                     alpha=0.8, s=40, edgecolors='none')  # Slightly larger markers

# Customize the plot
ax.set_title(f'UMAP Embedding Colored by {color_column}', fontsize=20, fontweight='bold')
ax.set_xlabel('UMAP 1', fontsize=16)
ax.set_ylabel('UMAP 2', fontsize=16)
ax.tick_params(axis='both', which='major', labelsize=14)  # Larger tick labels

# Add a colorbar legend with larger text
cbar = plt.colorbar(scatter, ax=ax, aspect=40, pad=0.02)
cbar.set_label(color_column, fontsize=16, fontweight='bold')  # Increased font size and made bold
cbar.ax.tick_params(labelsize=14)  # Increased tick label size

# If the column is categorical or has mixed types, set discrete ticks
if df[color_column].dtype == 'object' or df[color_column].nunique() < 10 or df[color_column].dtype == 'float':
    tick_locs = [value_to_int[str(val)] for val in value_to_int.keys()]
    cbar.set_ticks(tick_locs)
    cbar.set_ticklabels(list(value_to_int.keys()))

# Add a text box with clustering parameters (larger font)
params_text = f"UMAP: n_neighbors={UMAP_PARAMS['n_neighbors']}, min_dist={UMAP_PARAMS['min_dist']}\n" \
              f"HDBSCAN: min_cluster_size={HDBSCAN_PARAMS['min_cluster_size']}, " \
              f"min_samples={HDBSCAN_PARAMS['min_samples']}"
ax.text(0.05, 0.95, params_text, transform=ax.transAxes, fontsize=14,  # Increased font size
        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, 
                                           pad=0.6))  # Added more padding and increased opacity

# Add grid lines for better readability
ax.grid(True, linestyle='--', alpha=0.7)

# Improve the layout
plt.tight_layout()

# Save the figure with higher resolution
plt.savefig('cluster2_label.png', dpi=400, bbox_inches='tight')
print("Figure saved as 'hdbscan_2nd_cluster.png'")

# Now create a second visualization showing the HDBSCAN clusters
plt.figure(figsize=(18, 12))
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=clusters, cmap='tab20', 
                      alpha=0.8, s=40, edgecolors='none')

# Create a custom colormap that makes noise points grey
cluster_cmap = plt.cm.get_cmap('tab20', len(np.unique(clusters)))
cluster_colors = cluster_cmap(np.linspace(0, 1, len(np.unique(clusters))))
if -1 in np.unique(clusters):
    cluster_colors[0] = [0.7, 0.7, 0.7, 1.0]  # Grey for noise points (-1)
custom_cmap = ListedColormap(cluster_colors)

# Plot with custom colormap
plt.figure(figsize=(18, 12))
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], 
                      c=clusters, cmap=custom_cmap, 
                      alpha=0.8, s=40, edgecolors='none')

plt.title('UMAP Embedding with HDBSCAN Clusters', fontsize=20, fontweight='bold')
plt.xlabel('UMAP 1', fontsize=16)
plt.ylabel('UMAP 2', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=14)
plt.grid(True, linestyle='--', alpha=0.7)

cbar = plt.colorbar(scatter, aspect=40, pad=0.02)
cbar.set_label('Cluster', fontsize=16, fontweight='bold')
cbar.ax.tick_params(labelsize=14)

# Add a text box with clustering parameters
plt.text(0.05, 0.95, params_text, transform=plt.gca().transAxes, fontsize=14,
         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, pad=0.6))

plt.tight_layout()
plt.savefig('cluster2_hdbscan.png', dpi=400, bbox_inches='tight')
print("Figure saved as 'hdbscan_clusters.png'")

# Show the plot (only if running interactively)
plt.show()

# Print top features for each cluster (skip noise cluster -1)
print("\nAnalyzing top features for each cluster...")
for cluster in tqdm(sorted([c for c in np.unique(clusters) if c != -1])):
    cluster_data = X[clusters == cluster]
    if len(cluster_data) > 0:  # Ensure we have data points in this cluster
        cluster_mean = cluster_data.mean()
        overall_mean = X.mean()
        feature_importance = (cluster_mean - overall_mean) / overall_mean
        # Replace NaN and infinite values with 0
        feature_importance = feature_importance.replace([np.inf, -np.inf], np.nan).fillna(0)
        top_features = feature_importance.nlargest(5)
        print(f"\nCluster {cluster} (size: {len(cluster_data)}):")
        print(top_features)

# Add cluster labels to the dataframe
df['HDBSCAN_cluster'] = clusters

# Create a more user-friendly label for clusters in a new column
# Convert to string with 'Noise' label for -1 values
df['cluster_label'] = df['HDBSCAN_cluster'].apply(lambda x: 'Noise' if x == -1 else f'Cluster {x}')

# Save results
output_filename = 'clustered_data_custom.csv'
df.to_csv(output_filename, index=False)
print(f"\nResults saved to '{output_filename}'")

# Print a summary of the clusters
cluster_summary = df['cluster_label'].value_counts().sort_index()
print("\nCluster summary:")
print(cluster_summary)

print("\nClustering completed. You can modify the UMAP_PARAMS and HDBSCAN_PARAMS at the top of the script to try different configurations.")