In [None]:
import scanpy as sc
import ot
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import MDS
import pandas as pd
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import matplotlib.gridspec as gridspec

#first of all, load up the processed scRNAseq h5ad data and then subset the celltype(s) of interest #here dopaminergic lineage
adata_all = sc.read_h5ad("all.h5ad")
adata = adata_all[adata_all.obs['BroadCellType'].isin(['DAN', 'IDN'])].copy()



# Step 1.1: Compute quality control metrics if not already available
try:
    if 'n_counts' not in adata.obs:
        sc.pp.calculate_qc_metrics(adata, inplace=True)
        if 'n_counts' not in adata.obs:
            adata.obs['n_counts'] = adata.X.sum(axis=1).A1 if hasattr(adata.X, 'A1') else adata.X.sum(axis=1)
except Exception as e:
    print(f"Warning: Failed to compute QC metrics: {e}. Using uniform weights for quality.")

# Step 1.2: Compute UMAP embeddings if not already available
if 'X_umap' not in adata.obsm:
    print("Computing UMAP embeddings...")
    # First, compute PCA if not already available
    if 'X_pca' not in adata.obsm:
        sc.tl.pca(adata, n_comps=50)
    # Compute neighbor graph based on PCA
    sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_pca')
    # Compute UMAP embeddings (2D for simplicity)
    sc.tl.umap(adata, min_dist=0.5, spread=1.0, n_components=2)
else:
    print("Using existing UMAP embeddings")

# --- Step 2: Define analysis groups ---
# Define all mutation types and their corresponding isogenic controls
mutation_groups = {
    'LRRK2': {'mutant': 'LRRK2', 'isogenic': 'isoLRRK2'},
    'GBA1': {'mutant': 'GBA1', 'isogenic': 'isoGBA1'},
    'SNCA': {'mutant': 'SNCA', 'isogenic': 'isoSNCA'}
}

# Define healthy controls
hc_label = 'HC'

# --- Step 3: Helper Functions ---
# Step 3.1: Quality-Based Weights (using log-transformed UMI counts)
def get_quality_weights(adata_subset):
    if 'n_counts' not in adata_subset.obs:
        print("Warning: 'n_counts' not found in adata_subset.obs, using uniform weights for quality")
        return np.ones(adata_subset.shape[0]) / adata_subset.shape[0]
    # Log-transform UMI counts to reduce skewness
    quality_scores = np.log1p(adata_subset.obs['n_counts'].values)
    # Handle any non-finite values
    quality_scores[~np.isfinite(quality_scores)] = 0
    # Normalize to sum to 1, with fallback if all scores are zero
    total = quality_scores.sum()
    return quality_scores / total if total > 0 else np.ones(adata_subset.shape[0]) / adata_subset.shape[0]

# Step 3.2: Cluster-Based Weights (using UMAP embeddings)
def get_cluster_weights(adata_subset, resolution=0.5):
    # Create a temporary AnnData object for clustering
    temp_adata = sc.AnnData(X=adata_subset.obsm['X_umap'])
    sc.pp.neighbors(temp_adata, n_neighbors=15, use_rep='X')
    sc.tl.leiden(temp_adata, resolution=resolution)

    # Compute cluster sizes
    cluster_sizes = temp_adata.obs['leiden'].value_counts()
    # Assign weights as inverse of cluster size
    weights = np.array([1.0 / cluster_sizes[cluster] for cluster in temp_adata.obs['leiden']])
    # Normalize to sum to 1
    return weights / weights.sum()

# Step 3.3: Combine Weights for Each Population
def compute_non_uniform_weights(adata_subset, resolution=0.5):
    quality_weights = get_quality_weights(adata_subset)
    cluster_weights = get_cluster_weights(adata_subset, resolution=resolution)
    # Combine by multiplication and normalize
    combined_weights = quality_weights * cluster_weights
    total = combined_weights.sum()
    return combined_weights / total if total > 0 else np.ones(adata_subset.shape[0]) / adata_subset.shape[0]

# Step 3.4: Function to compute OT distances between populations
def compute_ot_distances(X_A, X_B, a_weights, b_weights, reg=0.1):
    # Calculate cost matrix (squared Euclidean distances)
    M = ot.dist(X_A, X_B, metric='euclidean') ** 2

    # Normalize cost matrix for numerical stability
    M_norm = M / M.max() if M.max() > 0 else M

    # Check for non-finite values
    if not np.isfinite(M_norm).all():
        print("Warning: Non-finite values in cost matrix")

    # Compute optimal transport distance and plan
    ot_distance = ot.sinkhorn2(a_weights, b_weights, M_norm, reg)
    ot_plan = ot.sinkhorn(a_weights, b_weights, M_norm, reg)

    return ot_distance, ot_plan, M

# --- Step 4: Run Analysis for Each Mutation Type ---
# Dictionary to store results
results = {}

# First, extract healthy control cells
X_HC = adata[adata.obs['Mutation'] == hc_label].obsm['X_umap']
n_hc = X_HC.shape[0]
print(f"Number of healthy control cells: {n_hc}")

# Calculate weights for healthy controls
adata_hc = adata[adata.obs['Mutation'] == hc_label]
weights_hc = compute_non_uniform_weights(adata_hc)

# Process each mutation type
for mutation_name, labels in mutation_groups.items():
    print(f"\nProcessing {mutation_name} mutation group...")

    # Extract cells and UMAP coordinates
    mutant_label = labels['mutant']
    iso_label = labels['isogenic']

    # Extract mutant and isogenic control cells
    X_mutant = adata[adata.obs['Mutation'] == mutant_label].obsm['X_umap']
    X_iso = adata[adata.obs['Mutation'] == iso_label].obsm['X_umap']

    n_mutant = X_mutant.shape[0]
    n_iso = X_iso.shape[0]
    print(f"Number of cells: {mutant_label}: {n_mutant}, {iso_label}: {n_iso}")

    # Calculate weights for each population
    adata_mutant = adata[adata.obs['Mutation'] == mutant_label]
    adata_iso = adata[adata.obs['Mutation'] == iso_label]

    weights_mutant = compute_non_uniform_weights(adata_mutant)
    weights_iso = compute_non_uniform_weights(adata_iso)

    # Validate weights
    for name, weights in [(mutant_label, weights_mutant), (iso_label, weights_iso)]:
        if not np.all(weights > 0) or not np.isfinite(weights).all():
            print(f"Warning: Invalid weights in {name}")
        if not np.isclose(weights.sum(), 1.0, rtol=1e-5):
            print(f"Warning: Weights in {name} do not sum to 1")

    # Compute OT distances and plans for all pairwise comparisons
    # Mutant vs Isogenic
    cost_mutant_iso, plan_mutant_iso, M_mutant_iso = compute_ot_distances(
        X_mutant, X_iso, weights_mutant, weights_iso)

    # Mutant vs Healthy Control
    cost_mutant_hc, plan_mutant_hc, M_mutant_hc = compute_ot_distances(
        X_mutant, X_HC, weights_mutant, weights_hc)

    # Isogenic vs Healthy Control
    cost_iso_hc, plan_iso_hc, M_iso_hc = compute_ot_distances(
        X_iso, X_HC, weights_iso, weights_hc)

    # Store results
    results[mutation_name] = {
        'mutant_label': mutant_label,
        'iso_label': iso_label,
        'hc_label': hc_label,
        'X_mutant': X_mutant,
        'X_iso': X_iso,
        'X_HC': X_HC,
        'weights_mutant': weights_mutant,
        'weights_iso': weights_iso,
        'weights_hc': weights_hc,
        'distances': {
            f"{mutant_label} vs {iso_label}": cost_mutant_iso,
            f"{mutant_label} vs {hc_label}": cost_mutant_hc,
            f"{iso_label} vs {hc_label}": cost_iso_hc
        },
        'plans': {
            f"{mutant_label} vs {iso_label}": plan_mutant_iso,
            f"{mutant_label} vs {hc_label}": plan_mutant_hc,
            f"{iso_label} vs {hc_label}": plan_iso_hc
        },
        'cost_matrices': {
            f"{mutant_label} vs {iso_label}": M_mutant_iso,
            f"{mutant_label} vs {hc_label}": M_mutant_hc,
            f"{iso_label} vs {hc_label}": M_iso_hc
        }
    }

    # Print distances
    print("\nOptimal Transport Distances:")
    for pair, distance in results[mutation_name]['distances'].items():
        print(f"{pair}: {distance:.4f}")

    # Find the most similar populations
    closest_pair = min(results[mutation_name]['distances'].items(), key=lambda x: x[1])
    print(f"Most similar populations: {closest_pair[0]} with distance {closest_pair[1]:.4f}")

    # Determine if isogenic control is closer to mutant or to HC
    if cost_mutant_iso < cost_iso_hc:
        print(f"{iso_label} cells are more similar to {mutant_label} mutant cells than to healthy controls.")
    else:
        print(f"{iso_label} cells are more similar to healthy controls than to {mutant_label} mutant cells.")

# --- Step 5: Visualizations ---

# 1. Bar plots of OT distances for each mutation group
plt.figure(figsize=(16, 6))

# Define bar positions
bar_width = 0.25
mutation_positions = np.arange(len(mutation_groups))
comparison_offsets = [-bar_width, 0, bar_width]
comparison_types = ['mutant_vs_iso', 'mutant_vs_hc', 'iso_vs_hc']
colors = ['#3274A1', '#E1812C', '#3A923A']  # Blue, Orange, Green

# For each comparison type, plot bars for all mutations
for i, comp_type in enumerate(comparison_types):
    values = []
    for mutation_name in mutation_groups:
        if comp_type == 'mutant_vs_iso':
            mutant_label = results[mutation_name]['mutant_label']
            iso_label = results[mutation_name]['iso_label']
            key = f"{mutant_label} vs {iso_label}"
        elif comp_type == 'mutant_vs_hc':
            mutant_label = results[mutation_name]['mutant_label']
            hc_label = results[mutation_name]['hc_label']
            key = f"{mutant_label} vs {hc_label}"
        else:  # iso_vs_hc
            iso_label = results[mutation_name]['iso_label']
            hc_label = results[mutation_name]['hc_label']
            key = f"{iso_label} vs {hc_label}"

        values.append(results[mutation_name]['distances'][key])

    # Plot bars
    positions = mutation_positions + comparison_offsets[i]
    bars = plt.bar(positions, values, width=bar_width, color=colors[i],
                   label=comp_type.replace('_', ' ').title())

    # Add value labels
    for bar, val in zip(bars, values):
        plt.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{val:.4f}', ha='center', va='bottom', fontsize=8)

plt.xlabel('Mutation Group')
plt.ylabel('Optimal Transport Distance')
plt.title('Pairwise Distances Between Cell Populations Across Mutation Groups')
plt.xticks(mutation_positions, mutation_groups.keys())
plt.legend()
plt.tight_layout()
plt.savefig('ot_distances_all_mutations.png', dpi=300)
plt.show()

# 2. MDS Visualization for each mutation group
fig = plt.figure(figsize=(15, 5 * len(mutation_groups)))
gs = gridspec.GridSpec(len(mutation_groups), 1)

for idx, (mutation_name, data) in enumerate(results.items()):
    # Create distance matrix for MDS
    mutant_label = data['mutant_label']
    iso_label = data['iso_label']
    hc_label = data['hc_label']

    populations = [mutant_label, iso_label, hc_label]
    dist_matrix = np.zeros((3, 3))

    dist_matrix[0, 1] = dist_matrix[1, 0] = data['distances'][f"{mutant_label} vs {iso_label}"]
    dist_matrix[0, 2] = dist_matrix[2, 0] = data['distances'][f"{mutant_label} vs {hc_label}"]
    dist_matrix[1, 2] = dist_matrix[2, 1] = data['distances'][f"{iso_label} vs {hc_label}"]

    # Apply MDS
    mds = MDS(n_components=2, dissimilarity='precomputed', random_state=42)
    positions = mds.fit_transform(dist_matrix)

    # Plot MDS results
    ax = fig.add_subplot(gs[idx])
    colors = ['#E41A1C', '#377EB8', '#4DAF4A']  # Red, Blue, Green
    ax.scatter(positions[:, 0], positions[:, 1], s=300, c=colors)

    # Add labels for each point
    for i, pop in enumerate(populations):
        ax.annotate(pop, (positions[i, 0], positions[i, 1]),
                    fontsize=14, ha='center', va='center',
                    color='white', fontweight='bold')

    # Add edges with distance labels
    for i in range(len(populations)):
        for j in range(i+1, len(populations)):
            x1, y1 = positions[i]
            x2, y2 = positions[j]
            ax.plot([x1, x2], [y1, y2], 'k--', alpha=0.5)

            # Add distance label on the edge
            dist_value = dist_matrix[i, j]
            mid_x = (x1 + x2) / 2
            mid_y = (y1 + y2) / 2
            ax.text(mid_x, mid_y, f'{dist_value:.4f}',
                    fontsize=10, ha='center', va='center',
                    bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

    ax.set_title(f'MDS Visualization: {mutation_name} Mutation Group')
    ax.set_xlabel('Dimension 1')
    ax.set_ylabel('Dimension 2')
    ax.grid(alpha=0.3)
    ax.axis('equal')

plt.tight_layout()
plt.savefig('mds_all_mutations.png', dpi=300)
plt.show()

# 3. Cell distribution visualization using UMAP for all populations
# Use scanpy plotting for UMAP visualization of all mutation groups
all_groups = []
for mutation_name, data in results.items():
    all_groups.extend([data['mutant_label'], data['iso_label']])
all_groups.append(hc_label)  # Add HC once

sc.pl.umap(adata, color='Mutation', groups=all_groups,
           title='All Cell Populations in UMAP Space',
           save='all_cell_populations_umap.png')

# 4. Comparative Analysis - Which Mutation Has Most Dissimilar Isogenic Control?
iso_diff = []
for mutation_name, data in results.items():
    mutant_label = data['mutant_label']
    iso_label = data['iso_label']
    dist = data['distances'][f"{mutant_label} vs {iso_label}"]
    iso_diff.append((mutation_name, dist))

# Sort by distance (descending)
iso_diff.sort(key=lambda x: x[1], reverse=True)

plt.figure(figsize=(10, 6))
mutations = [x[0] for x in iso_diff]
distances = [x[1] for x in iso_diff]
bars = plt.bar(mutations, distances, color='#8856a7')

# Add value labels
for bar, val in zip(bars, distances):
    plt.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
             f'{val:.4f}', ha='center', va='bottom')

plt.ylabel('OT Distance')
plt.title('Distance Between Mutant and Isogenic Control Cells')
plt.xticks(rotation=0)
plt.tight_layout()
plt.savefig('mutant_vs_isogenic_comparison.png', dpi=300)
plt.show()

# 5. Summary Table of Results
summary_data = []
for mutation_name, data in results.items():
    mutant_label = data['mutant_label']
    iso_label = data['iso_label']
    hc_label = data['hc_label']

    dist_mutant_iso = data['distances'][f"{mutant_label} vs {iso_label}"]
    dist_mutant_hc = data['distances'][f"{mutant_label} vs {hc_label}"]
    dist_iso_hc = data['distances'][f"{iso_label} vs {hc_label}"]

    iso_closer_to = "Mutant" if dist_mutant_iso < dist_iso_hc else "HC"

    summary_data.append({
        'Mutation': mutation_name,
        'Mutant vs Iso': f"{dist_mutant_iso:.4f}",
        'Mutant vs HC': f"{dist_mutant_hc:.4f}",
        'Iso vs HC': f"{dist_iso_hc:.4f}",
        'Iso Closer To': iso_closer_to
    })

summary_df = pd.DataFrame(summary_data)
print("\nSummary of OT Analysis Results:")
print(summary_df.to_string(index=False))

# Save summary to CSV
summary_df.to_csv('ot_analysis_summary.csv', index=False)

# 6. Violin plots of the marginal distance distributions (for one mutation type)
# Choose the first mutation for example
example_mutation = list(mutation_groups.keys())[0]
data = results[example_mutation]
mutant_label = data['mutant_label']
iso_label = data['iso_label']
hc_label = data['hc_label']

plt.figure(figsize=(12, 6))

# Convert distance matrices to 1D arrays for plotting
distances_mutant_iso = data['cost_matrices'][f"{mutant_label} vs {iso_label}"].flatten()
distances_mutant_hc = data['cost_matrices'][f"{mutant_label} vs {hc_label}"].flatten()
distances_iso_hc = data['cost_matrices'][f"{iso_label} vs {hc_label}"].flatten()

# Create a DataFrame for seaborn
df = pd.DataFrame({
    'Distance': np.concatenate([distances_mutant_iso, distances_mutant_hc, distances_iso_hc]),
    'Comparison': np.concatenate([
        np.repeat(f'{mutant_label} vs {iso_label}', len(distances_mutant_iso)),
        np.repeat(f'{mutant_label} vs {hc_label}', len(distances_mutant_hc)),
        np.repeat(f'{iso_label} vs {hc_label}', len(distances_iso_hc))
    ])
})

# Plot
sns.violinplot(x='Comparison', y='Distance', data=df, palette=['#3274A1', '#E1812C', '#3A923A'])
plt.title(f'Distribution of Pairwise Cell Distances for {example_mutation} Mutation')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(f'{example_mutation}_distance_distributions.png', dpi=300)
plt.show()

# Print overall conclusion
print("\n=== Overall Conclusions ===")
print("Cross-mutation comparison of isogenic control similarity:")
for mutation_name, data in results.items():
    mutant_label = data['mutant_label']
    iso_label = data['iso_label']
    hc_label = data['hc_label']

    dist_mutant_iso = data['distances'][f"{mutant_label} vs {iso_label}"]
    dist_iso_hc = data['distances'][f"{iso_label} vs {hc_label}"]

    closer_to = "mutant cells" if dist_mutant_iso < dist_iso_hc else "healthy controls"
    difference = abs(dist_mutant_iso - dist_iso_hc)

    print(f"For {mutation_name}: {iso_label} is more similar to {closer_to} ")
    print(f"  (difference in OT distance: {difference:.4f})")

# Determine which mutation's isogenic control most closely resembles the mutant
mutant_iso_similarities = [(m, results[m]['distances'][f"{results[m]['mutant_label']} vs {results[m]['iso_label']}"])
                          for m in mutation_groups]
most_similar = min(mutant_iso_similarities, key=lambda x: x[1])
print(f"\nThe isogenic control cells that most closely resemble their respective mutant cells are from the {most_similar[0]} mutation group")
print(f"(OT distance: {most_similar[1]:.4f})")