In [None]:
# # Example: Clustering MD Trajectory Data with KMeans

# This notebook demonstrates how to use KMeans clustering to identify distinct conformational states from molecular dynamics simulation data. It uses functions from the `md_analysis_tools` library.

# **Workflow:**
# 1. Import necessary libraries.
# 2. Load pre-processed data suitable for clustering. This could be:
#     *   Coordinates from specific atoms.
#     *   Principal component projections (output from PCA analysis).
#     *   Key distances or dihedral angles.
#     *   *(This example will assume we are using PC projections from the PCA example)*
# 3. (Optional) Determine a suitable number of clusters (k) using the Elbow method.
# 4. Perform KMeans clustering using `perform_kmeans`.
# 5. Visualize the clusters (e.g., on a PCA plot).
# 6. Find representative frames closest to each cluster centroid using `find_closest_frames_to_centroids`.
# 7. (Optional) Save the representative structures.


In [None]:
# Import necessary libraries
import md_analysis_tools # Our custom library
import MDAnalysis as mda
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns # For potentially nicer plotting
import os

# Import KMeans related tools if needed for elbow plot etc.
from sklearn.cluster import KMeans 
# Optional: For elbow plot visualization - install if needed (`pip install yellowbrick`)
# try:
#     from yellowbrick.cluster import KElbowVisualizer
#     _YELLOWBRICK_AVAILABLE = True
# except ImportError:
#     _YELLOWBRICK_AVAILABLE = False
_YELLOWBRICK_AVAILABLE = False # Keep False if not installing/using

# Configure plotting style (optional)
plt.style.use('seaborn-v0_8-poster')


In [None]:
# ## 1. Load Data for Clustering

# Clustering is often performed on reduced-dimensionality data rather than raw coordinates to focus on significant motions and reduce noise. A common choice is to use the projections onto the first few principal components (PCs) obtained from PCA.

# **Assumption:** This notebook assumes you have already run PCA (like in `Example_PCA_Analysis.ipynb`) and saved the projections, or you can calculate them here.

# **ACTION:**
#   - Specify the number of PCs to use for clustering.
#   - Provide the path to the PCA projection data *OR* uncomment and adapt the PCA calculation section if you need to run it first.
#   - Define topology/trajectory paths if calculating PCA here *or* if needed later for saving structures.


In [None]:
# --- User Input ---
pca_components_to_use = 3 # How many PCs to use for clustering (e.g., 2 or 3)
projection_file = None    # Set to path if loading pre-calculated projections, e.g., "pca_analysis_output/pca_projections.csv"

# --- OR --- Calculate PCA projections here if needed ---
# If projection_file is None, we might need to calculate them.
# Requires topology and trajectory files.
topology_file = "placeholder.prmtop" # <-- Needs path if calculating PCA or saving structures
trajectory_file = "placeholder.dcd"   # <-- Needs path if calculating PCA or saving structures
pca_selection = "name CA and protein" # Selection used for PCA
align_selection = "name CA and protein" # Selection used for alignment (or None)
output_dir = "clustering_analysis_output" # Directory for saving results

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# --- Load/Calculate Data ---
data_for_clustering = None
pca_result_obj = None # To store PCA object if calculated here

if projection_file and os.path.exists(projection_file):
    print(f"Loading pre-calculated projections from: {projection_file}")
    try:
        df_projections = pd.read_csv(projection_file)
        # Select the first 'pca_components_to_use' columns (assuming they are named PC1, PC2, ...)
        pc_cols = [f'PC{i+1}' for i in range(pca_components_to_use)]
        if all(col in df_projections.columns for col in pc_cols):
             data_for_clustering = df_projections[pc_cols].values
             print(f"Using PC columns: {pc_cols}")
        else:
             print(f"Error: Projection file exists but doesn't contain required columns {pc_cols}.", file=sys.stderr)
             # exit()
    except Exception as e:
        print(f"Error loading projection file: {e}", file=sys.stderr)
        # exit()

elif topology_file and trajectory_file and os.path.exists(topology_file) and os.path.exists(trajectory_file):
    print("Projection file not found or specified. Calculating PCA projections now...")
    try:
        u = mda.Universe(topology_file, trajectory_file)
        print(f"Universe loaded with {len(u.trajectory)} frames.")
        
        pca_result_obj = md_analysis_tools.perform_cartesian_pca(
            universe=u,
            select=pca_selection,
            align=(align_selection is not None),
            align_select=align_selection,
            n_components=pca_components_to_use 
        )
        
        if pca_result_obj:
            pca_atoms = u.select_atoms(pca_selection)
            data_for_clustering = pca_result_obj.transform(pca_atoms, n_components=pca_components_to_use)
            print(f"PCA calculation and transformation complete.")
            # Optionally save projections
            df_projections = pd.DataFrame(data_for_clustering, columns=[f'PC{i+1}' for i in range(pca_components_to_use)])
            df_projections.to_csv(os.path.join(output_dir, "pca_projections_calculated.csv"), index=False)
        else:
            print("PCA calculation failed during data loading.", file=sys.stderr)
            # exit()
            
    except Exception as e:
        print(f"Error loading Universe or running PCA: {e}", file=sys.stderr)
        # exit()
else:
    print("Error: Cannot proceed without either a projection file or valid topology/trajectory files.", file=sys.stderr)
    # exit()

# --- Verify Data ---
if data_for_clustering is not None:
     print(f"\nData shape for clustering: {data_for_clustering.shape}") # Should be (n_frames, n_components_to_use)
else:
     print("\nClustering cannot proceed due to data loading/calculation errors.", file=sys.stderr)
     # exit()


In [None]:
# ## 2. Determine Optimal Number of Clusters (k) - Optional

# The Elbow method is a common heuristic to find a suitable number of clusters. It involves running KMeans for a range of `k` values and plotting the Sum of Squared Errors (SSE) or "inertia". The "elbow" point in the plot suggests a good balance between the number of clusters and the variance explained.

# We can use `scikit-learn`'s `KMeans` inertia attribute or the `yellowbrick` library for visualization.


In [None]:
# --- Elbow Method ---
if data_for_clustering is not None:
    sse = []
    k_range = range(1, 11) # Test k from 1 to 10

    print("\nCalculating SSE for Elbow method...")
    for k in k_range:
        kmeans_test = KMeans(
            n_clusters=k,
            init='k-means++', # Common initialization method
            n_init=10,       # Run multiple times with different seeds
            random_state=42  # For reproducibility
        )
        kmeans_test.fit(data_for_clustering)
        sse.append(kmeans_test.inertia_) # inertia_ is the SSE

    # Plot SSE vs k
    plt.figure(figsize=(8, 5))
    plt.plot(k_range, sse, marker='o', linestyle='--')
    plt.title('Elbow Method for Optimal k')
    plt.xlabel('Number of Clusters (k)')
    plt.ylabel('Sum of Squared Errors (SSE)')
    plt.xticks(k_range)
    plt.grid(True, linestyle=':')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "kmeans_elbow_plot.png"), dpi=300)
    plt.show()

    # --- Optional: Yellowbrick visualization ---
    if _YELLOWBRICK_AVAILABLE:
        print("\nUsing Yellowbrick KElbowVisualizer...")
        try:
            model = KMeans(init='k-means++', n_init=10, random_state=42)
            visualizer = KElbowVisualizer(model, k=k_range)
            visualizer.fit(data_for_clustering)
            visualizer.show(outpath=os.path.join(output_dir, "kmeans_elbow_yellowbrick.png"))
            # The visualizer might suggest an optimal k based on the elbow or other metrics.
            print(f"(Yellowbrick suggested k: {visualizer.elbow_value_})") # May be None
        except Exception as e:
             print(f"Yellowbrick visualization failed: {e}", file=sys.stderr)
else:
    print("Skipping Elbow method due to missing data.")
    
# --- User Decision ---
# Based on the plot, choose a value for k where the SSE decrease starts to level off.
optimal_k = 3 # <-- REPLACE with your chosen k value based on the elbow plot

print(f"\nSelected optimal number of clusters (k): {optimal_k}")



In [None]:
# ## 3. Perform KMeans Clustering

# Now we use the chosen `optimal_k` and run the `perform_kmeans` function from our library on the prepared data (e.g., PCA projections). We can also choose whether to scale the data (often recommended if the ranges of the PCs differ significantly, although PCA components are often somewhat scaled already).


In [None]:
# --- Run KMeans ---
kmeans_model = None
labels = None
centroids = None
data_used = None # Store the data actually used (potentially scaled)

if data_for_clustering is not None:
    print(f"\nRunning KMeans with k={optimal_k}...")
    kmeans_result_tuple = md_analysis_tools.perform_kmeans(
        data=data_for_clustering,
        n_clusters=optimal_k,
        scale_data=False, # Set to True if features (PCs) have very different scales
        random_state=42, # For reproducibility
        # Pass additional kwargs for scikit-learn's KMeans if needed
        # init='k-means++',
        # n_init=10
    )

    if kmeans_result_tuple:
        kmeans_model, labels, centroids, data_used = kmeans_result_tuple
        print(f"KMeans finished. Found {len(centroids)} centroids.")
        # Count members in each cluster
        from collections import Counter
        cluster_counts = Counter(labels)
        print("Cluster populations:")
        for cluster_id, count in sorted(cluster_counts.items()):
             print(f"  Cluster {cluster_id}: {count} frames")
    else:
        print("KMeans clustering failed.", file=sys.stderr)
        # exit()
else:
     print("Skipping KMeans clustering due to missing data.", file=sys.stderr)


In [None]:
# ## 4. Visualize Clusters

# If clustering was performed on 2D or 3D data (e.g., PC1 vs PC2, or PC1 vs PC2 vs PC3), we can visualize the data points colored by their assigned cluster label. We also plot the cluster centroids.


In [None]:
# --- Plot Clusters ---
if labels is not None and data_used is not None and centroids is not None:
    n_dims = data_used.shape[1]
    
    if n_dims >= 2:
        print("\nVisualizing clusters on PC1 vs PC2...")
        plt.figure(figsize=(8, 8))
        scatter = plt.scatter(data_used[:, 0], data_used[:, 1], c=labels, cmap='viridis', s=10, alpha=0.5, label='Data Points')
        
        # Plot centroids
        plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, c='red', edgecolor='black', label='Centroids')
        
        plt.title(f'KMeans Clustering Results (k={optimal_k}) on PCA Projections')
        plt.xlabel('Principal Component 1')
        plt.ylabel('Principal Component 2')
        plt.legend()
        plt.grid(True, linestyle=':')
        plt.gca().set_aspect('equal', adjustable='box')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "kmeans_clusters_pc1_pc2.png"), dpi=300)
        plt.show()
        
        if n_dims >= 3:
             # Optional: Add 3D plot (PC1 vs PC2 vs PC3)
             print("\nVisualizing clusters on PC1 vs PC2 vs PC3...")
             fig = plt.figure(figsize=(10, 10))
             ax = fig.add_subplot(111, projection='3d')
             scatter_3d = ax.scatter(data_used[:, 0], data_used[:, 1], data_used[:, 2], c=labels, cmap='viridis', s=10, alpha=0.3)
             ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], marker='X', s=250, c='red', edgecolor='black', label='Centroids')
             ax.set_title(f'KMeans Clusters (k={optimal_k}) on PC1-3')
             ax.set_xlabel('PC1')
             ax.set_ylabel('PC2')
             ax.set_zlabel('PC3')
             ax.legend()
             plt.tight_layout()
             plt.savefig(os.path.join(output_dir, "kmeans_clusters_pc1_pc2_pc3.png"), dpi=300)
             plt.show()

    else:
        print("Visualization skipped: Need at least 2 dimensions for scatter plot.")

else:
     print("Skipping cluster visualization due to previous errors.", file=sys.stderr)



In [None]:
# ## 5. Find Representative Frames

# For each cluster, we want to find the actual frame from the original trajectory that is closest to the cluster centroid (in the space used for clustering, e.g., PCA space). This gives us a representative structure for each conformational state.


In [None]:
# --- Find Representatives ---
representative_indices = None

if labels is not None and data_used is not None and centroids is not None:
    print("\nFinding representative frames closest to centroids...")
    representative_indices = md_analysis_tools.find_closest_frames_to_centroids(
        data=data_used, # Use the data that was actually clustered (potentially scaled)
        labels=labels,
        centroids=centroids
    )

    if representative_indices:
        print("\nRepresentative frame indices (0-based) for each cluster:")
        for cluster_id, frame_idx in enumerate(representative_indices):
            if frame_idx is not None:
                 print(f"  Cluster {cluster_id}: Frame {frame_idx}")
            else:
                 print(f"  Cluster {cluster_id}: No representative found (empty cluster?)")
                 
        # Optional: Re-plot clusters highlighting representatives
        if data_used.shape[1] >= 2:
             valid_indices = [idx for idx in representative_indices if idx is not None]
             if valid_indices:
                  rep_data = data_used[valid_indices]
                  plt.figure(figsize=(8, 8))
                  plt.scatter(data_used[:, 0], data_used[:, 1], c=labels, cmap='viridis', s=10, alpha=0.3)
                  plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, c='red', edgecolor='black', label='Centroids')
                  # Highlight representatives
                  plt.scatter(rep_data[:, 0], rep_data[:, 1], marker='o', s=150, c='yellow', edgecolor='black', label='Representatives')
                  plt.title(f'Clusters with Representatives (k={optimal_k})')
                  plt.xlabel('PC1')
                  plt.ylabel('PC2')
                  plt.legend()
                  plt.grid(True, linestyle=':')
                  plt.gca().set_aspect('equal', adjustable='box')
                  plt.tight_layout()
                  plt.savefig(os.path.join(output_dir, "kmeans_clusters_representatives.png"), dpi=300)
                  plt.show()
             
    else:
        print("Failed to find representative frames.", file=sys.stderr)

else:
    print("Skipping finding representatives due to previous errors.", file=sys.stderr)



In [None]:
# ## 6. (Optional) Save Representative Structures

# If representative frame indices were found, we can load the original trajectory and save these specific frames as PDB files.


In [None]:
# --- Save Structures ---
if representative_indices:
    # Need the original Universe object if it wasn't kept in memory
    # Reload if necessary (ensure topology/trajectory_file paths are correct)
    try:
        if 'u' not in locals() or not isinstance(u, mda.Universe):
             print("\nReloading Universe to save structures...")
             if os.path.exists(topology_file) and os.path.exists(trajectory_file):
                  u = mda.Universe(topology_file, trajectory_file)
             else:
                  raise FileNotFoundError("Topology/Trajectory files not found for saving structures.")
        
        print("\nSaving representative structures...")
        # Select all atoms for saving
        all_atoms = u.select_atoms("all")
        
        for cluster_id, frame_idx in enumerate(representative_indices):
            if frame_idx is not None:
                 try:
                      # Access the specific frame
                      u.trajectory[frame_idx]
                      
                      # Define output filename
                      pdb_filename = os.path.join(output_dir, f"representative_cluster_{cluster_id}_frame_{frame_idx}.pdb")
                      
                      # Write the PDB file
                      all_atoms.write(pdb_filename)
                      print(f"  Saved: {pdb_filename}")
                 except IndexError:
                      print(f"  Error: Frame index {frame_idx} out of bounds for trajectory. Cannot save structure for cluster {cluster_id}.", file=sys.stderr)
                 except Exception as e:
                      print(f"  Error saving structure for cluster {cluster_id} (frame {frame_idx}): {e}", file=sys.stderr)
            else:
                 print(f"  Skipping save for empty cluster {cluster_id}.")

    except Exception as e:
        print(f"Error during structure saving: {e}", file=sys.stderr)
        
else:
    print("Skipping structure saving as no representative indices were found.", file=sys.stderr)


In [None]:
# ## Conclusion

# This notebook demonstrated clustering MD trajectory data (using PCA projections) with KMeans via the `md_analysis_tools` library. We determined an optimal cluster number using the Elbow method, visualized the resulting clusters, and identified representative frames closest to the cluster centroids, saving them as PDB files.

