### Metabolic Barcoding with GSR

In [None]:
import numpy as np
import os

data_dir = r"..\data\high_dose"
lipids_gsr = np.load(os.path.join(data_dir, "lipid_gsr.npy"))
base_dir = r"Y:\coskun-lab\Efe\GSR_MSI\experiments\AmberGen Tri-Modal Lipid Data\high_low_no_2"
sorted_features = np.load(os.path.join(base_dir,'sorted_features.npy'))[:15]
sorted_mzs = [float(c.replace("mz_", "")) for c in sorted_features]

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import re

results_dir = r"..\results\high_dose"

# ----------------------------
# CONFIGURATION
# ----------------------------
mask_dir = os.path.join(results_dir, "masks")
tile_size = 1000
stitched_size = (4600, 4600)  # Final output size with edge tiles
target_h, target_w = stitched_size

# ----------------------------
# LOAD MASK TILES (correct numeric order)
# ----------------------------
def extract_tile_index(filename):
    match = re.search(r"mask_tile_(\d+)\.npy", filename)
    return int(match.group(1)) if match else -1

tile_files = sorted(
    [f for f in os.listdir(mask_dir) if f.startswith("mask_tile_") and f.endswith(".npy")],
    key=extract_tile_index
)
mask_tiles = [np.load(os.path.join(mask_dir, f)) for f in tile_files]

# ----------------------------
# INFER COORDINATES IN RASTER ORDER
# ----------------------------
def infer_tile_coords(image_shape, tile_size):
    coords = []
    for i in range(0, image_shape[0], tile_size):
        for j in range(0, image_shape[1], tile_size):
            coords.append((i, j))  # top-left (y, x)
    return coords

coords = infer_tile_coords(stitched_size, tile_size)

# ----------------------------
# FUNCTION: STITCH MASKS WITH UNIQUE LABELS
# ----------------------------
def reconstruct_mask_with_unique_labels(mask_tiles, coords, stitched_size):
    combined_mask = np.zeros(stitched_size, dtype=np.int32)
    label_offset = 0

    for tile, (y, x) in zip(mask_tiles, coords):
        tile = tile.astype(np.int32)
        unique_labels = np.unique(tile)
        unique_labels = unique_labels[unique_labels > 0]

        new_tile = np.zeros_like(tile)
        for label_val in unique_labels:
            label_offset += 1
            new_tile[tile == label_val] = label_offset

        h, w = tile.shape
        combined_mask[y:y+h, x:x+w] = new_tile

    return combined_mask

# ----------------------------
# STITCH AND SAVE
# ----------------------------
stitched_mask = reconstruct_mask_with_unique_labels(mask_tiles, coords, stitched_size)

In [None]:
import numpy as np
import cv2
import os
import pandas as pd
from skimage.measure import regionprops
from tqdm import tqdm

# Resize target
target_h, target_w = 4600, 4600

# ----------------------------
# EXTRACT REGIONPROPS
# ----------------------------
props = regionprops(stitched_mask)
print(f"Found {len(props)} labeled cell regions...")

# Extract centroids first
df_expression = pd.DataFrame({
    'y_centroid': [prop.centroid[0] for prop in props],
    'x_centroid': [prop.centroid[1] for prop in props]
})

# ----------------------------
# EXTRACT LIPID FEATURES
# ----------------------------
print("Extracting Lipid features...")
for c, mz in enumerate(tqdm(sorted_mzs, desc="Lipids")):
    img = lipids_gsr[c, :, :]
    channel_means = [img[prop.coords[:, 0], prop.coords[:, 1]].mean() for prop in props]
    df_expression[f"mz_{mz:.4f}"] = channel_means

# ----------------------------
# SAVE EXPRESSION MATRIX
# ----------------------------
df_filename = os.path.join(results_dir, "lipid_gsr_df.csv")
df_expression.to_csv(df_filename, index=False)
print(f"Saved expression matrix to: {df_filename}")

In [None]:
import scanpy as sc
import anndata

# Extract feature columns (all intensity values)
feature_cols = [col for col in df_expression.columns if not col.endswith('_centroid')]

# Create AnnData
adata = anndata.AnnData(
    X=df_expression[feature_cols].values.astype(np.float32),  # expression matrix
    obs=df_expression[['x_centroid', 'y_centroid']].copy()     # metadata
)

# Assign feature (channel) names
adata.var_names = feature_cols

# Assign cell names
adata.obs_names = [f"cell_{i}" for i in range(adata.n_obs)]

# Add spatial coordinates
adata.obsm['spatial'] = df_expression[['x_centroid', 'y_centroid']].values
adata.write(os.path.join(results_dir, "lipid_gsr_adata.h5ad"))

In [None]:
import os
import pandas as pd
import scanpy as sc
import numpy as np
from anndata import AnnData

# === Load data ===
# High Dose
base_dir = r"..\results\high_dose"
adata_high = sc.read_h5ad(os.path.join(base_dir, "lipid_analysis", "lipid_gsr_adata.h5ad"))

# Low Dose
base_dir = r"..\results\low_dose"
adata_low = sc.read_h5ad(os.path.join(base_dir, "lipid_analysis", "lipid_gsr_adata.h5ad"))

# No Dose
base_dir = r"..\results\no_dose"
adata_no = sc.read_h5ad(os.path.join(base_dir, "lipid_analysis", "lipid_gsr_adata.h5ad"))

# Add condition labels
adata_high.obs["condition"] = "High"
adata_low.obs["condition"] = "Low"
adata_no.obs["condition"] = "No"

# Concatenate rows (cells)
adata_combined_gsr = adata_high.concatenate(
    [adata_low, adata_no],
    batch_key=None,
    index_unique=None
)

base_dir = r"..\results\high_low_no"
sorted_features = np.load(os.path.join(base_dir,'sorted_features.npy'))[:15]
sorted_mzs = [float(c.replace("mz_", "")) for c in sorted_features]

# === Done ===
print(adata_combined_gsr)
print("Features:", adata_combined_gsr.var_names[:10])
print("obs:", adata_combined_gsr.obs.columns)

In [None]:
import pandas as pd
import os
import scanpy as sc
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler 
import numpy as np

base_dir = r"..\results\high_low_no"
adata = sc.read_h5ad(os.path.join(base_dir, "combined_adata_leiden_merged.h5ad"))
adata.X = MaxAbsScaler().fit_transform(adata.X.astype(np.float32))

# Map cluster to phenotype
phenotype_map = {
    "0": "Neurons",
    "1": "Neurons",
    "2": "Astrocytes",
    "3": "Neurons",
    "4": "Neurons",
    "5": "Oligodendrocytes",
    "6": "Endothelial cells",
    "7": "Neurons",
    "8": "Endothelial cells",
    "9": "Neurons"
}

adata.obs['cell_phenotype'] = adata.obs['leiden_merged'].map(phenotype_map)

adata_combined_gsr.obs['cell_phenotype'] = adata.obs['cell_phenotype'].copy()

output_path = r"..\results\high_low_no\combined_adata_gsr.h5ad"
adata_combined_gsr.write(output_path)

In [None]:
import networkx as nx
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder
from scipy.spatial import cKDTree

def prepare_graph_data_cell_type(msi_adata, distance_threshold=10):
    """
    Create separate PyG graphs for WT and PS19, using lipid features,
    excluding cells labeled as 'Others' in 'cell_phenotype'.

    Parameters:
    - msi_adata: AnnData object with 'condition' and 'cell_phenotype' in .obs
    - distance_threshold: max distance for graph edges

    Returns:
    - graph_data_dict: {condition: PyG Data}
    - label_encoder: fitted LabelEncoder for consistent labels
    """
    graph_data_dict = {}
    label_encoder = LabelEncoder()

    # === Global label fitting on all valid cells (WT + PS19, excluding 'Others') ===
    valid_mask = msi_adata.obs['cell_phenotype'] != 'Others'
    label_encoder.fit(msi_adata.obs.loc[valid_mask, 'cell_phenotype'])

    for condition in ['High', 'Low', 'No']:
        adata_cond = msi_adata[
            (msi_adata.obs['condition'] == condition) &
            (msi_adata.obs['cell_phenotype'] != 'Others')
        ].copy()

        if adata_cond.n_obs == 0:
            continue

        # === Extract lipid features
        lipid_channels = [c for c in adata_cond.var_names if c.startswith("mz_")]
        lipid_X = adata_cond[:, lipid_channels].X
        lipid_X = lipid_X.toarray() if hasattr(lipid_X, 'toarray') else lipid_X
        features = torch.tensor(lipid_X, dtype=torch.float)

        # === Encode labels
        labels = torch.tensor(
            label_encoder.transform(adata_cond.obs['cell_phenotype']),
            dtype=torch.long
        )

        # === Build edges using cKDTree
        centroids = adata_cond.obs[['x_centroid', 'y_centroid']].values
        tree = cKDTree(centroids)
        pairs = tree.query_pairs(r=distance_threshold)

        G = nx.Graph()
        for i, (x, y) in enumerate(centroids):
            G.add_node(i, pos=(x, y))
        G.add_edges_from(pairs)

        edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)  # undirected

        # === Create PyG graph object
        graph_data = Data(x=features, edge_index=edge_index, y=labels)
        graph_data_dict[f"{condition}"] = graph_data

    return graph_data_dict, label_encoder

In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold

# === Reproducibility ===
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# === Prepare single-graph data ===
graph_data_dict, label_encoder = prepare_graph_data_cell_type(adata_combined_gsr, distance_threshold=10)  # uses 'cell_phenotype'
graphs = list(graph_data_dict.values())
full_batch = Batch.from_data_list(graphs)
y_all = full_batch.y.cpu().numpy()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = r"results\high_low_no"
os.makedirs(save_dir, exist_ok=True)

# === K-fold setup ===
k = 5
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
results = []

# === Cross-validation loop ===
for fold, (train_val_idx, test_idx) in enumerate(skf.split(np.zeros(len(y_all)), y_all)):
    print(f"\n[Fold {fold+1}/{k}]")

    # Split train/val
    val_split = int(0.15 * len(train_val_idx))
    np.random.shuffle(train_val_idx)
    val_idx = train_val_idx[:val_split]
    train_idx = train_val_idx[val_split:]

    # Assign masks
    full_batch.train_mask = torch.zeros(len(y_all), dtype=torch.bool)
    full_batch.val_mask = torch.zeros(len(y_all), dtype=torch.bool)
    full_batch.test_mask = torch.zeros(len(y_all), dtype=torch.bool)
    full_batch.train_mask[train_idx] = True
    full_batch.val_mask[val_idx] = True
    full_batch.test_mask[test_idx] = True

    # Save batch for evaluation
    torch.save(full_batch.cpu(), f"{save_dir}/fold_{fold+1}_graph_gsr.pt")

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.stats import mode

# === Parameters ===
num_folds = 5
num_thresholds = 9
phenotype_to_barcodes = {}  # collect per fold
phenotype_names_all = set()

# === Helper Functions ===
def compute_thresholds(data, num_thresholds):
    quantiles = np.linspace(0, 1, num_thresholds + 2)[1:-1]
    return {feature: data[feature].quantile(quantiles).values for feature in data.columns}

def categorize_values(data, thresholds):
    categorical_matrix = np.zeros(data.shape, dtype=int)
    for i, feature in enumerate(data.columns):
        for level, threshold in enumerate(thresholds[feature]):
            categorical_matrix[:, i] += (data[feature].values > threshold).astype(int)
    return categorical_matrix

def plot_barcode_matrix(matrix, row_labels, col_labels, color_map, filename):
    fig, ax = plt.subplots(figsize=(10, 5))
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            level = matrix[i, j]
            ax.scatter(
                j, i,
                color=color_map[level],
                marker='D',
                s=160,
                edgecolors="black",
                linewidth=0.6
            )
    ax.set_xticks(range(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha='right', fontsize=14)
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_yticklabels(row_labels, fontsize=14)
    ax.set_ylim(-0.5, len(row_labels) - 0.5)
    ax.set_xlabel("Metabolite Features", fontsize=16)
    ax.set_ylabel("Cell Phenotypes", fontsize=16)
    ax.grid(True, linestyle="--", linewidth=0.4, alpha=0.3)
    ax.tick_params(axis='both', which='major', length=0)
    plt.tight_layout()
    plt.savefig(filename, dpi=600, bbox_inches='tight', pad_inches=0.0)
    plt.show()

color_cmap = plt.cm.get_cmap('tab20')
color_list = [mcolors.to_hex(color_cmap(i)) for i in range(num_thresholds + 1)]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}

# === Loop through folds ===
for fold_id in range(1, num_folds + 1):
    graph_path = f"{save_dir}/fold_{fold_id}_graph_gsr.pt"
    if not os.path.exists(graph_path):
        continue

    graph = torch.load(graph_path)
    train_mask = graph.train_mask.cpu().numpy()

    X_all = pd.DataFrame(graph.x.cpu().numpy(), columns=sorted_features)
    X_train = X_all.loc[train_mask, sorted_features].reset_index(drop=True)
    y_train = graph.y.cpu().numpy()[train_mask]
    cell_types_train = label_encoder.inverse_transform(y_train)

    thresholds = compute_thresholds(X_train, num_thresholds)

    # Compute phenotype mean barcodes
    thresholds_fold = {feat: thresholds[feat] for feat in sorted_features}
    phenotype_means = []
    phenotype_names = []

    for phenotype in np.unique(cell_types_train):
        subset = X_train[cell_types_train == phenotype]
        mean_vals = subset.mean(axis=0).to_frame().T
        barcode_row = categorize_values(mean_vals, thresholds_fold)[0]
        phenotype_names_all.add(phenotype)

        # Store barcode in per-phenotype list
        phenotype_to_barcodes.setdefault(phenotype, []).append(barcode_row)

        phenotype_means.append(barcode_row)
        phenotype_names.append(phenotype)

    barcode_matrix = np.vstack(phenotype_means)

    # Plot fold barcode
    plot_barcode_matrix(
        matrix=barcode_matrix,
        row_labels=phenotype_names,
        col_labels=sorted_features,
        color_map=color_map,
        filename=f"{save_dir}/barcode_matrix_train_fold_{fold_id}_gsr.png"
    )

# === Compute final barcodes via majority vote ===
phenotype_barcodes = {}  # final dict
final_matrix = []
final_phenotypes = sorted(phenotype_names_all)

for phenotype in final_phenotypes:
    barcode_stack = np.stack(phenotype_to_barcodes[phenotype])  # [num_folds, num_features]
    voted_barcode = mode(barcode_stack, axis=0).mode.flatten()
    phenotype_barcodes[phenotype] = voted_barcode
    final_matrix.append(voted_barcode)

# === Plot final consensus barcode ===
final_matrix_np = np.vstack(final_matrix)
plot_barcode_matrix(
    matrix=final_matrix_np,
    row_labels=final_phenotypes,
    col_labels=sorted_features,
    color_map=color_map,
    filename=f"{save_dir}/barcode_matrix_majority_vote_gsr.png"
)

In [None]:
from scipy.spatial.distance import pdist, squareform
import seaborn as sns
import matplotlib.pyplot as plt

cell_group_color_dict = {
    "Astrocytes": "#1f77b4",         # blue
    "Oligodendrocytes": "#2ca02c",   # green
    "Neurons": "#d62728",            # red
    "Microglia": "#9467bd",          # purple
    "Endothelial cells": "#e377c2",  # pink
}

# Stack barcode vectors
barcode_matrix = np.vstack(list(phenotype_barcodes.values()))
phenotype_labels = list(phenotype_barcodes.keys())

# Compute pairwise cosine distances
dist_matrix = squareform(pdist(barcode_matrix, metric='cosine'))

# Plot heatmap
sns.heatmap(dist_matrix, annot=True, xticklabels=phenotype_labels, yticklabels=phenotype_labels, cmap='viridis')
plt.title("Pairwise Cosine Distance Between Phenotype Barcodes")
plt.show()

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

barcode_matrix = np.vstack(list(phenotype_barcodes.values()))
phenotype_labels = list(phenotype_barcodes.keys())

pca = PCA(n_components=2)
barcode_2d = pca.fit_transform(barcode_matrix)

plt.figure(figsize=(6, 6))
for i, label in enumerate(phenotype_labels):
    color = cell_group_color_dict.get(label, "#000000")
    plt.scatter(barcode_2d[i, 0], barcode_2d[i, 1], label=label, color=color, s=100)

plt.legend()
plt.title("PCA of Phenotype Barcodes")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.axis('equal')
plt.grid()
plt.savefig(f"{save_dir}/pca_phenotype_barcodes.png", dpi=600, bbox_inches='tight', pad_inches=0.0)
plt.show()

import scipy.cluster.hierarchy as sch

dists = pdist(barcode_matrix, metric='cosine')
linkage = sch.linkage(dists, method='average')

plt.figure(figsize=(6, 4))
sch.dendrogram(linkage, labels=phenotype_labels, color_threshold=0, above_threshold_color='black')
plt.title("Hierarchical Clustering of Phenotype Barcodes")
plt.ylabel("Cosine Distance")
plt.xticks(rotation=45, ha='right')
plt.savefig(f"{save_dir}/hierarchical_clustering_phenotype_barcodes_gsr.png", dpi=600, bbox_inches='tight', pad_inches=0.0)
plt.show()

In [None]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from scipy.stats import mode
from sklearn.cluster import AgglomerativeClustering

# Step 1: Stack barcodes and get labels
barcode_matrix = np.vstack(list(phenotype_barcodes.values()))
phenotype_labels = list(phenotype_barcodes.keys())

# Step 2: Compute pairwise distances
dist_matrix = squareform(pdist(barcode_matrix, metric='cosine'))  # or 'hamming'

# Step 3: Cluster similar barcodes
clustering = AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=0.01,
    linkage='average'
)
cluster_labels = clustering.fit_predict(dist_matrix)

# Step 4: Group phenotypes by cluster ID
group_to_phenos = {}
for pheno, group_id in zip(phenotype_labels, cluster_labels):
    group_to_phenos.setdefault(group_id, []).append(pheno)

# Step 5: Create readable group names
grouped_barcodes = {}
phenotype_to_group = {}

for group_id, pheno_list in group_to_phenos.items():
    group_name = "/".join(sorted(pheno_list))
    barcodes = np.vstack([phenotype_barcodes[p] for p in pheno_list])
    mean_barcode = np.round(barcodes.mean(axis=0)).astype(int)
    grouped_barcodes[group_name] = mean_barcode

    for pheno in pheno_list:
        phenotype_to_group[pheno] = group_name

# Optional preview
print("Grouped Barcode Names:")
print(list(grouped_barcodes.keys()))

print("\nPhenotype to Group Mapping:")
print(phenotype_to_group)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os

mz_shortnames = {
    'mz_699.4500': 'PA 34:1 (699.45 m/z)',
    'mz_701.5500': 'PA 36:1 (701.55 m/z)',
    'mz_719.5500': 'PE 34:0 (719.55 m/z)',
    'mz_745.5500': 'PG 34:2 (745.55 m/z)',
    'mz_747.4500': 'PG 34:1 (747.45 m/z)',
    'mz_748.5500': 'PE 36:1 (748.55 m/z)',
    'mz_762.5500': 'PE 38:6 (762.55 m/z)',
    'mz_772.5500': 'PE 38:1 (772.55 m/z)',
    'mz_790.5500': 'PS 36:1 (790.55 m/z)',
    'mz_794.5500': 'PS 36:0 (794.55 m/z)',
    'mz_880.6500': 'PI 38:4 (880.65 m/z)',
    'mz_888.6500': 'ST C24:1 (888.65 m/z)',
    'mz_889.6500': 'ST 42:2 (889.65 m/z)',
    'mz_890.6500': 'ST C24:0 (890.65 m/z)',
    'mz_905.6500': 'ST C24:0 (OH) (905.65 m/z)',
}

col_labels_short = [mz_shortnames.get(mz, mz) for mz in top_k_features]

# ------------------------------------------------------------
# 1.  Mean barcode per merged group
# ------------------------------------------------------------
grouped_barcodes = {}
for gid, phenos in group_to_phenos.items():
    barcodes      = np.vstack([phenotype_barcodes[p] for p in phenos])
    mean_barcode  = np.round(barcodes.mean(axis=0)).astype(int)
    grouped_name  = "/".join(sorted(phenos))       # e.g. "Myeloid Cells/T Cells"
    grouped_barcodes[grouped_name] = mean_barcode

# ------------------------------------------------------------
# 2.  Colour map (0-num_thresholds discrete levels)
# ------------------------------------------------------------
'''unique_colors = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.XKCD_COLORS.values())
color_list = unique_colors[:num_thresholds + 1]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}'''

color_cmap = plt.cm.get_cmap('tab20')
color_list = [mcolors.to_hex(color_cmap(i)) for i in range(num_thresholds + 1)]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}

# ------------------------------------------------------------
# 3.  Barcode matrix + labels
# ------------------------------------------------------------
group_labels   = list(grouped_barcodes.keys())
barcode_matrix = np.vstack([grouped_barcodes[g] for g in group_labels])

#  ── wrap Y-tick labels: “A/B/C”  →  stacked text ───────────
wrapped_labels = [lbl.replace("/", "\n") for lbl in group_labels]

# ------------------------------------------------------------
# 4.  Plot
# ------------------------------------------------------------
def plot_barcode_matrix(matrix, row_labels, col_labels, color_map, filename):
    fig, ax = plt.subplots(figsize=(10, 5))

    for r in range(matrix.shape[0]):
        for c in range(matrix.shape[1]):
            lvl = matrix[r, c]
            ax.scatter(
                c, r,
                color=color_map.get(lvl, "black"),
                marker="D",
                s=160,
                edgecolors="black",
                linewidth=0.6,
            )

    ax.set_xticks(range(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha="right", fontsize=14)

    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_yticklabels(row_labels, fontsize=14, va="center")

    ax.set_xlim(-0.5, len(col_labels) - 0.5)
    ax.set_ylim(-0.5, len(row_labels) - 0.5)

    ax.grid(True, linestyle="--", linewidth=0.4, alpha=0.3)
    ax.tick_params(axis="both", which="major", length=0)

    plt.tight_layout()
    plt.savefig(filename, dpi=600, bbox_inches="tight", pad_inches=0.02)
    plt.show()

# ------------------------------------------------------------
# 5.  Save figure
# ------------------------------------------------------------
output_path = os.path.join(save_dir, "barcode_matrix_grouped_gsr.png")
plot_barcode_matrix(
    matrix=barcode_matrix,
    row_labels=wrapped_labels,      # <- stacked Y-labels
    col_labels=col_labels_short,
    color_map=color_map,
    filename=output_path,
)

np.save(os.path.join(save_dir, f"barcode_matrix_grouped_gsr.npy"),barcode_matrix)

In [None]:
import os, torch, numpy as np, pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt, seaborn as sns

phenotype_to_group = {
    ph: grp
    for grp in grouped_barcodes          # <— (comes from earlier notebook cell)
    for ph in grp.split("/")
}

# ------------------------------------------------------------------
# 2.  Evaluation hyper-params
# ------------------------------------------------------------------
num_thresholds = 9
auc_per_fold   = []      # stores dicts
acc_per_fold   = []      # stores dicts

# ------------------------------------------------------------------
# 3.  Cross-validation
# ------------------------------------------------------------------
for fold_id in range(1, num_folds + 1):
    print(f"\n=== Fold {fold_id} ===")
    graph_path = f"{save_dir}/fold_{fold_id}_graph_gsr.pt"
    if not os.path.exists(graph_path):
        print("  … file missing, skipping.")
        continue

    # 3-A.  Load fold data ----------------------------------------
    graph      = torch.load(graph_path)
    train_mask = graph.train_mask.cpu().numpy()
    test_mask  = graph.test_mask.cpu().numpy()

    X_all   = pd.DataFrame(graph.x.cpu().numpy(), columns=sorted_features)
    X_train = X_all.loc[train_mask, sorted_features].reset_index(drop=True)
    X_test  = X_all.loc[test_mask,  sorted_features].reset_index(drop=True)

    y_train_ph = label_encoder.inverse_transform(graph.y.cpu().numpy()[train_mask])
    y_test_ph  = label_encoder.inverse_transform(graph.y.cpu().numpy()[test_mask])

    # 3-B.  Thresholds & phenotype barcodes -----------------------
    thresholds = compute_thresholds(X_all[sorted_features], num_thresholds)

    pheno_barcodes = {}
    for ph in np.unique(y_train_ph):
        cells = X_train[y_train_ph == ph]
        if not cells.empty:
            pheno_barcodes[ph] = categorize_values(
                cells.mean().to_frame().T, thresholds)[0]

    # 3-C.  Merge phenotype barcodes into group barcodes ----------
    group_barcodes = {}
    for grp in np.unique(list(phenotype_to_group.values())):
        members = [ph for ph in pheno_barcodes if phenotype_to_group[ph] == grp]
        if members:
            arr = np.vstack([pheno_barcodes[m] for m in members])
            group_barcodes[grp] = np.round(arr.mean(axis=0)).astype(int)

    group_names    = list(group_barcodes.keys())
    grp_to_idx     = {g: i for i, g in enumerate(group_names)}
    barcode_matrix = np.vstack([group_barcodes[g] for g in group_names])

    # 3-D.  Discretise test cells --------------------------------
    test_barcodes = categorize_values(X_test, thresholds)

    probs_groups, pred_groups, true_ph_labels = [], [], []

    for bc, true_ph in zip(test_barcodes, y_test_ph):
        dist = cdist([bc], barcode_matrix, metric="cosine")[0]
        dist = np.where(np.isnan(dist) | np.isinf(dist), 1.0, dist)

        inv      = 1.0 / (dist + 1e-6)
        inv[np.isinf(inv)] = 1e6
        s        = inv.sum()
        pvec     = inv / s if s > 0 else np.ones_like(inv) / len(inv)
        pvec     = np.nan_to_num(pvec, nan=1.0/len(inv))

        probs_groups.append(pvec)
        true_ph_labels.append(true_ph)
        pred_groups.append(group_names[np.argmax(pvec)])

    probs_groups = np.vstack(probs_groups)

    # 3-E.  ROC-AUC per phenotype ---------------------------------
    phenotype_names = sorted(np.unique(y_test_ph))
    ph_probs        = np.zeros((len(true_ph_labels), len(phenotype_names)))

    for j, ph in enumerate(phenotype_names):
        grp_idx         = grp_to_idx[phenotype_to_group.get(ph, "Others")]
        ph_probs[:, j]  = probs_groups[:, grp_idx]

    ph_probs = np.nan_to_num(ph_probs)

    y_true_bin = label_binarize(true_ph_labels, classes=phenotype_names)

    auc_row = {"fold_id": fold_id}
    for i, ph in enumerate(phenotype_names):
        try:
            auc_row[ph] = roc_auc_score(y_true_bin[:, i], ph_probs[:, i])
        except ValueError:
            auc_row[ph] = np.nan
    auc_per_fold.append(auc_row)

    # 3-F.  Accuracy per phenotype -------------------------------
    acc_row = {"fold_id": fold_id}
    pred_grp_arr = np.array(pred_groups)
    true_grp_arr = np.array([phenotype_to_group.get(ph, "Others")
                             for ph in true_ph_labels])
    for ph in phenotype_names:
        mask = (np.array(true_ph_labels) == ph)
        if mask.sum() == 0:
            acc_row[ph] = np.nan
        else:
            acc_row[ph] = (pred_grp_arr[mask] == true_grp_arr[mask]).mean()
    acc_per_fold.append(acc_row)

# ------------------------------------------------------------------
# 4.  Convert to DataFrames
# ------------------------------------------------------------------
auc_df = pd.DataFrame(auc_per_fold)
acc_df = pd.DataFrame(acc_per_fold)

def melt_metric(df, metric_name):
    m = df.melt(id_vars="fold_id", var_name="Phenotype",
                value_name=metric_name)
    m = m[m["Phenotype"] != "fold_id"]
    m = m[m[metric_name].notna()]
    m["Color"] = m["Phenotype"].map(cell_group_color_dict).fillna("gray")
    return m

auc_melt = melt_metric(auc_df, "AUC")
acc_melt = melt_metric(acc_df, "ACC")

# ------------------------------------------------------------------
# 5.  Plot helpers
# ------------------------------------------------------------------
def boxplot_metric(melted, metric, title, filename):
    plt.figure(figsize=(12, 6))
    sns.boxplot(data=melted,
                x="Phenotype", y=metric,
                palette=melted.set_index("Phenotype")["Color"].to_dict())
    plt.title(title, fontsize=18)
    plt.ylabel(metric, fontsize=16)
    plt.xlabel("Phenotype", fontsize=16)
    plt.xticks(rotation=45, ha="right", fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"{save_dir}/{filename}.png", dpi=600)
    plt.show()

def barplot_metric_with_errors(melted, metric, title, filename):
    summary = (melted.groupby("Phenotype")
                      .agg(mean=(metric, "mean"),
                           std =(metric, "std"),
                           Color=("Color", "first"))
                      .reset_index())
    plt.figure(figsize=(12, 6))
    ax = sns.barplot(data=summary, x="Phenotype", y="mean",
                     palette=summary.set_index("Phenotype")["Color"].to_dict(),
                     errorbar=None)
    for i, r in summary.iterrows():
        ax.errorbar(i, r["mean"], yerr=r["std"],
                    fmt="none", ecolor="black",
                    elinewidth=1.5, capsize=4)
    plt.title(title, fontsize=18)
    plt.ylabel(metric, fontsize=16)
    plt.xlabel("Phenotype", fontsize=16)
    plt.xticks(rotation=45, ha="right", fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"{save_dir}/{filename}.png", dpi=600)
    plt.show()

# ------------------------------------------------------------------
# 6.  Generate plots
# ------------------------------------------------------------------
boxplot_metric(auc_melt, "AUC",
               "Barcode Strategy AUC per Phenotype (K-Fold)",
               "barcode_auc_kfold_phenotypes_boxplot_gsr")

barplot_metric_with_errors(auc_melt, "AUC",
                           "Metabolic Barcoding AUC per Phenotype",
                           "barcode_auc_kfold_phenotypes_barplot_gsr")

boxplot_metric(acc_melt, "ACC",
               "Barcode Strategy Accuracy per Phenotype (K-Fold)",
               "barcode_acc_kfold_phenotypes_boxplot_gsr")

barplot_metric_with_errors(acc_melt, "ACC",
                           "Metabolic Barcoding Accuracy per Phenotype",
                           "barcode_acc_kfold_phenotypes_barplot_gsr")

### Compare Low-res vs GSR Barcodes

In [None]:
import numpy as np
import os

base_dir = r"..\results\high_low_no"

barcodes = np.load(os.path.join(base_dir, 'barcode_matrix_grouped.npy'))
barcodes_gsr = np.load(os.path.join(base_dir, 'barcode_matrix_grouped_gsr.npy'))

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder, StandardScaler, label_binarize
from sklearn.neighbors import NearestNeighbors
from scipy.stats import mode
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# === Paths & data loading ===
base_dir = r"..\results\high_low_no"

barcode_upsampled = np.load(os.path.join(base_dir, 'barcode_matrix_grouped.npy'))
barcode_gsr       = np.load(os.path.join(base_dir, 'barcode_matrix_grouped_gsr.npy'))

adata = sc.read_h5ad(os.path.join(base_dir, "combined_adata_leiden_merged.h5ad"))
adata_gsr = sc.read_h5ad(os.path.join(base_dir, "combined_adata_gsr.h5ad"))

# === Filter out "Others" phenotype ===
y_str = adata_gsr.obs["cell_phenotype"].to_numpy()
mask = y_str != 'Others'

y_str = y_str[mask]
adata_gsr = adata_gsr[mask].copy()
adata = adata[mask].copy()

# === Label encoding ===
label_encoder = LabelEncoder()
y_true = label_encoder.fit_transform(y_str)
phenotype_names = label_encoder.classes_

# === MSI input matrices ===
adata_filtered = adata[:, adata_gsr.var_names]  # match filtered rows & features
X_msi_raw = adata_filtered.X
X_msi_gsr = adata_gsr.X

# === Barcode projection ===
def project_msi_to_barcode(X_msi, barcode_matrix, threshold=7):
    barcode_matrix = np.where(barcode_matrix >= threshold, barcode_matrix, 0)
    return X_msi @ barcode_matrix.T

X_barcode_raw = project_msi_to_barcode(X_msi_raw, barcode_upsampled)
X_barcode_gsr = project_msi_to_barcode(X_msi_gsr, barcode_gsr)

# === Cross-validated classifier with per-class AUCs
def evaluate_barcode_classifier(X, y, model=None, n_splits=5):
    if model is None:
        model = RandomForestClassifier(n_estimators=10, random_state=42)

    X = StandardScaler().fit_transform(X)
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    accs, aucs = [], []
    perclass_aucs = []
    all_preds = np.zeros_like(y)

    for train_idx, test_idx in tqdm(skf.split(X, y), total=n_splits, desc="Cross-validation"):
        model.fit(X[train_idx], y[train_idx])
        y_pred = model.predict(X[test_idx])
        y_prob = model.predict_proba(X[test_idx])

        accs.append(accuracy_score(y[test_idx], y_pred))
        try:
            aucs.append(roc_auc_score(y[test_idx], y_prob, multi_class='ovr'))
        except ValueError:
            aucs.append(np.nan)

        y_test_bin = label_binarize(y[test_idx], classes=np.arange(len(phenotype_names)))
        try:
            aucs_per_class = roc_auc_score(y_test_bin, y_prob, average=None)
        except ValueError:
            aucs_per_class = np.full(len(phenotype_names), np.nan)

        perclass_aucs.append(aucs_per_class)
        all_preds[test_idx] = y_pred

    perclass_aucs = np.array(perclass_aucs).T  # shape (n_classes, n_folds)
    return np.mean(accs), np.nanmean(aucs), all_preds, accs, aucs, perclass_aucs

# === Run classification and evaluation
acc_raw, auc_raw, pred_raw, accs_raw, aucs_raw, perclass_auc_raw_cv = evaluate_barcode_classifier(X_barcode_raw, y_true)
acc_gsr, auc_gsr, pred_gsr, accs_gsr, aucs_gsr, perclass_auc_gsr_cv = evaluate_barcode_classifier(X_barcode_gsr, y_true)

# === Per-class accuracy from confusion matrix
cm_raw = confusion_matrix(y_true, pred_raw, labels=np.arange(len(phenotype_names)))
cm_gsr = confusion_matrix(y_true, pred_gsr, labels=np.arange(len(phenotype_names)))
accs_perclass_raw = cm_raw.diagonal() / cm_raw.sum(axis=1)
accs_perclass_gsr = cm_gsr.diagonal() / cm_gsr.sum(axis=1)

# === Mean per-class AUCs from CV
mean_perclass_auc_raw = np.nanmean(perclass_auc_raw_cv, axis=1)
mean_perclass_auc_gsr = np.nanmean(perclass_auc_gsr_cv, axis=1)

# === Report
print(f"\nRaw Barcode   → Accuracy: {acc_raw:.3f}, AUC: {auc_raw:.3f}")
print(f"GSR Barcode   → Accuracy: {acc_gsr:.3f}, AUC: {auc_gsr:.3f}")

print("\nPer-class accuracy:")
for i, name in enumerate(phenotype_names):
    print(f"{name:<25} Raw: {accs_perclass_raw[i]:.2f}   GSR: {accs_perclass_gsr[i]:.2f}")

print("\nPer-class AUCs (CV):")
for i, name in enumerate(phenotype_names):
    print(f"{name:<25} Raw: {mean_perclass_auc_raw[i]:.2f}   GSR: {mean_perclass_auc_gsr[i]:.2f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# === Prepare data for boxplots ===

# Overall accuracy & AUC boxplot data
overall_df = pd.DataFrame({
    "Score": accs_raw + accs_gsr + aucs_raw + aucs_gsr,
    "Metric": ["Accuracy"] * len(accs_raw) + ["Accuracy"] * len(accs_gsr) +
              ["AUC"] * len(aucs_raw) + ["AUC"] * len(aucs_gsr),
    "Method": ["Low-res"] * len(accs_raw) + ["Super-res"] * len(accs_gsr) +
              ["Low-res"] * len(aucs_raw) + ["Super-res"] * len(aucs_gsr)
})

# Per-class accuracy & AUC (no CV, so boxplots = bars)
perclass_df = pd.DataFrame({
    "Phenotype": phenotype_names.tolist() * 2,
    "Accuracy": np.concatenate([accs_perclass_raw, accs_perclass_gsr]),
    "AUC": np.concatenate([mean_perclass_auc_raw, mean_perclass_auc_gsr]),
    "Method": ["Low-res"] * len(phenotype_names) + ["Super-res"] * len(phenotype_names)
})

# === Plot overall Accuracy & AUC boxplots
plt.figure(figsize=(6, 4))
sns.barplot(data=overall_df, x="Metric", y="Score", hue="Method", palette="Set2")
plt.title("Overall Barcode Accuracy and AUC")
plt.tight_layout()
plt.savefig(os.path.join(base_dir, "overall_acc_auc_boxplot.png"), dpi=300)
plt.show()

# === Plot per-class Accuracy boxplot
plt.figure(figsize=(max(6, len(phenotype_names)*0.75), 4))
sns.barplot(data=perclass_df, x="Phenotype", y="Accuracy", hue="Method", palette="Set2", capsize=0.1)
plt.title("Per-Class Barcode Accuracy")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(os.path.join(base_dir, "per_class_accuracy_barplot.png"), dpi=300)
plt.show()

# === Plot per-class AUC boxplot
plt.figure(figsize=(max(6, len(phenotype_names)*0.75), 4))
sns.barplot(data=perclass_df, x="Phenotype", y="AUC", hue="Method", palette="Set2", capsize=0.1)
plt.title("Per-Class Barcode AUC")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(os.path.join(base_dir, "per_class_auc_barplot.png"), dpi=300)
plt.show()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import networkx as nx
from scipy.spatial import cKDTree
from scipy.stats import entropy
from collections import Counter
import scanpy as sc

# ----------------------------------------------------------------------
# 0. USER SETTINGS
# ----------------------------------------------------------------------
save_dir               = r"..\results\high_low_no"
base_dir               = save_dir
distance_threshold     = 20
n_hops                 = 5
top_k                  = 10
min_dist_between_roots = 300
bbox_colour            = 'red'

# ----------------------------------------------------------------------
# 1. LOAD DATA AND BUILD GRAPH
# ----------------------------------------------------------------------
adata_gsr = sc.read_h5ad(os.path.join(base_dir, "combined_adata_gsr.h5ad"))
adata_gsr = adata_gsr[adata_gsr.obs['condition'] == 'High'].copy()

coords = adata_gsr.obs[['x_centroid', 'y_centroid']].values
phenotypes = adata_gsr.obs['cell_phenotype'].astype(str).values
tree = cKDTree(coords)
pairs = tree.query_pairs(r=distance_threshold)

G = nx.Graph()
for idx, (x, y) in enumerate(coords):
    G.add_node(idx, pos=(x, y))
G.add_edges_from(pairs)
pos = nx.get_node_attributes(G, 'pos')

# ----------------------------------------------------------------------
# 2. COMPUTE DIVERSITY PER NEIGHBORHOOD
# ----------------------------------------------------------------------
def compute_diversity(labels):
    label_counts = np.array(list(Counter(labels).values()))
    probs = label_counts / label_counts.sum()
    return entropy(probs)

id_to_phenotype = dict(enumerate(phenotypes))
diversity_scores = []

for node in G.nodes:
    neighborhood = list(nx.single_source_shortest_path_length(G, node, cutoff=n_hops).keys())
    phenos = [id_to_phenotype[n] for n in neighborhood]
    if len(set(phenos)) > 1:
        score = compute_diversity(phenos)
        diversity_scores.append((node, score))

diversity_scores.sort(key=lambda x: -x[1])

# ----------------------------------------------------------------------
# 3. SELECT SPATIALLY DISTINCT ROOTS
# ----------------------------------------------------------------------
selected_roots = []
selected_coords = []

for node, score in diversity_scores:
    candidate_coord = np.array(pos[node])
    if all(np.linalg.norm(candidate_coord - np.array(c)) >= min_dist_between_roots for c in selected_coords):
        selected_roots.append((node, score))
        selected_coords.append(candidate_coord)
    if len(selected_roots) == top_k:
        break

print("Top diverse and spatially distinct neighborhoods:")
for i, (node, score) in enumerate(selected_roots):
    print(f"#{i+1}: Node {node} at {pos[node]} → diversity = {score:.3f}")

# ----------------------------------------------------------------------
# 4. PHENOTYPE COLORS
# ----------------------------------------------------------------------
phenotype_colors = {
    "Astrocytes": "#1f77b4",
    "Oligodendrocytes": "#2ca02c",
    "Neurons": "#d62728",
    "Microglia": "#9467bd",
    "Endothelial cells": "#e377c2",
}
node_colors = [phenotype_colors.get(label, "#7f7f7f") for label in phenotypes]

# ----------------------------------------------------------------------
# 5. SINGLE PLOT WITH ALL BOUNDING BOXES + LABELS
# ----------------------------------------------------------------------
os.makedirs(save_dir, exist_ok=True)
fig, ax = plt.subplots(figsize=(12, 12))
ax.set_facecolor('black')
ax.set_aspect('equal')
plt.axis('off')

# Draw all nodes
nx.draw_networkx_nodes(G, pos, ax=ax,
                       node_color=node_colors,
                       node_size=1.0, linewidths=0)

# Draw and label each neighborhood box
for i, (root_node, _) in enumerate(selected_roots):
    sub_nodes = list(nx.single_source_shortest_path_length(G, root_node, cutoff=n_hops).keys())
    sub_xy = np.array([pos[n] for n in sub_nodes])
    min_x, min_y = sub_xy.min(axis=0)
    max_x, max_y = sub_xy.max(axis=0)

    # Bounding box
    ax.add_patch(patches.Rectangle(
        (min_x, min_y),
        max_x - min_x,
        max_y - min_y,
        linewidth=2,
        edgecolor=bbox_colour,
        facecolor='none',
        zorder=10
    ))

    # Label (box number)
    ax.text(
    min_x + 5, min_y + 10,
    str(i + 1),
    color='white',
    fontsize=12,
    weight='bold',
    zorder=11,
    bbox=dict(
        facecolor='black',
        edgecolor='none',
        boxstyle='round,pad=0.2'  # reduce padding here
        )
    )

ax.set_title(f"Top {top_k} Diverse Neighborhoods", color='white')
ax.invert_yaxis()

plt.savefig(os.path.join(save_dir, 'subgraph_visualizations', f"top_{top_k}_diverse_neighborhoods_overlay_labeled.png"),
            dpi=600, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.patches import Wedge
import networkx as nx
from sklearn.preprocessing import MinMaxScaler
from scipy.spatial import cKDTree
from scipy.spatial.distance import cdist
import matplotlib.colors as mcolors
from matplotlib import cm

def find_nearest_node(pos_dict, coord):
    node_ids    = list(pos_dict.keys())
    node_coords = np.array([pos_dict[n] for n in node_ids])
    _, idx      = cKDTree(node_coords).query(coord)
    return node_ids[idx]

# === Plot simple barcode scatter ===
def plot_scatter_barcode(ax, center, barcode_row, color_map, dot_size=40, spacing=0.5):
    """
    Plot a scatter barcode at the cell's spatial position based on a precomputed barcode row.
    """
    n = len(barcode_row)
    x_start = center[0] - (n * spacing) / 2
    y = center[1]
    for j in range(n):
        level = barcode_row[j]
        color = color_map.get(level, 'black')
        ax.scatter(x_start + j * spacing, y, color=color, s=dot_size, edgecolors='black', linewidths=0.5, marker='D')

def draw_cont(ax, ctr, row, s=100, sp=0.6):
    x0 = ctr[0] - (len(row)*sp)/2
    for j, val in enumerate(row):
        ax.scatter(x0 + j*sp, ctr[1], c=hot(val), s=s,
                   marker="D", edgecolors="black", linewidths=0.5)

hot = cm.get_cmap("hot")

# === Pie color dictionary ===
cell_group_color_dict = {
    "Astrocytes": "#1f77b4",         # blue
    "Oligodendrocytes": "#2ca02c",   # green
    "Neurons": "#d62728",            # red
    "Microglia": "#9467bd",          # purple
    "Endothelial cells": "#e377c2",  # pink
}

# === Subgraph visualization ===
def plot_subgraph_pie_and_bars(
    raw_adata,
    gsr_adata,
    top_k_features,
    target_coord,
    save_dir,
    raw_barcodes,
    gsr_barcodes,
    color_map,
    condition,
    n_hops=2,
    distance_threshold=20,
    spacing=0.8,
    markersize=8
):
    os.makedirs(save_dir, exist_ok=True)

    # Map phenotype to barcode vector
    pheno_names = list(raw_barcodes.keys())
    pheno_matrix = np.vstack([raw_barcodes[p] for p in pheno_names])

    # Graph
    G, pos = create_graph(gsr_adata, distance_threshold)
    node_id = find_nearest_node(pos, target_coord)
    neighbors = list(nx.single_source_shortest_path_length(G, node_id, cutoff=n_hops).keys())
    G_sub = G.subgraph(neighbors)
    pos_sub = {nid: pos[nid] for nid in neighbors}

    node_colors_sub = [cell_group_color_dict.get(raw_adata.obs.iloc[n]['cell_phenotype'], 'gray') for n in neighbors]

    X_all = gsr_adata[:, top_k_features].X
    if not isinstance(X_all, np.ndarray):
        X_all = X_all.toarray()
    
    cont_scaler   = MinMaxScaler().fit(X_all)     # 0–1 per metabolite
    cont_barcodes = cont_scaler.transform(X_all)  # shape [n_cells, k_features]

    # === Plot 1: Subgraph Cell Group Colors ===
    fig1, ax1 = plt.subplots(figsize=(10, 10))
    ax1.set_facecolor('black')
    for u, v in G_sub.edges:
        x1, y1 = pos_sub[u]
        x2, y2 = pos_sub[v]
        ax1.plot([x1, x2], [y1, y2], color='lightgray', linewidth=1.0, alpha=0.8)
    for i, nid in enumerate(neighbors):
        x, y = pos_sub[nid]
        color = node_colors_sub[i]
        edgecolor = 'black'
        linewidth = 2.5 if nid == node_id else 1
        ax1.plot(x, y, marker='o', markersize=markersize, color=color, markeredgecolor=edgecolor, markeredgewidth=linewidth)
    ax1.set_title(f"Subgraph: Cell Groups near {target_coord}", color='white')
    ax1.set_aspect('equal')
    ax1.axis('off')
    ax1.invert_yaxis()
    fig1.savefig(f"{save_dir}/subgraph_groups_{int(target_coord[0])}_{int(target_coord[1])}.png", dpi=600, bbox_inches='tight', pad_inches=0)

    # === Plot 2: Spatial Scatter Raw Barcodes ===
    fig2, ax2 = plt.subplots(figsize=(10, 10))
    ax2.set_facecolor('black')
    
    # Draw edges
    for u, v in G_sub.edges:
        x1, y1 = pos_sub[u]
        x2, y2 = pos_sub[v]
        ax2.plot([x1, x2], [y1, y2], color='lightgray', linewidth=1.0, alpha=0.8)
    
    # Map cell group to row index in the scatter barcode matrix
    cell_group_to_idx = {name: i for i, name in enumerate(raw_barcodes.keys())}
    
    # Draw the barcode or a gray dot depending on group
    for nid in G_sub.nodes:
        center = pos_sub[nid]
        cell_group = raw_adata.obs.iloc[nid]['cell_phenotype']
        
        if cell_group in cell_group_to_idx:
            row_idx = cell_group_to_idx[cell_group]
            barcode_row = np.array(list(raw_barcodes.values()))[row_idx]
            plot_scatter_barcode(ax2, center, barcode_row, color_map, dot_size=100, spacing=spacing)
        else:
            # Gray dot for 'Others' or unknown groups
            ax2.plot(center[0], center[1], marker='o', markersize=markersize, color='lightgray', alpha=0.8)
    
    ax2.set_title(f"Subgraph: Group Scatter Barcodes near {target_coord}", color='white')
    ax2.set_aspect('equal')
    ax2.axis('off')
    ax2.invert_yaxis()
    
    fig2.savefig(
        f"{save_dir}/{condition}_subgraph_raw_{int(target_coord[0])}_{int(target_coord[1])}.png",
        dpi=600, bbox_inches='tight', pad_inches=0
    )

    # === Plot 3: Spatial Scatter GSR Barcodes ===
    fig3, ax3 = plt.subplots(figsize=(10, 10))
    ax3.set_facecolor('black')
    
    # Draw edges
    for u, v in G_sub.edges:
        x1, y1 = pos_sub[u]
        x2, y2 = pos_sub[v]
        ax3.plot([x1, x2], [y1, y2], color='lightgray', linewidth=1.0, alpha=0.8)
    
    # Map cell group to row index in the scatter barcode matrix
    cell_group_to_idx = {name: i for i, name in enumerate(gsr_barcodes.keys())}
    
    # Draw the barcode or a gray dot depending on group
    for nid in G_sub.nodes:
        center = pos_sub[nid]
        cell_group = gsr_adata.obs.iloc[nid]['cell_phenotype']
        
        if cell_group in cell_group_to_idx:
            row_idx = cell_group_to_idx[cell_group]
            barcode_row = np.array(list(gsr_barcodes.values()))[row_idx]
            plot_scatter_barcode(ax3, center, barcode_row, color_map, dot_size=100, spacing=spacing)
        else:
            # Gray dot for 'Others' or unknown groups
            ax3.plot(center[0], center[1], marker='o', markersize=markersize, color='lightgray', alpha=0.8)
    
    ax3.set_title(f"Subgraph: Group Scatter Barcodes near {target_coord}", color='white')
    ax3.set_aspect('equal')
    ax3.axis('off')
    ax3.invert_yaxis()
    
    fig3.savefig(
        f"{save_dir}/{condition}_subgraph_gsr_{int(target_coord[0])}_{int(target_coord[1])}.png",
        dpi=600, bbox_inches='tight', pad_inches=0
    )

    #***Cell-specific continuous barcodes***
    fig4, ax4 = plt.subplots(figsize=(10, 10))
    ax4.set_facecolor('black')

    # Draw edges
    for u, v in G_sub.edges:
        x1, y1 = pos_sub[u]
        x2, y2 = pos_sub[v]
        ax4.plot([x1, x2], [y1, y2], color='lightgray', linewidth=1.0, alpha=0.8)
    
    for n in neighbors:
        draw_cont(ax4, pos_sub[n], cont_barcodes[n])
    ax4.set_aspect('equal'); ax4.axis('off'); ax4.invert_yaxis()
    fig4.savefig(f"{save_dir}/{condition}_subgraph_continuous_{int(target_coord[0])}_{int(target_coord[1])}.png", dpi=600, bbox_inches='tight', pad_inches=0)

# === Load data ===
condition = 'High'
base_dir = r"Y:\coskun-lab\Efe\GSR_MSI\experiments\AmberGen Tri-Modal Lipid Data\high_low_no_2"
barcode_raw = np.load(os.path.join(base_dir, 'barcode_matrix_grouped.npy'))
barcode_gsr = np.load(os.path.join(base_dir, 'barcode_matrix_grouped_gsr.npy'))

adata = sc.read_h5ad(os.path.join(base_dir, "combined_adata_leiden_merged.h5ad"))
adata_gsr = sc.read_h5ad(os.path.join(base_dir, "combined_adata_gsr.h5ad"))

# === Filter High dose
adata = adata[adata.obs['condition'] == condition].copy()
adata_gsr = adata_gsr[adata_gsr.obs['condition'] == condition].copy()

adata.obs['cell_phenotype'] = adata_gsr.obs['cell_phenotype'].copy()

# === Preserve full data for plotting
adata_plot = adata.copy()
adata_gsr_plot = adata_gsr.copy()

# === Mask for prediction (exclude 'Others')
mask = adata_gsr.obs["cell_phenotype"] != "Others"
adata = adata[mask].copy()
adata_gsr = adata_gsr[mask].copy()

# === Feature matching
adata = adata[:, adata_gsr.var_names]  # match features
top_k_features = list(adata_gsr_plot.var_names)

raw_barcodes = {
    "Astrocytes": barcode_raw[0],
    "Endothelial cells": barcode_raw[1],
    "Neurons": barcode_raw[2],
    "Oligodendrocytes": barcode_raw[3],
}

gsr_barcodes = {
    "Astrocytes": barcode_gsr[0],
    "Endothelial cells": barcode_gsr[1],
    "Neurons": barcode_gsr[2],
    "Oligodendrocytes": barcode_gsr[3],
}

num_thresholds = 9
color_cmap = plt.cm.get_cmap('tab20')
color_list = [mcolors.to_hex(color_cmap(i)) for i in range(num_thresholds + 1)]
color_map = {i: color_list[i] for i in range(num_thresholds + 1)}

save_dir = os.path.join(base_dir, "subgraph_visualizations")
plot_subgraph_pie_and_bars(
    raw_adata=adata,
    gsr_adata=adata_gsr,
    top_k_features=top_k_features,
    target_coord=(1784.5555555555557, 3042.5925925925926),
    save_dir=save_dir,
    raw_barcodes=raw_barcodes,
    gsr_barcodes=gsr_barcodes,
    color_map=color_map,
    condition=condition,
    n_hops=5,
    distance_threshold=20,
)

### High vs Low vs No Dose Analysis

In [None]:
import scanpy as sc
import numpy as np
import os

base_dir = r"..\results\high_low_no"
adata_combined_gsr = sc.read_h5ad(os.path.join(base_dir, "combined_adata_gsr.h5ad"))

# Map cluster to phenotype
henotype_map = {
    "0": "Neurons",
    "1": "Neurons",
    "2": "Astrocytes",
    "3": "Neurons",
    "4": "Neurons",
    "5": "Oligodendrocytes",
    "6": "Endothelial cells",
    "7": "Neurons",
    "8": "Endothelial cells",
    "9": "Neurons"
}

sorted_features = np.load(os.path.join(base_dir, "sorted_features.npy"))

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import ttest_ind

# ─────────────────────────────────────────────────────────────
# 1. Short names for top features
# ─────────────────────────────────────────────────────────────
mz_shortnames = {
    'mz_699.4500': 'PA 34:1 (699.45 m/z)',
    'mz_701.5500': 'PA 36:1 (701.55 m/z)',
    'mz_719.5500': 'PE 34:0 (719.55 m/z)',
    'mz_745.5500': 'PG 34:2 (745.55 m/z)',
    'mz_747.4500': 'PG 34:1 (747.45 m/z)',
    'mz_748.5500': 'PE 36:1 (748.55 m/z)',
    'mz_762.5500': 'PE 38:6 (762.55 m/z)',
    'mz_772.5500': 'PE 38:1 (772.55 m/z)',
    'mz_790.5500': 'PS 36:1 (790.55 m/z)',
    'mz_794.5500': 'PS 36:0 (794.55 m/z)',
    'mz_880.6500': 'PI 38:4 (880.65 m/z)',
    'mz_888.6500': 'ST C24:1 (888.65 m/z)',
    'mz_889.6500': 'ST 42:2 (889.65 m/z)',
    'mz_890.6500': 'ST C24:0 (890.65 m/z)',
    'mz_905.6500': 'ST C24:0 (OH) (905.65 m/z)',
}

col_labels_short = [mz_shortnames.get(mz, mz) for mz in top_k_features]

# ─────────────────────────────────────────────────────────────
# 2. Build long-form DataFrame for seaborn
# ─────────────────────────────────────────────────────────────
condition_list = adata_combined_gsr.obs['condition'].unique().tolist()
long_df = []

for condition in condition_list:
    X_cond = adata_combined_gsr[adata_combined_gsr.obs['condition'] == condition][:, top_k_features].X
    for i, feature in enumerate(top_k_features):
        values = np.array(X_cond[:, i]).flatten()
        long_df.append(pd.DataFrame({
            'Feature': mz_shortnames.get(feature, feature),
            'Condition': condition,
            'Value': values
        }))

plot_df = pd.concat(long_df, axis=0)

# ─────────────────────────────────────────────────────────────
# 3. Plot
# ─────────────────────────────────────────────────────────────
plt.figure(figsize=(18, 8))
ax = sns.boxplot(data=plot_df, x='Feature', y='Value', hue='Condition', showfliers=False)
plt.xticks(rotation=45, ha='right', fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel('Feature', fontsize=14)
plt.ylabel('Intensity', fontsize=14)
plt.title('Top 15 Lipid Distributions by Condition', fontsize=16)
plt.legend(fontsize=12)

# ─────────────────────────────────────────────────────────────
# 4. Add t-test annotations
# ─────────────────────────────────────────────────────────────
def pval_to_stars(p):
    if p > 0.05: return 'ns'
    elif 0.01 < p <= 0.05: return '*'
    elif 0.001 < p <= 0.01: return '**'
    elif 0.0001 < p <= 0.001: return '***'
    else: return '****'

def add_stat_annotation(ax, x1, x2, y, h, p_val, fontsize=10):
    stars = pval_to_stars(p_val)
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, color='black')
    ax.text((x1+x2)*0.5, y+h, stars, ha='center', va='bottom', fontsize=fontsize)

feature_order = plot_df['Feature'].unique().tolist()
feature_positions = {f: i for i, f in enumerate(feature_order)}

for feat in feature_order:
    sub = plot_df[plot_df['Feature'] == feat]
    y_max = sub['Value'].max()
    pairs = list(combinations(condition_list, 2))

    for j, (c1, c2) in enumerate(pairs):
        v1 = sub[sub['Condition'] == c1]['Value']
        v2 = sub[sub['Condition'] == c2]['Value']
        stat, pval = ttest_ind(v1, v2, equal_var=False)

        idx = feature_positions[feat]
        x1 = idx - 0.3 + 0.2 * condition_list.index(c1)
        x2 = idx - 0.3 + 0.2 * condition_list.index(c2)
        y = y_max * 1.10 + j * 0.05 * y_max
        h = 0.015 * y_max

        add_stat_annotation(ax, x1, x2, y, h, pval)

# ─────────────────────────────────────────────────────────────
# 5. Save and show
# ─────────────────────────────────────────────────────────────
plt.tight_layout()
plt.savefig(os.path.join(base_dir, 'feature_distributions_boxplot.png'), dpi=300)
plt.show()

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import ttest_ind

# ─────────────────────────────────────────────────────────────
# 1. Inputs
# ─────────────────────────────────────────────────────────────
mz_shortnames = {
    'mz_699.4500': 'PA 34:1 (699.45 m/z)',
    'mz_701.5500': 'PA 36:1 (701.55 m/z)',
    'mz_719.5500': 'PE 34:0 (719.55 m/z)',
    'mz_745.5500': 'PG 34:2 (745.55 m/z)',
    'mz_747.4500': 'PG 34:1 (747.45 m/z)',
    'mz_748.5500': 'PE 36:1 (748.55 m/z)',
    'mz_762.5500': 'PE 38:6 (762.55 m/z)',
    'mz_772.5500': 'PE 38:1 (772.55 m/z)',
    'mz_790.5500': 'PS 36:1 (790.55 m/z)',
    'mz_794.5500': 'PS 36:0 (794.55 m/z)',
    'mz_880.6500': 'PI 38:4 (880.65 m/z)',
    'mz_888.6500': 'ST C24:1 (888.65 m/z)',
    'mz_889.6500': 'ST 42:2 (889.65 m/z)',
    'mz_890.6500': 'ST C24:0 (890.65 m/z)',
    'mz_905.6500': 'ST C24:0 (OH) (905.65 m/z)',
}

col_labels_short = [mz_shortnames.get(mz, mz) for mz in top_k_features]

phenotype_colors = {
    "Astrocytes": "#1f77b4",
    "Oligodendrocytes": "#2ca02c",
    "Neurons": "#d62728",
    "Microglia": "#9467bd",
    "Endothelial cells": "#e377c2"
}

# ─────────────────────────────────────────────────────────────
# 2. Filter AnnData & build melted dataframe
# ─────────────────────────────────────────────────────────────
adata_filtered = adata_combined_gsr[adata_combined_gsr.obs['cell_phenotype'] != 'Others'].copy()
phenotype_list = [p for p in adata_filtered.obs['cell_phenotype'].unique().tolist() if p in phenotype_colors]

melted_df = []

for phenotype in phenotype_list:
    mask = adata_filtered.obs['cell_phenotype'] == phenotype
    X = adata_filtered[mask][:, top_k_features].X
    for i, feature in enumerate(top_k_features):
        values = np.array(X[:, i]).flatten()
        melted_df.append(pd.DataFrame({
            'Feature': mz_shortnames.get(feature, feature),
            'Cell Phenotype': [phenotype] * len(values),
            'Value': values
        }))

plot_df = pd.concat(melted_df, axis=0)

# ─────────────────────────────────────────────────────────────
# 3. Plot
# ─────────────────────────────────────────────────────────────
plt.figure(figsize=(18, 8))
ax = sns.boxplot(
    data=plot_df,
    x='Feature',
    y='Value',
    hue='Cell Phenotype',
    palette=phenotype_colors,
    showfliers=False
)

plt.xticks(rotation=45, ha='right', fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel('Feature', fontsize=14)
plt.ylabel('Intensity', fontsize=14)
plt.title('Top 15 Lipid Distributions by Cell Phenotype', fontsize=16)
plt.legend(fontsize=11)

# ─────────────────────────────────────────────────────────────
# 4. Stats & annotation
# ─────────────────────────────────────────────────────────────
def pval_to_stars(p):
    if p > 0.05: return 'ns'
    elif 0.01 < p <= 0.05: return '*'
    elif 0.001 < p <= 0.01: return '**'
    elif 0.0001 < p <= 0.001: return '***'
    else: return '****'

def add_stat_annotation(ax, x1, x2, y, h, p_val, fontsize=10):
    stars = pval_to_stars(p_val)
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, color='black')
    ax.text((x1+x2)*.5, y+h, stars, ha='center', va='bottom', fontsize=fontsize)

features_order = col_labels_short
feature_positions = {feat: idx for idx, feat in enumerate(features_order)}

for feat in features_order:
    df_feat = plot_df[plot_df['Feature'] == feat]
    y_max = df_feat['Value'].max()
    pairs = list(combinations(phenotype_list, 2))

    for j, (p1, p2) in enumerate(pairs):
        v1 = df_feat[df_feat['Cell Phenotype'] == p1]['Value']
        v2 = df_feat[df_feat['Cell Phenotype'] == p2]['Value']
        if len(v1) < 3 or len(v2) < 3:
            continue

        stat, pval = ttest_ind(v1, v2, equal_var=False)

        idx = feature_positions[feat]
        x1 = idx - 0.3 + 0.2 * phenotype_list.index(p1)
        x2 = idx - 0.3 + 0.2 * phenotype_list.index(p2)
        y = y_max * 1.10 + j * 0.05 * y_max
        h = 0.015 * y_max

        add_stat_annotation(ax, x1, x2, y, h, pval)

# ─────────────────────────────────────────────────────────────
# 5. Save and show
# ─────────────────────────────────────────────────────────────
plt.tight_layout()
plt.savefig(os.path.join(base_dir, 'cellphenotype_feature_distributions_boxplot.png'), dpi=300)
plt.show()

### Pseudotime Analysis

In [None]:
import scFates as scf
import scanpy as sc
import numpy as np
import os

base_dir = r"..\results\high_low_no"
adata_combined_gsr = sc.read_h5ad(os.path.join(base_dir, "combined_adata_gsr.h5ad"))

adata_filtered = adata_combined_gsr[adata_combined_gsr.obs['cell_phenotype'] != 'Others'].copy()
adata_temp = adata_filtered[::10].copy()
adata_temp.obs_names_make_unique()
sc.pp.filter_genes(adata_temp,min_cells=3)
sc.pp.normalize_total(adata_temp)
sc.pp.highly_variable_genes(adata_temp)
adata_temp.raw=adata_temp
adata_temp=adata_temp[:,adata_temp.var.highly_variable]
sc.pp.scale(adata_temp)
sc.pp.pca(adata_temp)
sc.pp.neighbors(adata_temp, n_neighbors=15)
sc.tl.umap(adata_temp)

In [None]:
sc.pl.umap(adata_temp,color="condition",cmap="RdBu_r")

In [None]:
sc.tl.leiden(adata_temp, resolution=0.5)
cluster_sizes = adata_temp.obs['leiden'].value_counts()

# Remove clusters with fewer than 30 cells
small_clusters = cluster_sizes[cluster_sizes < 300].index.tolist()
adata_temp = adata_temp[~adata_temp.obs['leiden'].isin(small_clusters)].copy()

sc.pl.umap(adata_temp,color="condition",cmap="RdBu_r", save=f'_condition.png')

In [None]:
phenotype_colors = {
    "Astrocytes": "#1f77b4",
    "Oligodendrocytes": "#2ca02c",
    "Neurons": "#d62728",
    "Microglia": "#9467bd",
    "Endothelial cells": "#e377c2"
}
sc.pl.umap(
    adata_temp,
    color="cell_phenotype",
    palette=phenotype_colors,
    save="_cell_phenotype_custom_colors.png"
)

In [None]:
scf.tl.curve(adata_temp,Nodes=30,use_rep="X_umap",ndims_rep=2)

In [None]:
scf.pl.graph(adata_temp, basis="umap", nodes=range(30))

In [None]:
sc.pl.umap(sc.AnnData(adata_temp.obsm["X_R"],obsm=adata_temp.obsm),color="1",cmap="Reds")

In [None]:
#scf.tl.root(adata_temp, root="mz_790.5500")
scf.tl.root(adata_temp, root=28)

In [None]:
scf.tl.pseudotime(adata_temp,n_jobs=10,n_map=10,seed=42)

In [None]:
t = adata_temp.obs['t']
adata_temp.obs['t_norm'] = (t - t.min()) / (t.max() - t.min())
sc.pl.umap(adata_temp,color="t_norm",save='_pseudotime.png')

In [None]:
scf.settings.figdir = base_dir
scf.pl.trajectory(adata_temp,basis="umap",arrows=False,arrow_offset=3,save='_pseudotime_trajectory.png')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from statannotations.Annotator import Annotator
import os

# Step 1: Extract condition colors from UMAP
umap_colors = adata_temp.uns["condition_colors"]
conditions = adata_temp.obs["condition"].cat.categories.tolist()
palette_dict = {cond: color for cond, color in zip(conditions, umap_colors)}

# Step 2: Prepare data and pairs for comparison
plot_data = adata_temp.obs[["condition", "t_norm"]]
pairs = [("No", "Low"), ("No", "High"), ("Low", "High")]

# Step 3: Create boxplot with statannotation
plt.figure(figsize=(6, 4))
ax = sns.boxplot(data=plot_data, x="condition", y="t_norm", palette=palette_dict)

annotator = Annotator(ax, pairs, data=plot_data, x="condition", y="t_norm")
annotator.configure(test='t-test_ind', text_format='star', loc='inside', comparisons_correction="bonferroni")
annotator.apply_and_annotate()

plt.ylabel("Pseudotime")
plt.xlabel("Dose Condition")
plt.tight_layout()
plt.savefig(os.path.join(base_dir, 'pseudotime_boxplot.png'), dpi=300)
plt.show()

In [None]:
sns.kdeplot(data=adata_temp.obs, x="t_norm", hue="condition", fill=True)
plt.savefig(os.path.join(base_dir, f'pseudotime_kdeplot.png'), dpi=300)

In [None]:
mz_shortnames = {
    'mz_699.4500': 'PA 34:1 (699.45 m/z)',
    'mz_701.5500': 'PA 36:1 (701.55 m/z)',
    'mz_719.5500': 'PE 34:0 (719.55 m/z)',
    'mz_745.5500': 'PG 34:2 (745.55 m/z)',
    'mz_747.4500': 'PG 34:1 (747.45 m/z)',
    'mz_748.5500': 'PE 36:1 (748.55 m/z)',
    'mz_762.5500': 'PE 38:6 (762.55 m/z)',
    'mz_772.5500': 'PE 38:1 (772.55 m/z)',
    'mz_790.5500': 'PS 36:1 (790.55 m/z)',
    'mz_794.5500': 'PS 36:0 (794.55 m/z)',
    'mz_880.6500': 'PI 38:4 (880.65 m/z)',
    'mz_888.6500': 'ST C24:1 (888.65 m/z)',
    'mz_889.6500': 'ST 42:2 (889.65 m/z)',
    'mz_890.6500': 'ST C24:0 (890.65 m/z)',
    'mz_905.6500': 'ST C24:0 (OH) (905.65 m/z)',
}

col_labels_short = [mz_shortnames.get(mz, mz) for mz in top_k_features]

# Create a mapping of original names to short annotated names
rename_dict = {mz: mz_shortnames.get(mz, mz) for mz in adata_temp.var_names}

# Apply to a temporary copy so original adata remains intact
adata_temp_annotated = adata_temp.copy()
adata_temp_annotated.var_names = [rename_dict[mz] for mz in adata_temp.var_names]

sc.pl.umap(
    adata_temp_annotated,
    color=adata_temp_annotated.var_names.tolist(),
    cmap="RdBu_r",
    ncols=4,
    use_raw=False,
    save='_features_grid.png'
)