In [1]:
import os
os.chdir("../../..")
print(os.getcwd())

/Users/titonka/FAIRIS


In [2]:
from sklearn.mixture import GaussianMixture
import numpy as np
from sklearn.cluster import KMeans, Birch
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import math
import pickle
import logging
import time

In [3]:
# Function to plot clusters on the x, y plane
def plot_clusters(xy_list, cluster_labels, name):
    """
    Plot the clusters on the x, y plane using the original (x, y) coordinates, and save the figure.

    Args:
    - xy_list (list of tuples): The original x, y coordinates for each datapoint.
    - cluster_labels (list of int): The cluster label for each datapoint.
    - name (str): The filename to save the figure as (e.g., "clusters_plot.png").
    """
    # Convert xy_list to NumPy arrays for easy plotting
    x_coords = np.array([x for x, y in xy_list])
    y_coords = np.array([y for x, y in xy_list])

    # Check if the lengths match
    if len(x_coords) != len(cluster_labels):
        raise ValueError(f"Mismatch: {len(x_coords)} coordinates and {len(cluster_labels)} cluster labels.")

    # Scatter plot with color coding for clusters
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(x_coords, y_coords, c=cluster_labels, cmap='rainbow', alpha=0.7)

    # Add color bar to indicate clusters
    plt.colorbar(scatter, label='Cluster')

    # Add labels and title
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.title(f'Clustering with {len(set(cluster_labels))} Clusters on X-Y Plane')

    # Save the plot to the specified file
    plt.savefig(name)

    # Close the plot to avoid displaying it when running in scripts
    plt.close()


def plot_clusters_by_subplots(xy_list, theta_list, cluster_labels, name, n_clusters=8):
    """
    Plot the clusters on subplots, one for each cluster, using the original (x, y) coordinates
    and their corresponding direction vectors.

    Args:
    - xy_list (list of tuples): The original x, y coordinates for each datapoint.
    - theta_list (list of floats): The orientation (theta) in degrees for each datapoint.
    - cluster_labels (list of int): The cluster label for each datapoint.
    - name (str): The filename to save the plot.
    - n_clusters (int): Number of clusters to plot.
    """
    # Fixed number of columns
    cols = 5
    # Calculate the number of rows needed
    rows = int(math.ceil(n_clusters / cols))

    # Dynamically scale the figsize based on rows and columns
    width_per_col = 6  # Adjust this for horizontal scaling
    height_per_row = 6  # Adjust this for vertical scaling
    figsize = (cols * width_per_col, rows * height_per_row)

    # Convert xy_list to NumPy arrays for easy filtering and plotting
    x_coords = np.array([x for x, y in xy_list])
    y_coords = np.array([y for x, y in xy_list])
    theta_list = np.array([theta for theta in theta_list])
    # Create subplots
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten()  # Flatten the axes array for easy indexing

    for cluster_id in range(n_clusters):
        # Get the indices of the datapoints that belong to the current cluster
        cluster_indices = np.where(cluster_labels == cluster_id)[0]

        # Filter x and y coordinates for the current cluster
        cluster_x = x_coords[cluster_indices]
        cluster_y = y_coords[cluster_indices]

        # Extract theta values for the current cluster
        cluster_theta = theta_list[cluster_indices]

        # Compute dx and dy for each point in the current cluster
        cluster_dx = 0.5 * np.cos(np.radians(cluster_theta))
        cluster_dy = 0.5 * np.sin(np.radians(cluster_theta))

        # Scatter plot for the current cluster
        axes[cluster_id].scatter(cluster_x, cluster_y, c=f'C{cluster_id}', alpha=0.7, label='Points')

        # Add quiver plot for vectors
        axes[cluster_id].quiver(cluster_x, cluster_y, cluster_dx, cluster_dy, angles='xy', scale_units='xy', scale=1,
                                color='black', alpha=0.7, label='Vectors')

        # Set subplot title and labels
        axes[cluster_id].set_title(f'Cluster {cluster_id}')
        axes[cluster_id].set_xlabel('X Coordinate')
        axes[cluster_id].set_ylabel('Y Coordinate')
        axes[cluster_id].set_xlim(-3, 3)  # Set x-axis limits
        axes[cluster_id].set_ylim(-3, 3)  # Set y-axis limits
        axes[cluster_id].legend()

    # Hide unused subplots (if n_clusters < len(axes))
    for i in range(n_clusters, len(axes)):
        axes[i].axis('off')

    # Adjust layout for better spacing
    plt.tight_layout()

    # Save the plot to the specified file
    fig.savefig(name)
    plt.close()

def format_data_for_clustering(data):
    multimodal_feature_vectors = []
    cnn_feature_vectors = []
    xy_list = []
    theta_list = []
    for observation in data.observations:
        multimodal_feature_vectors.append(observation.multimodal_feature_vector)
        cnn_feature_vectors.append(observation.cnn_feature_vector)
        xy_list.append((observation.x, observation.y))
        theta_list.append(observation.theta)

    return multimodal_feature_vectors, cnn_feature_vectors, xy_list, theta_list


def cluster_with_kmeans_and_save_centers(features_list, n_clusters, centers_save_path):
    """
    Perform KMeans clustering, calculate the maximum distance for each cluster,
    and save the cluster centers and max distances as a list of lists using pickle.

    Args:
    - features_list (list of numpy arrays): The feature vectors extracted from the images.
    - n_clusters (int): The number of clusters to form.
    - centers_save_path (str): Path to save the cluster centers and max distances using pickle.

    Returns:
    - cluster_labels (list of int): The cluster label for each datapoint.
    - cluster_centers (numpy array): The centers of the final clusters.
    """
    features_array = np.array(features_list)

    # Perform KMeans clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(features_array)

    # Get cluster centers and labels for each data point
    cluster_centers = kmeans.cluster_centers_
    labels = kmeans.labels_

    # List to store [center, max_distance] for each cluster
    cluster_data = []

    # Calculate max distance for each cluster
    for cluster_index in range(n_clusters):
        # Get data points belonging to this cluster
        cluster_points = features_array[labels == cluster_index]

        # Calculate distances from each point to the cluster center
        distances = cdist(cluster_points, [cluster_centers[cluster_index]], metric='euclidean').flatten()

        # Find the maximum distance for this cluster
        max_distance = distances.max()

        # Append the center and max distance as a pair to cluster_data
        cluster_data.append([cluster_centers[cluster_index].tolist(), max_distance])

    # Save cluster_data (centers and max distances) using pickle
    with open(centers_save_path, 'wb') as f:
        pickle.dump(cluster_data, f)

    return cluster_labels

In [5]:
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def cluster_with_gmm_and_save_centers(features_list, n_clusters, centers_save_path, use_pca=False, n_components=100):
    """
    Perform Gaussian Mixture Model (GMM) clustering, calculate the maximum distance for each cluster,
    and save the cluster means and max distances as a list of lists using pickle.
    Optionally apply PCA to reduce dimensionality.

    Args:
        features_list (list of numpy arrays): The feature vectors extracted from the images.
        n_clusters (int): The number of clusters to form.
        centers_save_path (str): Path to save the cluster means and max distances using pickle.
        use_pca (bool): Whether to apply PCA for dimensionality reduction. Default is False.
        n_components (int): Number of components to keep if PCA is used. Default is 100.

    Returns:
        cluster_labels (list of int): The cluster label for each datapoint (based on highest probability).
    """
    start_time = time.time()
    logging.info("Starting GMM clustering with %d data points and %d clusters", len(features_list), n_clusters)

    # Convert features_list to numpy array
    try:
        features_array = np.array(features_list, dtype=np.float32)
        if np.any(np.isnan(features_array)) or np.any(np.isinf(features_array)):
            raise ValueError("Feature array contains NaN or infinite values")
    except Exception as e:
        logging.error("Error converting features_list to array: %s", e)
        raise

    # Log feature dimensions
    logging.info("Feature array shape: %s", features_array.shape)

    # Apply PCA if enabled
    if use_pca:
        logging.info("Applying PCA to reduce dimensionality to %d components", n_components)
        try:
            pca = PCA(n_components=n_components)
            features_array = pca.fit_transform(features_array)
            logging.info("PCA reduced shape: %s", features_array.shape)
        except Exception as e:
            logging.error("Error during PCA: %s", e)
            raise

    # Perform GMM clustering
    try:
        gmm = GaussianMixture(
            n_components=n_clusters,
            covariance_type='diag',  # Corrected to 'diag'
            max_iter=50,
            random_state=42,
            verbose=1,
            verbose_interval=10
        )
        gmm.fit(features_array)
        cluster_labels = gmm.predict(features_array)
        logging.info("GMM fitting completed in %.2f seconds", time.time() - start_time)
    except Exception as e:
        logging.error("Error during GMM fitting: %s", e)
        raise

    # Get cluster means (equivalent to centroids)
    cluster_means = gmm.means_

    # List to store [mean, max_distance] for each cluster
    cluster_data = []

    # Calculate max distance for each cluster
    for cluster_index in range(n_clusters):
        # Get data points assigned to this cluster
        cluster_points = features_array[cluster_labels == cluster_index]

        # Calculate distances from each point to the cluster mean
        try:
            distances = cdist(cluster_points, [cluster_means[cluster_index]], metric='euclidean').flatten()
            max_distance = distances.max() if len(distances) > 0 else 0.0
        except Exception as e:
            logging.error("Error calculating distances for cluster %d: %s", cluster_index, e)
            max_distance = 0.0

        # Append the mean and max distance as a pair to cluster_data
        cluster_data.append([cluster_means[cluster_index].tolist(), max_distance])
        logging.info("Cluster %d: %d points, max distance %.2f", cluster_index, len(cluster_points), max_distance)

    # Save cluster_data (means and max distances) using pickle
    try:
        with open(centers_save_path, 'wb') as f:
            pickle.dump(cluster_data, f)
        logging.info("Cluster data saved to %s in %.2f seconds", centers_save_path, time.time() - start_time)
    except Exception as e:
        logging.error("Error saving cluster data: %s", e)
        raise

    return cluster_labels

In [7]:
data_dir = 'data/VisualPlaceCellData/'
maze_files = ['LM4','LM6','LM8','LMO8','LM8_addition','LMO8_remove']
maze_index = 2

with open(data_dir+maze_files[maze_index]+'_Training','rb') as file:
    visual_place_cell_data = pickle.load(file)

multimodal_feature_vectors,cnn_feature_vectors,xy_list,theta_list = format_data_for_clustering(visual_place_cell_data)
# n_clusters = [50,100,250,500]  # You can change this based on how many clusters you expect
n_clusters = [100,200,400]
for n_cluster in n_clusters:

    centers_save_path = "data/VisualPlaceCellData/VisualPlaceCellClusters/multimodal_gmm_"+str(n_cluster)+"_"+maze_files[maze_index]
    cluster_labels = cluster_with_gmm_and_save_centers(multimodal_feature_vectors, n_cluster, centers_save_path)

    # Now plot the clusters
    plot_clusters_by_subplots(xy_list, theta_list, cluster_labels, "data/figures/Clustering/multimodel_gmm_"+str(n_cluster)+"clusters.png", n_clusters=n_cluster)

    # centers_save_path = "data/VisualPlaceCellData/VisualPlaceCellClusters/cnn_gmm_"+str(n_cluster)+"centers"
    # cluster_labels = cluster_with_gmm_and_save_centers(cnn_feature_vectors, n_cluster, centers_save_path)
    # 
    # # Now plot the clusters
    # plot_clusters_by_subplots(xy_list, theta_list, cluster_labels, "data/figures/Clustering/cnn_gmm_"+str(n_cluster)+"clusters.png", n_clusters=n_cluster)


2025-07-16 13:29:12,460 - INFO - Starting GMM clustering with 8000 data points and 100 clusters
2025-07-16 13:29:12,668 - INFO - Feature array shape: (8000, 7374)


Initialization 0
  Iteration 10
  Iteration 20
Initialization converged.


2025-07-16 13:29:51,291 - INFO - GMM fitting completed in 38.83 seconds
2025-07-16 13:29:51,298 - INFO - Cluster 0: 140 points, max distance 2778.69
2025-07-16 13:29:51,306 - INFO - Cluster 1: 111 points, max distance 3234.50
2025-07-16 13:29:51,309 - INFO - Cluster 2: 88 points, max distance 1976.21
2025-07-16 13:29:51,316 - INFO - Cluster 3: 175 points, max distance 2692.34
2025-07-16 13:29:51,318 - INFO - Cluster 4: 45 points, max distance 2634.81
2025-07-16 13:29:51,322 - INFO - Cluster 5: 75 points, max distance 1712.18
2025-07-16 13:29:51,328 - INFO - Cluster 6: 99 points, max distance 3113.56
2025-07-16 13:29:51,331 - INFO - Cluster 7: 86 points, max distance 2406.19
2025-07-16 13:29:51,333 - INFO - Cluster 8: 51 points, max distance 2661.21
2025-07-16 13:29:51,340 - INFO - Cluster 9: 124 points, max distance 2306.92
2025-07-16 13:29:51,346 - INFO - Cluster 10: 49 points, max distance 2457.20
2025-07-16 13:29:51,348 - INFO - Cluster 11: 43 points, max distance 2693.22
2025-07-16

Initialization 0
  Iteration 10
Initialization converged.


2025-07-16 13:30:53,477 - INFO - GMM fitting completed in 51.07 seconds
2025-07-16 13:30:53,478 - INFO - Cluster 0: 18 points, max distance 2597.37
2025-07-16 13:30:53,481 - INFO - Cluster 1: 45 points, max distance 2680.78
2025-07-16 13:30:53,483 - INFO - Cluster 2: 71 points, max distance 2688.72
2025-07-16 13:30:53,489 - INFO - Cluster 3: 151 points, max distance 1498.77
2025-07-16 13:30:53,490 - INFO - Cluster 4: 24 points, max distance 2384.49
2025-07-16 13:30:53,494 - INFO - Cluster 5: 60 points, max distance 2410.86
2025-07-16 13:30:53,495 - INFO - Cluster 6: 38 points, max distance 2091.81
2025-07-16 13:30:53,497 - INFO - Cluster 7: 29 points, max distance 2253.07
2025-07-16 13:30:53,500 - INFO - Cluster 8: 24 points, max distance 1999.57
2025-07-16 13:30:53,502 - INFO - Cluster 9: 16 points, max distance 2057.10
2025-07-16 13:30:53,504 - INFO - Cluster 10: 39 points, max distance 2549.67
2025-07-16 13:30:53,506 - INFO - Cluster 11: 32 points, max distance 2455.07
2025-07-16 13

Initialization 0
  Iteration 10
Initialization converged.


2025-07-16 13:32:50,023 - INFO - GMM fitting completed in 93.82 seconds
2025-07-16 13:32:50,025 - INFO - Cluster 0: 9 points, max distance 1742.55
2025-07-16 13:32:50,027 - INFO - Cluster 1: 22 points, max distance 2664.34
2025-07-16 13:32:50,029 - INFO - Cluster 2: 19 points, max distance 2019.70
2025-07-16 13:32:50,032 - INFO - Cluster 3: 73 points, max distance 1248.18
2025-07-16 13:32:50,034 - INFO - Cluster 4: 16 points, max distance 2234.19
2025-07-16 13:32:50,035 - INFO - Cluster 5: 11 points, max distance 2467.72
2025-07-16 13:32:50,037 - INFO - Cluster 6: 34 points, max distance 1947.09
2025-07-16 13:32:50,039 - INFO - Cluster 7: 18 points, max distance 2505.16
2025-07-16 13:32:50,043 - INFO - Cluster 8: 22 points, max distance 2088.19
2025-07-16 13:32:50,044 - INFO - Cluster 9: 12 points, max distance 1999.94
2025-07-16 13:32:50,046 - INFO - Cluster 10: 17 points, max distance 2085.01
2025-07-16 13:32:50,048 - INFO - Cluster 11: 11 points, max distance 2324.17
2025-07-16 13:3