# PKS Structural Feature Correlation Analysis

This notebook provides a **comprehensive correlation analysis** between all structural features in the PKS dataset.

## Goals
1. Identify strongly correlated features (for dimensionality reduction)
2. Discover unexpected relationships between structural properties
3. Find features that best explain structural organization
4. Generate hypotheses about structure-function relationships

## Data Sources
- Module macroproperties (size, surface area, volume, compactness)
- Domain macroproperties
- Inter-domain distances
- Domain centroids and orientations


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.cluster import hierarchy
import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', 50)
plt.style.use('seaborn-v0_8-whitegrid')

print("Libraries loaded successfully")


---
## 1. Load and Prepare Data


In [None]:
# Load all data sources
domain_mp = pd.read_csv('domain_macroproperties.csv', index_col=0)
module_mp = pd.read_csv('MP_PKS.csv')
combined_df = pd.read_csv('MP_IA_IDO_combined.csv', low_memory=False)
ido_df = pd.read_csv('IDO_out.csv', low_memory=False)

print(f"Domain macroproperties: {domain_mp.shape}")
print(f"Module macroproperties: {module_mp.shape}")
print(f"Combined data: {combined_df.shape}")
print(f"IDO data: {ido_df.shape}")


In [None]:
# Identify numeric columns for correlation analysis
def get_numeric_cols(df, exclude_patterns=['Unnamed', 'filename', 'zernike']):
    """Get numeric columns excluding specified patterns"""
    numeric = df.select_dtypes(include=[np.number]).columns.tolist()
    return [c for c in numeric if not any(p.lower() in c.lower() for p in exclude_patterns)]

module_numeric = get_numeric_cols(module_mp)
combined_numeric = get_numeric_cols(combined_df)

print(f"Module numeric columns: {len(module_numeric)}")
print(f"Combined numeric columns: {len(combined_numeric)}")


---
## 2. Module Macroproperties Correlations


In [None]:
# Module macroproperties correlation matrix
struct_cols = ['n_ca_atoms', 'n_heavy_atoms', 'radius_of_gyration_ca', 'sasa', 'ses_area', 'vdw_volume']
available_struct = [c for c in struct_cols if c in module_mp.columns]

corr_module = module_mp[available_struct].corr()

# Heatmap
fig, ax = plt.subplots(figsize=(10, 8))
mask = np.triu(np.ones_like(corr_module, dtype=bool), k=1)
sns.heatmap(corr_module, annot=True, cmap='RdBu_r', center=0, ax=ax, 
            fmt='.3f', mask=mask, square=True, linewidths=0.5)
ax.set_title('Module Macroproperties Correlation Matrix', fontsize=14)
plt.tight_layout()
plt.show()

# Print strong correlations
print("\nStrong correlations (|r| > 0.9):")
for i, col1 in enumerate(available_struct):
    for j, col2 in enumerate(available_struct):
        if i < j:
            r = corr_module.loc[col1, col2]
            if abs(r) > 0.9:
                print(f"  {col1} ↔ {col2}: r = {r:.4f}")


In [None]:
# Scatter plot matrix for module properties
fig = plt.figure(figsize=(14, 14))
pd.plotting.scatter_matrix(module_mp[available_struct], 
                           figsize=(14, 14), 
                           diagonal='hist',
                           alpha=0.3,
                           s=10)
plt.suptitle('Module Macroproperties: Pairwise Relationships', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()


---
## 3. Inter-Domain Distance Correlations

Analyze which domain-domain distances are correlated (suggesting structural constraints).


In [None]:
# Get distance columns
distance_cols = [c for c in ido_df.columns if c.startswith('dist.')]
print(f"Total distance columns: {len(distance_cols)}")

# Focus on key catalytic domain distances (no linkers)
def is_catalytic_distance(col):
    """Check if distance is between two catalytic (non-linker) domains"""
    domains = col.replace('dist.', '').split('__')
    return all(len(d) <= 3 or not d.endswith('L') for d in domains)

catalytic_dist_cols = [c for c in distance_cols if is_catalytic_distance(c)]
print(f"Catalytic domain distances: {len(catalytic_dist_cols)}")

# Select key distances for analysis
key_distances = ['dist.KS__AT', 'dist.KS__ACP', 'dist.AT__ACP', 'dist.KS__KR',
                 'dist.KR__ACP', 'dist.DH__KR', 'dist.DH__ACP', 'dist.ER__KR',
                 'dist.KS__DH', 'dist.AT__KR', 'dist.AT__DH']
available_key = [d for d in key_distances if d in distance_cols]
print(f"\nKey distances available: {len(available_key)}")


In [None]:
# Distance correlation matrix
if len(available_key) >= 4:
    dist_corr = ido_df[available_key].corr()
    
    # Clean labels
    clean_labels = [c.replace('dist.', '').replace('__', '↔') for c in available_key]
    dist_corr.index = clean_labels
    dist_corr.columns = clean_labels
    
    fig, ax = plt.subplots(figsize=(12, 10))
    sns.heatmap(dist_corr, annot=True, cmap='RdBu_r', center=0, ax=ax, 
                fmt='.2f', square=True, linewidths=0.5,
                annot_kws={'size': 9})
    ax.set_title('Inter-Domain Distance Correlations\n(Higher correlation = Structurally coupled)', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Find strongest correlations
    print("\nStrongest distance correlations (|r| > 0.7):")
    for i, col1 in enumerate(available_key):
        for j, col2 in enumerate(available_key):
            if i < j:
                r = ido_df[col1].corr(ido_df[col2])
                if abs(r) > 0.7:
                    label1 = col1.replace('dist.', '').replace('__', '↔')
                    label2 = col2.replace('dist.', '').replace('__', '↔')
                    print(f"  {label1} ↔ {label2}: r = {r:.3f}")


---
## 4. Cross-Feature Correlations

Correlate macroproperties with inter-domain distances.


In [None]:
# Cross-correlate macro properties with distances
# Use combined_df which has both
macro_cols = [c for c in available_struct if c in combined_df.columns]
dist_cols_combined = [c for c in available_key if c in combined_df.columns]

if macro_cols and dist_cols_combined:
    # Calculate cross-correlation matrix
    cross_corr = pd.DataFrame(index=macro_cols, columns=dist_cols_combined)
    
    for mc in macro_cols:
        for dc in dist_cols_combined:
            r = combined_df[mc].corr(combined_df[dc])
            cross_corr.loc[mc, dc] = r
    
    cross_corr = cross_corr.astype(float)
    
    # Heatmap
    fig, ax = plt.subplots(figsize=(14, 6))
    clean_dist_labels = [c.replace('dist.', '').replace('__', '↔') for c in dist_cols_combined]
    cross_corr.columns = clean_dist_labels
    
    sns.heatmap(cross_corr, annot=True, cmap='RdBu_r', center=0, ax=ax,
                fmt='.2f', linewidths=0.5)
    ax.set_title('Macroproperties vs Inter-Domain Distances\n(Correlation Matrix)', fontsize=14)
    ax.set_xlabel('Inter-Domain Distance')
    ax.set_ylabel('Macroproperty')
    plt.tight_layout()
    plt.show()
    
    # Find strongest cross-correlations
    print("\nStrongest cross-correlations (|r| > 0.3):")
    for mc in macro_cols:
        for dc in dist_cols_combined:
            r = float(cross_corr.loc[mc, dc.replace('dist.', '').replace('__', '↔') if dc.replace('dist.', '').replace('__', '↔') in cross_corr.columns else dc])
            if abs(r) > 0.3:
                dc_label = dc.replace('dist.', '').replace('__', '↔')
                print(f"  {mc} vs {dc_label}: r = {r:.3f}")


---
## 5. Hierarchical Clustering of Features

Cluster features by correlation to identify groups of related measurements.


In [None]:
# Select top features for clustering
# Combine macroproperties and key distances
all_features = macro_cols + dist_cols_combined
available_for_cluster = [c for c in all_features if c in combined_df.columns]

# Drop rows with too many NaN values
cluster_data = combined_df[available_for_cluster].dropna(thresh=len(available_for_cluster)*0.5)
cluster_data = cluster_data.dropna(axis=1, thresh=len(cluster_data)*0.5)

print(f"Features for clustering: {cluster_data.shape[1]}")
print(f"Samples: {cluster_data.shape[0]}")

# Compute correlation matrix
corr_matrix = cluster_data.corr()

# Hierarchical clustering dendrogram
fig, axes = plt.subplots(1, 2, figsize=(18, 8))

# Dendrogram
ax = axes[0]
from scipy.spatial.distance import squareform

# Convert correlation to distance (1 - |corr|)
dist_matrix = 1 - np.abs(corr_matrix)
np.fill_diagonal(dist_matrix.values, 0)

# Linkage
condensed = squareform(dist_matrix)
linkage = hierarchy.linkage(condensed, method='average')

# Plot dendrogram
clean_labels = [c.replace('dist.', '').replace('__', '↔') for c in corr_matrix.columns]
dendro = hierarchy.dendrogram(linkage, labels=clean_labels, ax=ax, leaf_rotation=90)
ax.set_ylabel('Distance (1 - |correlation|)')
ax.set_title('Feature Clustering by Correlation')

# Clustered heatmap
ax = axes[1]
order = dendro['leaves']
ordered_corr = corr_matrix.iloc[order, order]
ordered_corr.index = [clean_labels[i] for i in order]
ordered_corr.columns = [clean_labels[i] for i in order]

sns.heatmap(ordered_corr, cmap='RdBu_r', center=0, ax=ax, square=True,
            xticklabels=True, yticklabels=True)
ax.set_title('Reordered Correlation Matrix')

plt.tight_layout()
plt.show()


---
## 6. Domain Centroid Coordinate Correlations

Analyze relationships between domain spatial positions.


In [None]:
# Centroid coordinate correlations
centroid_cols = [c for c in ido_df.columns if c.startswith('centroid.')]
print(f"Centroid columns: {len(centroid_cols)}")

# Extract key domain centroids (X, Y, Z for main catalytic domains)
key_domains = ['KS', 'AT', 'ACP', 'KR', 'DH', 'ER']
centroid_key = []
for dom in key_domains:
    for axis in ['X', 'Y', 'Z']:
        col = f'centroid.{dom}.{axis}'
        if col in centroid_cols:
            centroid_key.append(col)

print(f"Key centroid coordinates: {len(centroid_key)}")

if len(centroid_key) >= 9:
    # Correlation of X coordinates (how domains line up on X axis)
    x_cols = [c for c in centroid_key if c.endswith('.X')]
    if len(x_cols) >= 3:
        x_corr = ido_df[x_cols].corr()
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        for i, axis in enumerate(['X', 'Y', 'Z']):
            ax = axes[i]
            axis_cols = [c for c in centroid_key if c.endswith(f'.{axis}')]
            axis_corr = ido_df[axis_cols].corr()
            
            # Clean labels
            clean = [c.replace('centroid.', '').replace(f'.{axis}', '') for c in axis_cols]
            axis_corr.index = clean
            axis_corr.columns = clean
            
            sns.heatmap(axis_corr, annot=True, cmap='RdBu_r', center=0, ax=ax,
                       fmt='.2f', square=True)
            ax.set_title(f'{axis}-Coordinate Correlations')
        
        plt.suptitle('Domain Centroid Coordinate Correlations', fontsize=14, y=1.02)
        plt.tight_layout()
        plt.show()


---
## 7. Rotation Matrix Analysis

Analyze correlations in domain orientations.


In [None]:
# Rotation matrix elements correlation
rotation_cols = [c for c in ido_df.columns if c.startswith('R.')]
print(f"Rotation matrix columns: {len(rotation_cols)}")

if len(rotation_cols) == 9:
    # R.00, R.01, R.02, R.10, R.11, R.12, R.20, R.21, R.22
    rot_corr = ido_df[rotation_cols].corr()
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(rot_corr, annot=True, cmap='RdBu_r', center=0, ax=ax, 
               fmt='.2f', square=True)
    ax.set_title('Rotation Matrix Element Correlations\n(Reference → Module orientation)', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Analyze diagonal vs off-diagonal (diagonal = axis preservation)
    diag = ['R.00', 'R.11', 'R.22']
    off_diag = ['R.01', 'R.02', 'R.10', 'R.12', 'R.20', 'R.21']
    
    print("\nRotation Matrix Statistics:")
    print("-"*50)
    print(f"Diagonal elements (axis preservation):")
    for col in diag:
        if col in ido_df.columns:
            vals = ido_df[col].dropna()
            print(f"  {col}: mean={vals.mean():.3f}, std={vals.std():.3f}")
    
    print(f"\nOff-diagonal elements (axis mixing):")
    for col in off_diag:
        if col in ido_df.columns:
            vals = ido_df[col].dropna()
            print(f"  {col}: mean={vals.mean():.3f}, std={vals.std():.3f}")


---
## 8. Full Correlation Network

Identify the most connected features in the dataset.


In [None]:
# Find features with most strong correlations
threshold = 0.6

# Build adjacency from correlation
all_numeric = combined_df.select_dtypes(include=[np.number])
# Remove columns with too many NaN
valid_cols = all_numeric.columns[all_numeric.isnull().mean() < 0.5]
all_numeric = all_numeric[valid_cols]

# Sample for large correlation matrix
if len(all_numeric.columns) > 100:
    # Select mix of macro, distance, and centroid columns
    selected_cols = []
    for prefix in ['n_', 'radius', 'sasa', 'ses_', 'vdw_', 'dist.KS', 'dist.AT', 'dist.ACP', 
                   'centroid.KS', 'centroid.AT', 'centroid.ACP']:
        matching = [c for c in all_numeric.columns if c.startswith(prefix)][:3]
        selected_cols.extend(matching)
    selected_cols = list(set(selected_cols))[:40]
    analysis_df = all_numeric[selected_cols]
else:
    analysis_df = all_numeric

print(f"Analyzing {len(analysis_df.columns)} features")

# Correlation matrix
full_corr = analysis_df.corr()

# Count strong correlations per feature
strong_corr_counts = {}
for col in full_corr.columns:
    count = (abs(full_corr[col]) > threshold).sum() - 1  # exclude self
    strong_corr_counts[col] = count

# Top connected features
top_connected = sorted(strong_corr_counts.items(), key=lambda x: x[1], reverse=True)[:15]

print(f"\nMost Connected Features (|r| > {threshold}):")
print("-"*60)
for feat, count in top_connected:
    clean_feat = feat.replace('dist.', '').replace('__', '↔').replace('centroid.', 'c.')
    print(f"  {clean_feat:40s}: {count:3d} strong correlations")


---
## 9. Summary and Key Insights


In [None]:
# Summary statistics
print("="*70)
print("CORRELATION ANALYSIS SUMMARY")
print("="*70)

print("\n1. HIGHLY CORRELATED FEATURE PAIRS (r > 0.9):")
print("-"*50)
high_corr_pairs = []
for i, col1 in enumerate(analysis_df.columns):
    for j, col2 in enumerate(analysis_df.columns):
        if i < j:
            r = analysis_df[col1].corr(analysis_df[col2])
            if abs(r) > 0.9:
                high_corr_pairs.append((col1, col2, r))

high_corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
for col1, col2, r in high_corr_pairs[:10]:
    clean1 = col1.replace('dist.', '').replace('__', '↔')[:25]
    clean2 = col2.replace('dist.', '').replace('__', '↔')[:25]
    print(f"  {clean1:25s} ↔ {clean2:25s}: r={r:.3f}")

print(f"\n  Total pairs with |r| > 0.9: {len(high_corr_pairs)}")

print("\n2. REDUNDANT FEATURES (Consider removing for ML):")
print("-"*50)
if len(high_corr_pairs) > 0:
    redundant = set()
    for col1, col2, r in high_corr_pairs:
        if col2 not in redundant:  # Keep first, mark second as redundant
            redundant.add(col2)
    print(f"  {len(redundant)} features are highly correlated with others")
    for f in list(redundant)[:10]:
        print(f"    - {f[:50]}")

print("\n3. INTERPRETATIONS:")
print("-"*50)
print("  • Size metrics (n_atoms, volume, SASA) are highly correlated")
print("  • Adjacent domain distances show correlation (structural constraints)")
print("  • Centroid coordinates on same axis indicate linear arrangement")
print("  • Rotation matrix structure reveals preferred orientations")
