### Metabolic Barcoding

In [None]:
import pandas as pd
import os
import scanpy as sc
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler 
import numpy as np

base_dir = r"..\results\high_low_no"
adata = sc.read_h5ad(os.path.join(base_dir, "combined_adata_leiden_merged.h5ad"))
adata.X = MaxAbsScaler().fit_transform(adata.X.astype(np.float32))

In [None]:
# Map cluster to phenotype
phenotype_map = {
    "0": "Neurons",
    "1": "Neurons",
    "2": "Astrocytes",
    "3": "Neurons",
    "4": "Neurons",
    "5": "Oligodendrocytes",
    "6": "Endothelial cells",
    "7": "Neurons",
    "8": "Endothelial cells",
    "9": "Neurons"
} 

adata.obs['cell_phenotype'] = adata.obs['leiden_merged'].map(phenotype_map)

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

# === Setup ===
output_dir = r"results\high_low_no"
os.makedirs(output_dir, exist_ok=True)

# Ensure 'msi_aligned_labels' is categorical
adata.obs['cell_phenotype'] = adata.obs['cell_phenotype'].astype('category')

# Generate distinct colors using seaborn
phenotype_colors = {
    "Astrocytes": "#1f77b4",         # blue
    "Oligodendrocytes": "#2ca02c",   # green
    "Neurons": "#d62728",            # red
    "Microglia": "#9467bd",          # purple
    "Endothelial cells": "#e377c2",  # pink
}

# === Plot for each condition separately ===
for condition in ['High', 'Low', 'No']:
    # Subset
    adata_temp = adata[adata.obs['condition'] == condition].copy()

    # Create plotting DataFrame
    plot_df = adata_temp.obs[['x_centroid', 'y_centroid', 'cell_phenotype']].copy()
    plot_df = plot_df.dropna(subset=['cell_phenotype'])
    plot_df['cell_phenotype'] = plot_df['cell_phenotype'].astype(str)
    plot_df['color'] = plot_df['cell_phenotype'].map(phenotype_colors)
    plot_df = plot_df.dropna(subset=['color'])

    # Plot
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.scatter(
        plot_df['x_centroid'],
        plot_df['y_centroid'],
        c=plot_df['color'].values,
        s=1,
        linewidth=0,
        alpha=1.0
    )
    ax.set_aspect('equal')
    ax.axis('off')
    ax.invert_yaxis()

    # Add legend (optional)
    handles = [
        mpatches.Patch(color=color, label=label)
        for label, color in phenotype_colors.items()
        if label in plot_df['cell_phenotype'].values
    ]
    ax.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=6, title="Cell Phenotype")

    plt.tight_layout()

    # Save
    save_path = os.path.join(output_dir, f"{condition.lower()}_cell_phenotype_spatial_map.png")
    plt.savefig(save_path, dpi=600, bbox_inches='tight', pad_inches=0.0)
    plt.show()

In [None]:
import numpy as np
from scipy.spatial import Delaunay
from scipy.spatial import cKDTree

# Step 1: Define Graph Construction
def create_graph(adata, distance_threshold=None):
    """
    Create a graph efficiently using KDTree for distance thresholding.
    
    Parameters:
    - adata: AnnData object containing centroids in .obs
    - distance_threshold: Maximum distance for connecting two nodes
    
    Returns:
    - G: NetworkX graph
    - pos: Dictionary of node positions
    """
    # Extract centroids
    centroids = adata.obs[['x_centroid', 'y_centroid']].values
    
    # Create graph
    G = nx.Graph()
    for i, (x, y) in enumerate(centroids):
        G.add_node(i, pos=(x, y))  # Add nodes with positions
    
    if distance_threshold:
        # Use KDTree for efficient neighbor search
        tree = cKDTree(centroids)
        pairs = tree.query_pairs(r=distance_threshold)
        G.add_edges_from(pairs)  # Add edges directly from KDTree output
    else:
        # Use Delaunay triangulation if no threshold is provided
        tri = Delaunay(centroids)
        for simplex in tri.simplices:
            G.add_edge(simplex[0], simplex[1])
            G.add_edge(simplex[1], simplex[2])
            G.add_edge(simplex[2], simplex[0])
    
    # Extract positions
    pos = nx.get_node_attributes(G, 'pos')
    
    return G, pos

In [None]:
import networkx as nx
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder
from scipy.spatial import cKDTree

def prepare_graph_data_cell_type(msi_adata, distance_threshold=10):
    """
    Create separate PyG graphs for WT and PS19, using lipid features,
    excluding cells labeled as 'Others' in 'cell_phenotype'.

    Parameters:
    - msi_adata: AnnData object with 'condition' and 'cell_phenotype' in .obs
    - distance_threshold: max distance for graph edges

    Returns:
    - graph_data_dict: {condition: PyG Data}
    - label_encoder: fitted LabelEncoder for consistent labels
    """
    graph_data_dict = {}
    label_encoder = LabelEncoder()

    # === Global label fitting on all valid cells (WT + PS19, excluding 'Others') ===
    valid_mask = msi_adata.obs['cell_phenotype'] != 'Others'
    label_encoder.fit(msi_adata.obs.loc[valid_mask, 'cell_phenotype'])

    for condition in ['High', 'Low', 'No']:
        adata_cond = msi_adata[
            (msi_adata.obs['condition'] == condition) &
            (msi_adata.obs['cell_phenotype'] != 'Others')
        ].copy()

        if adata_cond.n_obs == 0:
            continue

        # === Extract lipid features
        lipid_channels = [c for c in adata_cond.var_names if c.startswith("mz_")]
        lipid_X = adata_cond[:, lipid_channels].X
        lipid_X = lipid_X.toarray() if hasattr(lipid_X, 'toarray') else lipid_X
        features = torch.tensor(lipid_X, dtype=torch.float)

        # === Encode labels
        labels = torch.tensor(
            label_encoder.transform(adata_cond.obs['cell_phenotype']),
            dtype=torch.long
        )

        # === Build edges using cKDTree
        centroids = adata_cond.obs[['x_centroid', 'y_centroid']].values
        tree = cKDTree(centroids)
        pairs = tree.query_pairs(r=distance_threshold)

        G = nx.Graph()
        for i, (x, y) in enumerate(centroids):
            G.add_node(i, pos=(x, y))
        G.add_edges_from(pairs)

        edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)  # undirected

        # === Create PyG graph object
        graph_data = Data(x=features, edge_index=edge_index, y=labels)
        graph_data_dict[f"{condition}"] = graph_data

    return graph_data_dict, label_encoder

In [None]:
import torch
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv, GATConv, GINConv
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        """
        A Graph Neural Network with configurable number of layers.

        Parameters:
        - in_channels: Number of input features.
        - hidden_channels: Number of hidden units in each layer.
        - out_channels: Number of output classes.
        - num_layers: Total number of layers (default=2).
        """
        super(GNN, self).__init__()
        
        # Ensure at least 2 layers
        assert num_layers >= 2, "Number of layers must be at least 2."

        # Create a list to store convolutional layers
        self.convs = torch.nn.ModuleList()

        # Input layer
        self.convs.append(GraphConv(in_channels, hidden_channels))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GraphConv(hidden_channels, hidden_channels))
        
        # Output layer
        self.convs.append(GraphConv(hidden_channels, out_channels))

    def forward(self, x, edge_index, **kwargs):
        # Pass through all convolutional layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:  # Apply activation except for the last layer
                x = F.relu(x)
        return F.log_softmax(x, dim=1)

# Save model function
def save_model(model, optimizer, epoch, path="gnn_model.pth"):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }, path)
    print(f"Model saved to {path}")

# Load model function
def load_model(path="gnn_model.pth"):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Model loaded from {path}, starting at epoch {epoch}")
    return epoch

In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold

# === Reproducibility ===
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# === Prepare single-graph data ===
graph_data_dict, label_encoder = prepare_graph_data_cell_type(adata, distance_threshold=10)  # uses 'cell_phenotype'
graphs = list(graph_data_dict.values())
full_batch = Batch.from_data_list(graphs)
y_all = full_batch.y.cpu().numpy()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = r"..\results\high_low_no"
os.makedirs(save_dir, exist_ok=True)

# === K-fold setup ===
k = 3
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
results = []

# === Cross-validation loop ===
for fold, (train_val_idx, test_idx) in enumerate(skf.split(np.zeros(len(y_all)), y_all)):
    print(f"\n[Fold {fold+1}/{k}]")

    # Split train/val
    val_split = int(0.15 * len(train_val_idx))
    np.random.shuffle(train_val_idx)
    val_idx = train_val_idx[:val_split]
    train_idx = train_val_idx[val_split:]

    # Assign masks
    full_batch.train_mask = torch.zeros(len(y_all), dtype=torch.bool)
    full_batch.val_mask = torch.zeros(len(y_all), dtype=torch.bool)
    full_batch.test_mask = torch.zeros(len(y_all), dtype=torch.bool)
    full_batch.train_mask[train_idx] = True
    full_batch.val_mask[val_idx] = True
    full_batch.test_mask[test_idx] = True

    # Save batch for evaluation
    torch.save(full_batch.cpu(), f"{save_dir}/fold_{fold+1}_graph.pt")

    # Wrap in DataLoader
    loader = DataLoader([full_batch], batch_size=1, shuffle=False)

    # === Model setup ===
    model = GNN(
        in_channels=full_batch.num_node_features,
        hidden_channels=64,
        out_channels=len(label_encoder.classes_),
        num_layers=5
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 200

    # === Training loop ===
    for epoch in range(2000):
        model.train()
        for batch in loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        # === Validation and early stopping ===
        model.eval()
        with torch.no_grad():
            for batch in loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index)
                val_loss = criterion(out[batch.val_mask], batch.y[batch.val_mask])
                pred = out.argmax(dim=1)
                correct = (pred[batch.test_mask] == batch.y[batch.test_mask]).sum().item()
                total = batch.test_mask.sum().item()
                test_acc = correct / total if total > 0 else 0

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
            }, f"{save_dir}/cell_type_model_fold_{fold+1}_best.pth")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}, Test Acc: {test_acc:.4f}")

    results.append({
        'fold': fold + 1,
        'best_val_loss': best_val_loss.item(),
        'test_accuracy': test_acc
    })

# === Save CV results ===
results_df = pd.DataFrame(results)
results_df.to_csv(f"{save_dir}/cell_type_crossval_results.csv", index=False)
print("Cross-validation completed!")

In [None]:
import torch
import os
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.preprocessing import label_binarize
from torch_geometric.loader import DataLoader
import seaborn as sns
import matplotlib.pyplot as plt

# === Parameters ===
save_dir = r"..\results\high_low_no"
os.makedirs(save_dir, exist_ok=True)

hidden_channels = 64
num_layers = 5
num_folds = 3

graph_data_dict, label_encoder = prepare_graph_data_cell_type(adata, distance_threshold=10)
graphs = list(graph_data_dict.values())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
target_names = label_encoder.classes_
num_classes = len(target_names)

cell_group_color_dict = {
    "Astrocytes": "#1f77b4",         # blue
    "Oligodendrocytes": "#2ca02c",   # green
    "Neurons": "#d62728",            # red
    "Microglia": "#9467bd",          # purple
    "Endothelial cells": "#e377c2",  # pink
}

# === Evaluation Function ===
def evaluate_gnn(model, loader, num_classes):
    model.eval()
    all_probs, all_preds, all_labels = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index)
            prob = torch.softmax(out, dim=1)
            pred = out.argmax(dim=1)

            mask = batch.test_mask
            all_probs.append(prob[mask].cpu().numpy())
            all_preds.append(pred[mask].cpu().numpy())
            all_labels.append(batch.y[mask].cpu().numpy())
    return np.concatenate(all_preds), np.concatenate(all_labels), np.concatenate(all_probs)

# === Run K-Fold Evaluation ===
all_auc_dicts = []
all_acc_dicts = []

for fold_id in range(1, num_folds + 1):
    print(f"\n=== Evaluating Fold {fold_id} ===")

    model_path = f"{save_dir}/cell_type_model_fold_{fold_id}_best.pth"
    graph_path = f"{save_dir}/fold_{fold_id}_graph.pt"

    if not os.path.exists(model_path) or not os.path.exists(graph_path):
        print(f"Skipping fold {fold_id}: missing model or graph.")
        continue

    model = GNN(
        in_channels=graphs[0].num_node_features,
        hidden_channels=hidden_channels,
        out_channels=num_classes,
        num_layers=num_layers
    ).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])

    graph = torch.load(graph_path)
    loader = DataLoader([graph], batch_size=1, shuffle=False)

    preds, labels, probs = evaluate_gnn(model, loader, num_classes)

    report = classification_report(labels, preds, target_names=target_names, output_dict=True)
    accs = {'fold_id': fold_id}
    for cname in target_names:
        accs[cname] = report[cname]['f1-score']
    all_acc_dicts.append(accs)

    labels_bin = label_binarize(labels, classes=np.arange(num_classes))
    aucs = {'fold_id': fold_id}
    for i, cname in enumerate(target_names):
        try:
            auc = roc_auc_score(labels_bin[:, i], probs[:, i])
        except ValueError:
            auc = np.nan
        aucs[cname] = auc
    all_auc_dicts.append(aucs)

    # === Print Metrics ===
    print(classification_report(labels, preds, target_names=target_names))
    print("Per-Class AUCs:")
    for cname in target_names:
        val = aucs[cname]
        print(f"  {cname}: {val:.4f}" if not np.isnan(val) else f"  {cname}: N/A")

# === Save Metrics ===
acc_df = pd.DataFrame(all_acc_dicts)
auc_df = pd.DataFrame(all_auc_dicts)

acc_df.to_csv(f"{save_dir}/cell_type_per_class_accuracy.csv", index=False)
auc_df.to_csv(f"{save_dir}/cell_type_per_class_auc.csv", index=False)
print("\n✓ Saved per-class accuracy and AUC CSVs.")

# === Prepare for Boxplots ===
def melt_with_group(df, metric_name):
    melted = df.melt(id_vars="fold_id", var_name="Class", value_name=metric_name)
    melted["Group"] = melted["Class"].map(lambda c: next((g for g in cell_group_color_dict if g in c), "Others"))
    melted["Color"] = melted["Group"].map(cell_group_color_dict).fillna("gray")
    return melted

acc_melt = melt_with_group(acc_df, "Accuracy")
auc_melt = melt_with_group(auc_df, "AUC")

# === Plotting Function ===
def plot_metric(df, metric, filename):
    plt.figure(figsize=(12, 6))
    sns.boxplot(
        data=df,
        x="Class",
        y=metric,
        palette=df.set_index("Class")["Color"].to_dict()
    )
    
    plt.title(f"{metric} per Cell Type", fontsize=18)      # Title font size
    plt.ylabel(metric, fontsize=16)                        # Y-axis label font size
    plt.xlabel("Cell Type", fontsize=16)                   # X-axis label font size
    plt.xticks(rotation=45, ha='right', fontsize=16)       # X-tick font size
    plt.yticks(fontsize=16)                                # Y-tick font size
    plt.tight_layout()
    plt.savefig(f"{save_dir}/{filename}.png", dpi=600)
    plt.show()

def plot_metric_with_errorbars(df, metric, filename):
    # Compute summary stats
    summary_df = df.groupby("Class").agg(
        mean=(metric, "mean"),
        std=(metric, "std"),
        Group=("Group", "first"),
        Color=("Color", "first")
    ).reset_index()

    # Add dummy hue to trigger custom palette handling
    summary_df["hue"] = summary_df["Class"]

    # Plot
    plt.figure(figsize=(12, 6))
    ax = sns.barplot(
        data=summary_df,
        x="Class",
        y="mean",
        hue="hue",  # dummy hue just to enable palette per class
        palette=summary_df.set_index("Class")["Color"].to_dict(),
        errorbar=None  # disable seaborn CI bars
    )

    # Add error bars manually
    for i, row in summary_df.iterrows():
        ax.errorbar(
            i, row["mean"],
            yerr=row["std"],
            fmt='none',
            ecolor='black',
            elinewidth=1.5,
            capsize=4
        )

    # Remove redundant legend if it appears
    legend = ax.get_legend()
    if legend is not None:
        legend.remove()

    # Format
    plt.title(f"GNN {metric} per Cell Phenotype", fontsize=18)
    plt.ylabel(metric, fontsize=16)
    plt.xlabel("Cell Type", fontsize=16)
    plt.xticks(rotation=45, ha='right', fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"{save_dir}/{filename}_barplot.png", dpi=600)
    plt.show()

# === Plot AUC and Accuracy ===
plot_metric(acc_melt, "Accuracy", f"cell_type_fold_accuracy_boxplot")
plot_metric(auc_melt, "AUC", f"cell_type_fold_auc_boxplot")
plot_metric_with_errorbars(acc_melt, "Accuracy", "cell_type_fold_accuracy")
plot_metric_with_errorbars(auc_melt, "AUC", "cell_type_fold_auc")

In [None]:
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.explain import Explainer, CaptumExplainer
from tqdm import tqdm
import numpy as np
import pandas as pd
import os

# Parameters
n_folds = 3
attr_method = 'IntegratedGradients'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_labels = [c for c in adata.var_names if c.startswith("mz_")]
num_classes = len(label_encoder.classes_)

all_scores = {"all": [], "test": []}

def calculate_average_explanations(explainer, batch, mask=None, n_iterations=1):
    scores = []
    for _ in range(n_iterations):
        explanation = explainer(batch.x, batch.edge_index)
        node_mask = explanation.node_mask.cpu().detach().numpy()
        if mask is not None:
            node_mask = node_mask[mask.cpu().numpy()]
        scores.append(node_mask.sum(axis=0))
    return np.mean(scores, axis=0)

# === Main Loop Over Folds ===
for fold_id in range(1, n_folds + 1):
    print(f"\n=== Processing Fold {fold_id} ===")

    # Load fold graph
    graph_path = f"{save_dir}/fold_{fold_id}_graph.pt"
    if not os.path.exists(graph_path):
        print(f"Graph not found: {graph_path}")
        continue
    batch = torch.load(graph_path)
    loader = DataLoader([batch], batch_size=1, shuffle=False)

    # Load model
    model_path = f"{save_dir}/cell_type_model_fold_{fold_id}_best.pth"
    if not os.path.exists(model_path):
        print(f"Model not found: {model_path}")
        continue

    model = GNN(
        in_channels=batch.num_node_features,
        hidden_channels=64,
        out_channels=num_classes,
        num_layers=5
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Initialize explainer
    explainer = Explainer(
        model=model,
        algorithm=CaptumExplainer(attr_method),
        explanation_type="model",
        node_mask_type="attributes",
        edge_mask_type=None,
        model_config=dict(
            mode='multiclass_classification',
            task_level='node',
            return_type='probs',
        ),
    )

    for batch_data in tqdm(loader, desc=f"Fold {fold_id} Explanation"):
        batch_data = batch_data.to(device)

        avg_all = calculate_average_explanations(explainer, batch_data, mask=None, n_iterations=1)
        all_scores["all"].append(avg_all)

        if batch_data.test_mask.any():
            avg_test = calculate_average_explanations(explainer, batch_data, mask=batch_data.test_mask, n_iterations=1)
            all_scores["test"].append(avg_test)

# Save scores for all nodes
scores_df_all = pd.DataFrame({
    feature_labels[i]: [np.abs(s[i]) for s in all_scores["all"]]
    for i in range(len(feature_labels))
})
scores_df_all["Fold_ID"] = list(range(1, len(all_scores["all"]) + 1))
scores_df_all.to_csv(f"{save_dir}\\cell_type_model_{attr_method}_all_scores.csv", index=False)

# Save scores for test nodes
if all_scores["test"]:
    scores_df_test = pd.DataFrame({
        feature_labels[i]: [np.abs(s[i]) for s in all_scores["test"]]
        for i in range(len(feature_labels))
    })
    scores_df_test["Fold_ID"] = list(range(1, len(all_scores["test"]) + 1))
    scores_df_test.to_csv(f"{save_dir}\\cell_type_model_{attr_method}_test_scores.csv", index=False)

# Print summary
print("✓ Feature importance scores saved for all and test nodes.")

In [None]:
import pandas as pd
import numpy as np
import os

# === Experiment Setup ===
save_dir = r"results\high_low_no"
load_dir = save_dir
attr_method = 'IntegratedGradients'

# === File Paths ===
all_scores_file = f"{load_dir}\\cell_type_model_{attr_method}_all_scores.csv"
test_scores_file = f"{load_dir}\\cell_type_model_{attr_method}_test_scores.csv"

# === Load All-Nodes Scores ===
if os.path.exists(all_scores_file):
    scores_df_all = pd.read_csv(all_scores_file)
    print("Loaded all-nodes scores CSV.")
else:
    raise FileNotFoundError(f"All-nodes scores file not found: {all_scores_file}")

# === Load Test-Nodes Scores (Optional) ===
if os.path.exists(test_scores_file):
    scores_df_test = pd.read_csv(test_scores_file)
    print("Loaded test-nodes scores CSV.")
else:
    print("Test-nodes scores file not found. Proceeding with only all-nodes scores.")
    scores_df_test = None

# === Extract Feature Labels (exclude Fold_ID) ===
feature_labels = [col for col in scores_df_all.columns if col != "Fold_ID"]

# === Compute Averages Across Folds ===
average_scores_all = scores_df_all[feature_labels].mean(axis=0).values
scores_df_all = scores_df_all.drop(columns=["Fold_ID"])

if scores_df_test is not None:
    average_scores_test = scores_df_test[feature_labels].mean(axis=0).values
    scores_df_test = scores_df_test.drop(columns=["Fold_ID"])

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

# Prepare data for boxplot
# Assume `all_model_scores` is a list of arrays containing feature importance scores for each model
scores_array = np.array([np.abs(model_scores) for model_scores in scores_df_test.values])

# Convert to DataFrame for easier plotting
feature_labels = [c for c in adata.var_names if c.startswith("mz_")]
assert scores_array.shape[1] == len(feature_labels), "Feature labels and score dimensions do not match!"

scores_df = pd.DataFrame(scores_array, columns=feature_labels)

# Plot boxplot for all features (vertical orientation)
plt.figure(figsize=(14, 6))
sns.boxplot(data=scores_df, orient='v', color='skyblue')
plt.title("Feature Importance Distribution Across Models", fontsize=16)
plt.xlabel("Features", fontsize=12)
plt.ylabel("Importance Score", fontsize=12)

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, fontsize=6, ha='right')
plt.tight_layout()
plt.show()

In [None]:
from typing import Optional, List
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

def visualize_feature_importance(
    all_scores: List[np.ndarray],
    labels: Optional[List[str]] = None,
    top_k: Optional[int] = None,
    save_path: Optional[str] = None,
    individual_plots: bool = False,
    individual_save_dir: Optional[str] = None,
):
    """
    Visualize aggregated feature importance scores (mean and standard deviation) across multiple models
    and optionally plot individual feature importance scores for each model.
    
    Parameters:
    - all_scores: List of arrays containing feature importance scores for each model.
    - labels: Optional list of feature labels (defaults to Feature_1, Feature_2, ...).
    - top_k: Show only top_k features (optional).
    - save_path: Path to save the aggregated plot (optional).
    - individual_plots: Whether to generate individual plots for each model (default: False).
    - individual_save_dir: Directory to save individual plots if `individual_plots` is True (optional).
    """
    # Calculate mean and std of feature importance scores
    scores_array = np.array(all_scores)
    mean_scores = scores_array.mean(axis=0)
    std_scores = scores_array.std(axis=0)

    if labels is None:
        labels = [f"Feature_{i+1}" for i in range(len(mean_scores))]  # Default labels

    # Create a DataFrame for sorting and slicing
    df = pd.DataFrame({
        'mean_score': mean_scores,
        'std_score': std_scores
    }, index=labels)
    df = df.sort_values('mean_score', ascending=False)  # Sort by mean score
    df = df.round(decimals=3)  # Round for better display

    # Select top_k features if specified
    if top_k is not None:
        df = df.head(top_k)

    # Extract sorted features and their scores for plotting
    sorted_features = df.index.tolist()
    sorted_mean_scores = df['mean_score'].tolist()

    # Plot aggregated feature importance
    plt.figure(figsize=(14, 8))  # Wider figure for better x-label spacing
    plt.bar(sorted_features, sorted_mean_scores, alpha=0.7, color="skyblue")
    plt.xticks(rotation=45, ha="right")  # Rotate and align x-axis labels
    plt.title("Top Features - Aggregated Importance")
    plt.xlabel("Feature")
    plt.ylabel("Mean Importance")
    plt.tight_layout()

    # Save or display the aggregated plot
    if save_path is not None:
        plt.savefig(f"{save_path}/cell_type_{attr_method}_feature_importance_aggregated.png", bbox_inches="tight")
        print(f"Aggregated feature importance plot saved to {save_path}")
    else:
        plt.show()

    # Generate individual plots for each model
    if individual_plots:
        for i, scores in enumerate(all_scores):
            # Create individual DataFrame and sort by scores
            individual_df = pd.DataFrame({'score': scores}, index=labels)
            individual_df = individual_df.sort_values('score', ascending=False)

            if top_k is not None:
                individual_df = individual_df.head(top_k)

            sorted_individual_features = individual_df.index.tolist()
            sorted_individual_scores = individual_df['score'].tolist()

            # Plot individual feature importance
            plt.figure(figsize=(14, 8))  # Wider figure for better x-label spacing
            plt.bar(sorted_individual_features, sorted_individual_scores, alpha=0.7, color="lightcoral")
            plt.xticks(rotation=45, ha="right")  # Rotate and align x-axis labels
            plt.title(f"Top Features for Model {i+1}")
            plt.xlabel("Feature")
            plt.ylabel("Importance")
            plt.tight_layout()

            if individual_save_dir:
                plt.savefig(f"{individual_save_dir}/cell_type_{attr_method}_feature_importance_model_{i+1}.png", bbox_inches="tight")
            else:
                plt.show()

# Example Usage
scores_list = [np.abs(model_scores) for model_scores in scores_df_test.values]
feature_labels = [c for c in adata.var_names if c.startswith("mz_")]

# Visualize aggregated importance for top 20 features and save individual plots
visualize_feature_importance(
    all_scores=scores_list, 
    labels=feature_labels, 
    top_k=50,  # Adjust to display top 50 features,
    save_path=save_dir,
    individual_plots=True,
    individual_save_dir=save_dir
)

In [None]:
from typing import Optional, List
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

# ─────────────────────────────────────────────────────────────
# 1. Short m/z label dictionary
# ─────────────────────────────────────────────────────────────
mz_shortnames = {
    'mz_699.4500': 'PA 34:1 (699.45 m/z)',
    'mz_701.5500': 'PA 36:1 (701.55 m/z)',
    'mz_719.5500': 'PE 34:0 (719.55 m/z)',
    'mz_745.5500': 'PG 34:2 (745.55 m/z)',
    'mz_747.4500': 'PG 34:1 (747.45 m/z)',
    'mz_748.5500': 'PE 36:1 (748.55 m/z)',
    'mz_762.5500': 'PE 38:6 (762.55 m/z)',
    'mz_772.5500': 'PE 38:1 (772.55 m/z)',
    'mz_790.5500': 'PS 36:1 (790.55 m/z)',
    'mz_794.5500': 'PS 36:0 (794.55 m/z)',
    'mz_880.6500': 'PI 38:4 (880.65 m/z)',
    'mz_888.6500': 'ST C24:1 (888.65 m/z)',
    'mz_889.6500': 'ST 42:2 (889.65 m/z)',
    'mz_890.6500': 'ST C24:0 (890.65 m/z)',
    'mz_905.6500': 'ST C24:0 (OH) (905.65 m/z)',
}

# ─────────────────────────────────────────────────────────────
# 2. Main visualization function
# ─────────────────────────────────────────────────────────────
def visualize_feature_importance(
    all_scores: List[np.ndarray],
    labels: Optional[List[str]] = None,
    top_k: Optional[int] = None,
    save_path: Optional[str] = None,
    individual_plots: bool = False,
    individual_save_dir: Optional[str] = None,
    attr_method: str = "gnnexplainer",
    show_ranks: bool = True,
    color_map: str = "viridis"
):
    """
    Visualize aggregated and per-model feature importance scores with clean layout.
    """
    scores_array = np.array(all_scores)
    mean_scores = scores_array.mean(axis=0)
    std_scores = scores_array.std(axis=0)

    if labels is None:
        labels = [f"Feature_{i+1}" for i in range(len(mean_scores))]

    # Map m/z labels to short names (with m/z in parentheses)
    labels = [mz_shortnames.get(lab, lab) for lab in labels]

    df = pd.DataFrame({
        "mean_score": mean_scores,
        "std_score": std_scores
    }, index=labels).sort_values("mean_score", ascending=False)

    if top_k:
        df = df.head(top_k)

    # === AGGREGATED PLOT === #
    fig, ax = plt.subplots(figsize=(10, max(6, 0.3 * len(df))), constrained_layout=True)
    cmap = plt.get_cmap(color_map)
    norm = (df["mean_score"] - df["mean_score"].min()) / (df["mean_score"].max() - df["mean_score"].min())
    colors = cmap(norm)

    ax.barh(df.index[::-1], df["mean_score"][::-1],
             color=colors[::-1],
            edgecolor='black', alpha=0.9)

    ax.set_xlabel("Mean Importance Score", fontsize=16)
    plt.xticks(fontsize=16)
    ax.set_title(f"Top {len(df)} Lipids", fontsize=18)

    if show_ranks:
        for i, (val, label) in enumerate(zip(df["mean_score"][::-1], df.index[::-1])):
            ax.text(val + 0.01, i, f"{i+1}", va='center', color='white', fontsize=16)

    plt.yticks(fontsize=16)
    ax.set_ylim(-0.5, len(df) - 0.5)

    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f"{save_path}/cell_type_{attr_method}_feature_importance_aggregated.png",
                    bbox_inches="tight", pad_inches=0.0, dpi=600)
        print(f"[✓] Saved aggregated plot to {save_path}")
    else:
        plt.show()

    # === INDIVIDUAL PLOTS === #
    if individual_plots:
        os.makedirs(individual_save_dir, exist_ok=True)
        for i, scores in enumerate(all_scores):
            ind_df = pd.DataFrame({'score': scores}, index=labels).sort_values('score', ascending=False)
            if top_k:
                ind_df = ind_df.head(top_k)

            fig, ax = plt.subplots(figsize=(10, max(6, 0.3 * len(ind_df))), constrained_layout=True)
            color_vals = cmap((ind_df["score"] - ind_df["score"].min()) / (ind_df["score"].max() - ind_df["score"].min()))
            ax.barh(ind_df.index[::-1], ind_df["score"][::-1], color=color_vals[::-1],
                    edgecolor='black', alpha=0.9)

            ax.set_xlabel("Importance Score", fontsize=16)
            plt.xticks(fontsize=16)
            ax.set_title(f"Top {len(ind_df)} Lipids (Model {i+1})", fontsize=18)

            if show_ranks:
                for j, (val, label) in enumerate(zip(ind_df["score"][::-1], ind_df.index[::-1])):
                    ax.text(val + 0.01, j, f"{j+1}", va='center', color='white', fontsize=16)

            plt.yticks(fontsize=16)
            ax.set_ylim(-0.5, len(ind_df) - 0.5)

            plt.savefig(f"{individual_save_dir}/cell_type_{attr_method}_feature_importance_model_{i+1}.png",
                        bbox_inches="tight", pad_inches=0.0, dpi=600)
            plt.show()
            print(f"[✓] Saved model {i+1} plot to {individual_save_dir}")

# ─────────────────────────────────────────────────────────────
# 3. Run the function with your inputs
# ─────────────────────────────────────────────────────────────
visualize_feature_importance(
    all_scores=scores_list,
    labels=feature_labels,
    top_k=15,
    save_path=save_dir,
    individual_plots=True,
    individual_save_dir=save_dir,
    attr_method=attr_method,
    show_ranks=True
)

In [None]:
from kneed import KneeLocator
import numpy as np
import matplotlib.pyplot as plt

# Sort features by general importance
sorted_indices = np.argsort(average_scores_test)[::-1]
sorted_features = [feature_labels[i] for i in sorted_indices]
sorted_importance = average_scores_test[sorted_indices]

# Function to find all knee points
def find_knee_points(x, y, max_knees=3, curve="convex", direction="decreasing"):
    knees = []
    remaining_x = x
    remaining_y = y
    
    for _ in range(max_knees):
        # Use KneeLocator to find the next knee point
        knee_locator = KneeLocator(range(len(remaining_y)), remaining_y, curve=curve, direction=direction)
        k = knee_locator.knee
        
        if k is not None:
            knees.append((remaining_x[k], remaining_y[k]))  # Store the knee point
            # Split data and continue searching
            remaining_x = remaining_x[k + 1:]
            remaining_y = remaining_y[k + 1:]
            
            if len(remaining_y) < 2:  # Stop if fewer than two points remain
                break
        else:
            break

    return knees

# Find all knee points
x_range = range(len(sorted_importance))
knee_points = find_knee_points(x_range, sorted_importance, max_knees=5)

# Get features corresponding to each knee
knee_indices = [kp[0] for kp in knee_points]
knee_features = [sorted_features[ki] for ki in knee_indices]
knee_importance = [sorted_importance[ki] for ki in knee_indices]

# Plot the sorted importance with all knee points
plt.figure(figsize=(8, 5))
plt.plot(range(1,len(x_range)+1), sorted_importance, label="Sorted Importance", marker="o")
for ki in knee_indices:
    plt.axvline(ki+1, color="red", linestyle="--", label=f"Knee Point (k={ki+1})")
plt.xlabel("Metabolites", fontsize=16)
plt.ylabel("Importance", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.title("Sorted Importances of Metabolites")
plt.legend()
plt.tight_layout()
plt.xlim([1,len(x_range)+1])
plt.grid()
plt.savefig(f"{save_dir}/sorted_importance_with_cutoff.png", dpi=600)
plt.show()

# Print knee points and features
print("Knee points and corresponding features:")
for i, (feature, importance) in enumerate(zip(knee_features, knee_importance), start=1):
    print(f"Knee {i}: {feature} - Importance: {importance:.4f}")

np.save(os.path.join(save_dir,'sorted_features.npy'), sorted_features)

In [None]:
knee_point = 15
top_k_features = sorted_features[:knee_point]
print(top_k_features)

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.stats import mode

# === Parameters ===
num_folds = 3
num_thresholds = 9
phenotype_to_barcodes = {}  # collect per fold
phenotype_names_all = set()

# === Helper Functions ===
def compute_thresholds(data, num_thresholds):
    quantiles = np.linspace(0, 1, num_thresholds + 2)[1:-1]
    return {feature: data[feature].quantile(quantiles).values for feature in data.columns}

def categorize_values(data, thresholds):
    categorical_matrix = np.zeros(data.shape, dtype=int)
    for i, feature in enumerate(data.columns):
        for level, threshold in enumerate(thresholds[feature]):
            categorical_matrix[:, i] += (data[feature].values > threshold).astype(int)
    return categorical_matrix

def plot_barcode_matrix(matrix, row_labels, col_labels, color_map, filename):
    fig, ax = plt.subplots(figsize=(10, 5))
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            level = matrix[i, j]
            ax.scatter(
                j, i,
                color=color_map[level],
                marker='D',
                s=160,
                edgecolors="black",
                linewidth=0.6
            )
    ax.set_xticks(range(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha='right', fontsize=14)
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_yticklabels(row_labels, fontsize=14)
    ax.set_ylim(-0.5, len(row_labels) - 0.5)
    ax.set_xlabel("Metabolite Features", fontsize=16)
    ax.set_ylabel("Cell Phenotypes", fontsize=16)
    ax.grid(True, linestyle="--", linewidth=0.4, alpha=0.3)
    ax.tick_params(axis='both', which='major', length=0)
    plt.tight_layout()
    plt.savefig(filename, dpi=600, bbox_inches='tight', pad_inches=0.0)
    plt.show()

# === Color map setup ===
'''unique_colors = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.XKCD_COLORS.values())
color_list = unique_colors[:num_thresholds + 1]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}'''

color_cmap = plt.cm.get_cmap('tab20')
color_list = [mcolors.to_hex(color_cmap(i)) for i in range(num_thresholds + 1)]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}

# === Loop through folds ===
for fold_id in range(1, num_folds + 1):
    graph_path = f"{load_dir}/fold_{fold_id}_graph.pt"
    if not os.path.exists(graph_path):
        continue

    graph = torch.load(graph_path)
    train_mask = graph.train_mask.cpu().numpy()

    X_all = pd.DataFrame(graph.x.cpu().numpy(), columns=feature_labels)
    X_train = X_all.loc[train_mask, top_k_features].reset_index(drop=True)
    y_train = graph.y.cpu().numpy()[train_mask]
    cell_types_train = label_encoder.inverse_transform(y_train)

    thresholds = compute_thresholds(X_train, num_thresholds)

    # Compute phenotype mean barcodes
    thresholds_fold = {feat: thresholds[feat] for feat in top_k_features}
    phenotype_means = []
    phenotype_names = []

    for phenotype in np.unique(cell_types_train):
        subset = X_train[cell_types_train == phenotype]
        mean_vals = subset.mean(axis=0).to_frame().T
        barcode_row = categorize_values(mean_vals, thresholds_fold)[0]
        phenotype_names_all.add(phenotype)

        # Store barcode in per-phenotype list
        phenotype_to_barcodes.setdefault(phenotype, []).append(barcode_row)

        phenotype_means.append(barcode_row)
        phenotype_names.append(phenotype)

    barcode_matrix = np.vstack(phenotype_means)

    # Plot fold barcode
    plot_barcode_matrix(
        matrix=barcode_matrix,
        row_labels=phenotype_names,
        col_labels=top_k_features,
        color_map=color_map,
        filename=f"{save_dir}/barcode_matrix_train_fold_{fold_id}.png"
    )

# === Compute final barcodes via majority vote ===
phenotype_barcodes = {}  # final dict
final_matrix = []
final_phenotypes = sorted(phenotype_names_all)

for phenotype in final_phenotypes:
    barcode_stack = np.stack(phenotype_to_barcodes[phenotype])  # [num_folds, num_features]
    voted_barcode = mode(barcode_stack, axis=0).mode.flatten()
    phenotype_barcodes[phenotype] = voted_barcode
    final_matrix.append(voted_barcode)

# === Plot final consensus barcode ===
final_matrix_np = np.vstack(final_matrix)
plot_barcode_matrix(
    matrix=final_matrix_np,
    row_labels=final_phenotypes,
    col_labels=top_k_features,
    color_map=color_map,
    filename=f"{save_dir}/barcode_matrix_majority_vote.png"
)

In [None]:
from scipy.spatial.distance import pdist, squareform
import seaborn as sns
import matplotlib.pyplot as plt

cell_group_color_dict = {
    "Astrocytes": "#1f77b4",         # blue
    "Oligodendrocytes": "#2ca02c",   # green
    "Neurons": "#d62728",            # red
    "Microglia": "#9467bd",          # purple
    "Endothelial cells": "#e377c2",  # pink
}

# Stack barcode vectors
barcode_matrix = np.vstack(list(phenotype_barcodes.values()))
phenotype_labels = list(phenotype_barcodes.keys())

# Compute pairwise cosine distances
dist_matrix = squareform(pdist(barcode_matrix, metric='cosine'))

# Plot heatmap
sns.heatmap(dist_matrix, annot=True, xticklabels=phenotype_labels, yticklabels=phenotype_labels, cmap='viridis')
plt.title("Pairwise Cosine Distance Between Phenotype Barcodes")
plt.show()

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

barcode_matrix = np.vstack(list(phenotype_barcodes.values()))
phenotype_labels = list(phenotype_barcodes.keys())

pca = PCA(n_components=2)
barcode_2d = pca.fit_transform(barcode_matrix)

plt.figure(figsize=(6, 6))
for i, label in enumerate(phenotype_labels):
    color = cell_group_color_dict.get(label, "#000000")
    plt.scatter(barcode_2d[i, 0], barcode_2d[i, 1], label=label, color=color, s=100)

plt.legend()
plt.title("PCA of Phenotype Barcodes")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.axis('equal')
plt.grid()
plt.savefig(f"{save_dir}/pca_phenotype_barcodes.png", dpi=600, bbox_inches='tight', pad_inches=0.0)
plt.show()

import scipy.cluster.hierarchy as sch

dists = pdist(barcode_matrix, metric='cosine')
linkage = sch.linkage(dists, method='average')

plt.figure(figsize=(6, 4))
sch.dendrogram(linkage, labels=phenotype_labels, color_threshold=0, above_threshold_color='black')
plt.title("Hierarchical Clustering of Phenotype Barcodes")
plt.ylabel("Cosine Distance")
plt.xticks(rotation=45, ha='right')
plt.savefig(f"{save_dir}/hierarchical_clustering_phenotype_barcodes.png", dpi=600, bbox_inches='tight', pad_inches=0.0)
plt.show()

In [None]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from scipy.stats import mode
from sklearn.cluster import AgglomerativeClustering

# Step 1: Stack barcodes and get labels
barcode_matrix = np.vstack(list(phenotype_barcodes.values()))
phenotype_labels = list(phenotype_barcodes.keys())

# Step 2: Compute pairwise distances
dist_matrix = squareform(pdist(barcode_matrix, metric='cosine'))  # or 'hamming'

# Step 3: Cluster similar barcodes
clustering = AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=0.01,
    linkage='average'
)
cluster_labels = clustering.fit_predict(dist_matrix)

# Step 4: Group phenotypes by cluster ID
group_to_phenos = {}
for pheno, group_id in zip(phenotype_labels, cluster_labels):
    group_to_phenos.setdefault(group_id, []).append(pheno)

# Step 5: Create readable group names
grouped_barcodes = {}
phenotype_to_group = {}

for group_id, pheno_list in group_to_phenos.items():
    group_name = "/".join(sorted(pheno_list))
    barcodes = np.vstack([phenotype_barcodes[p] for p in pheno_list])
    mean_barcode = np.round(barcodes.mean(axis=0)).astype(int)
    grouped_barcodes[group_name] = mean_barcode

    for pheno in pheno_list:
        phenotype_to_group[pheno] = group_name

# Optional preview
print("Grouped Barcode Names:")
print(list(grouped_barcodes.keys()))

print("\nPhenotype to Group Mapping:")
print(phenotype_to_group)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

mz_shortnames = {
    'mz_699.4500': 'PA 34:1 (699.45 m/z)',
    'mz_701.5500': 'PA 36:1 (701.55 m/z)',
    'mz_719.5500': 'PE 34:0 (719.55 m/z)',
    'mz_745.5500': 'PG 34:2 (745.55 m/z)',
    'mz_747.4500': 'PG 34:1 (747.45 m/z)',
    'mz_748.5500': 'PE 36:1 (748.55 m/z)',
    'mz_762.5500': 'PE 38:6 (762.55 m/z)',
    'mz_772.5500': 'PE 38:1 (772.55 m/z)',
    'mz_790.5500': 'PS 36:1 (790.55 m/z)',
    'mz_794.5500': 'PS 36:0 (794.55 m/z)',
    'mz_880.6500': 'PI 38:4 (880.65 m/z)',
    'mz_888.6500': 'ST C24:1 (888.65 m/z)',
    'mz_889.6500': 'ST 42:2 (889.65 m/z)',
    'mz_890.6500': 'ST C24:0 (890.65 m/z)',
    'mz_905.6500': 'ST C24:0 (OH) (905.65 m/z)',
}

col_labels_short = [mz_shortnames.get(mz, mz) for mz in top_k_features]

# ------------------------------------------------------------
# 1.  Mean barcode per merged group
# ------------------------------------------------------------
grouped_barcodes = {}
for gid, phenos in group_to_phenos.items():
    barcodes      = np.vstack([phenotype_barcodes[p] for p in phenos])
    mean_barcode  = np.round(barcodes.mean(axis=0)).astype(int)
    grouped_name  = "/".join(sorted(phenos))       # e.g. "Myeloid Cells/T Cells"
    grouped_barcodes[grouped_name] = mean_barcode

# ------------------------------------------------------------
# 2.  Colour map (0-num_thresholds discrete levels)
# ------------------------------------------------------------
'''unique_colors = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.XKCD_COLORS.values())
color_list = unique_colors[:num_thresholds + 1]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}'''

color_cmap = plt.cm.get_cmap('tab20')
color_list = [mcolors.to_hex(color_cmap(i)) for i in range(num_thresholds + 1)]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}

# ------------------------------------------------------------
# 3.  Barcode matrix + labels
# ------------------------------------------------------------
group_labels   = list(grouped_barcodes.keys())
barcode_matrix = np.vstack([grouped_barcodes[g] for g in group_labels])

#  ── wrap Y-tick labels: “A/B/C”  →  stacked text ───────────
wrapped_labels = [lbl.replace("/", "\n") for lbl in group_labels]

# ------------------------------------------------------------
# 4.  Plot
# ------------------------------------------------------------
def plot_barcode_matrix(matrix, row_labels, col_labels, color_map, filename):
    fig, ax = plt.subplots(figsize=(10, 5))

    for r in range(matrix.shape[0]):
        for c in range(matrix.shape[1]):
            lvl = matrix[r, c]
            ax.scatter(
                c, r,
                color=color_map.get(lvl, "black"),
                marker="D",
                s=160,
                edgecolors="black",
                linewidth=0.6,
            )

    ax.set_xticks(range(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha="right", fontsize=14)

    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_yticklabels(row_labels, fontsize=14, va="center")

    ax.set_xlim(-0.5, len(col_labels) - 0.5)
    ax.set_ylim(-0.5, len(row_labels) - 0.5)

    ax.grid(True, linestyle="--", linewidth=0.4, alpha=0.3)
    ax.tick_params(axis="both", which="major", length=0)

    plt.tight_layout()
    plt.savefig(filename, dpi=600, bbox_inches="tight", pad_inches=0.02)
    plt.show()

# ------------------------------------------------------------
# 5.  Save figure
# ------------------------------------------------------------
output_path = os.path.join(save_dir, "barcode_matrix_grouped.png")
plot_barcode_matrix(
    matrix=barcode_matrix,
    row_labels=wrapped_labels,      # <- stacked Y-labels
    col_labels=col_labels_short,
    color_map=color_map,
    filename=output_path,
)

np.save(os.path.join(save_dir, f"barcode_matrix_grouped.npy"),barcode_matrix)

In [None]:
import matplotlib.pyplot as plt

# Create diamond marker legend handles
legend_handles = [
    plt.Line2D([0], [0],
               marker='D',
               color='w',
               markerfacecolor=color_map[i],
               markeredgecolor='black',
               markersize=10,
               label=f"Level {i + 1}")
    for i in range(num_thresholds + 1)
][::-1]  # Reverse for top-down order

# Compact vertical layout
fig, ax = plt.subplots(figsize=(2.2, len(legend_handles) * 0.45))
ax.axis('off')

ax.legend(
    handles=legend_handles,
    loc='center',
    fontsize=12,
    frameon=False,
    handletextpad=0.5,
    labelspacing=0.3,
    borderpad=0.1,
    title="Barcode Level",
    title_fontsize=13
)

plt.savefig(f"{save_dir}/scatter_barcode_legend.png", dpi=600, bbox_inches='tight', pad_inches=0)
plt.show()