_Neural Data Science_

Lecturer: Dr. Jan Lause, Prof. Dr. Philipp Berens

Tutors: Jonas Beck, Fabio Seel, Julius Würzler

Summer term 2025

Student names: Nina Lutz, Mathis Nommensen

LLM Disclaimer: Copilot was used to speed up the coding process in VS. ChatGPT was used for code review, nicer plots and debugging. ChatGPT was also used to aggregate information from papers on good practice and also used for quick lookups on biological information regarding genes and their functions.


# Project 3: Single-cell data analysis.

In [None]:
# %matplotlib notebook --> had to change this to compile
%matplotlib inline

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import string

import scipy as sp
from scipy import sparse
import sklearn

## add your packages ##

import time
import pickle
import memory_profiler
import seaborn as sns
import scipy.stats as stats
import umap.umap_ as umap
import igraph as ig
import leidenalg
#import umap

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import adjusted_rand_score
from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import recall_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, adjusted_mutual_info_score
from sklearn.feature_selection import VarianceThreshold
from sklearn.cross_decomposition import CCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, ElasticNet
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.pipeline import make_pipeline

from scipy.stats import pearsonr
from scipy.stats import nbinom

from statsmodels.stats.multitest import fdrcorrection

from gprofiler import GProfiler


%load_ext memory_profiler

from pathlib import Path

In [None]:
import black
import jupyter_black

jupyter_black.load(line_length=79)

In [None]:
variables_path = Path("../results/variables")
figures_path = Path("../results/figures")
data_path = Path("../data")

In [None]:
# plt.style.use("matplotlib_style.txt")
plt.style.use(
    "../matplotlib_style.txt"
)  # had to change this as well to succesfully compile

In [None]:
np.random.seed(42)

## Project and data description

In this project, we are going to work with the typical methods and pipelines used in single-cell data analysis and get some hands-on experience with the techniques used in the field. For that, we will be using Patch-seq multimodal data from cortical neurons in mice, from Scala et al. 2021 (https://www.nature.com/articles/s41586-020-2907-3#Sec7). From the different data modalities they used, we will focus on transcriptomics and electrophysiological data. 

In a real-world scenario, single cell data rarely comes with any "ground truth" labels. Often, the goal of researchers after measuring cells is to precisely classify them, grouping them into families or assigning them cell types based on the recorded features. This is normally done using usupervised methods, such as clustering methods.

However, the single-cell data that we are using in this project has some cell types assigned to each cell. These are not "ground truth" type annotations, but were one of the results from the original Scala et al. work. Still, we are going to use those annotations for validation (despite them not really being ground truth) to sanity-check some of our analyses, such as visualizations, clustering, etc. We will mainly work with cell types (`rna_types`, 77 unique types) and cell families (`rna_families`, 9 unique families).

From the transcriptomics mRNA counts, we will only work with the exon counts for simplicity. Some of the electrophysiological features are not high-quality recordings, therefore we will also filter them out.

## Import data

### Meta data

In [None]:
# META DATA

meta = pd.read_csv(data_path / "m1_patchseq_meta_data.csv", sep="\t")

cells = meta["Cell"].values

layers = meta["Targeted layer"].values.astype("str")
cre = meta["Cre"].values
yields = meta["Yield (pg/µl)"].values
yields[yields == "?"] = np.nan
yields = yields.astype("float")
depth = meta["Soma depth (µm)"].values
depth[depth == "Slice Lost"] = np.nan
depth = depth.astype(float)
thickness = meta["Cortical thickness (µm)"].values
thickness[thickness == 0] = np.nan
thickness = thickness.astype(float)
traced = meta["Traced"].values == "y"
exclude = meta["Exclusion reasons"].values.astype(str)
exclude[exclude == "nan"] = ""

mice_names = meta["Mouse"].values
mice_ages = meta["Mouse age"].values
mice_cres = np.array(
    [
        c if c[-1] != "+" and c[-1] != "-" else c[:-1]
        for c in meta["Cre"].values
    ]
)
mice_ages = dict(zip(mice_names, mice_ages))
mice_cres = dict(zip(mice_names, mice_cres))

print("Number of cells with measured depth:    ", np.sum(~np.isnan(depth)))
print("Number of cells with measured thickness:", np.sum(~np.isnan(thickness)))
print("Number of reconstructed cells:          ", np.sum(traced))

sliceids = meta["Slice"].values
a, b = np.unique(sliceids, return_counts=True)
assert np.all(b <= 2)
print("Number of slices with two cells:        ", np.sum(b == 2))

# Some consistency checks
assert np.all(
    [
        np.unique(meta["Date"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse age"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse gender"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse genotype"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse"].values[sliceids == s]).size == 1
        for s in sliceids
    ]
)

In [None]:
meta.columns

### "Ground truth labels"

In [None]:
# filter out low quality cells in term of RNA
print(
    "There are",
    np.sum(meta["RNA family"] == "low quality"),
    "cells with low quality RNA recordings.",
)
exclude_low_quality = meta["RNA family"] != "low quality"

In [None]:
rna_family = meta["RNA family"][exclude_low_quality]
rna_type = meta["RNA type"][exclude_low_quality]

In [None]:
print(len(np.unique(rna_family)))
print(len(np.unique(rna_type)))

In [None]:
pickle_in = open(data_path / "dict_rna_type_colors.pkl", "rb")
dict_rna_type_colors = pickle.load(pickle_in)

In [None]:
rna_type_colors = np.vectorize(dict_rna_type_colors.get)(rna_type)

### Transcriptomic data

In [None]:
# READ COUNTS
data_exons = pd.read_csv(
    data_path / "m1_patchseq_exon_counts.csv.gz", na_filter=False, index_col=0
)

assert all(cells == data_exons.columns)
genes = np.array(data_exons.index)

# filter out low quality cells in term of rna family
exonCounts = data_exons.values.transpose()[exclude_low_quality]
print("Count matrix shape (exon):  ", exonCounts.shape)

In [None]:
# GENE LENGTH

data = pd.read_csv(data_path / "gene_lengths.txt")
assert all(data["GeneID"] == genes)
exonLengths = data["exon_bp"].values

### Electrophysiological features

In [None]:
# EPHYS DATA

ephysData = pd.read_csv(data_path / "m1_patchseq_ephys_features.csv")
ephysNames = np.array(ephysData.columns[1:]).astype(str)
ephysCells = ephysData["cell id"].values
ephysData = ephysData.values[:, 1:].astype("float")
names2ephys = dict(zip(ephysCells, ephysData))
ephysData = np.array(
    [
        names2ephys[c] if c in names2ephys else ephysData[0] * np.nan
        for c in cells
    ]
)

print("Number of cells with ephys data:", np.sum(np.isin(cells, ephysCells)))

assert np.sum(~np.isin(ephysCells, cells)) == 0

In [None]:
# Filtering ephys data

features_exclude = [
    "Afterdepolarization (mV)",
    "AP Fano factor",
    "ISI Fano factor",
    "Latency @ +20pA current (ms)",
    "Wildness",
    "Spike frequency adaptation",
    "Sag area (mV*s)",
    "Sag time (s)",
    "Burstiness",
    "AP amplitude average adaptation index",
    "ISI average adaptation index",
    "Rebound number of APs",
]
features_log = [
    "AP coefficient of variation",
    "ISI coefficient of variation",
    "ISI adaptation index",
    "Latency (ms)",
]

X = ephysData[exclude_low_quality]
print(X.shape)
for e in features_log:
    X[:, ephysNames == e] = np.log(X[:, ephysNames == e])
X = X[:, ~np.isin(ephysNames, features_exclude)]

keepcells = ~np.isnan(np.sum(X, axis=1))
X = X[keepcells, :]
print(X.shape)

X = X - X.mean(axis=0)
ephysData_filtered = X / X.std(axis=0)

In [None]:
np.sum(np.isnan(ephysData_filtered))

# Research questions to investigate

**1) Inspect the data by computing key statistics.** For RNA counts, you can compute and plot statistics, e.g. total counts per cell, number of expressed genes per cell, mean count per gene, variance per gene, mean-variance relationship... See https://www.embopress.org/doi/full/10.15252/msb.20188746 for common quality control statistics. Keep in mind that the RNA data in this project is read counts, not UMI counts, so it is not supposed to follow a Poisson distribution.> To get an idea of the technical noise in the data, you can plot count distributions of single genes within cell types (like in the lecture). 

Similarly, you can compute and plot statistics over the electrophyiological data. Also, investigate the distribution of "ground truth" labels. Comment about other relevant metadata, and think if you can use it as some external validation for other analyses. If you do use other metadata throughout the project, explain why and what you get out of it. Take into account that certain features may not be very informative for our purposes (e.g. mouse age), so only choose features that provide you with useful information in this context. If you want to get additional information about the metadata, have a look at the extended data section in the original publication (e.g., cre-lines in Figure 1c in the extended data).

**2) Normalize & transform the data; select genes & apply PCA.** There are several ways of normalizing the RNA count data (Raw, CPM, CPMedian, RPKM, see https://www.reneshbedre.com/blog/expression_units.html, https://translational-medicine.biomedcentral.com/articles/10.1186/s12967-021-02936-w). Take into account that there are certain normalizations that only make sense for UMI data, but not for this read count data. You also explored different transformations in the assignment (none, log, sqrt). Compare how the different transformations change the two-dimensional visualization. After normalization and transformation, choose a set of highly variable genes (as demonstrated in the lecture) and apply PCA. Play with the number of selected genes and the number of PCA components, and again compare their effects on the two-dimensional visualization.

**3) Two-dimensional visualization.** To visualize the RNA count data after normalization, transformation, gene selection and PCA, try different methods (just PCA, t-SNE, UMAP, ..) and vary their parameters (exaggeration, perplexity, ..). Compare them using quantitative metrics (e.g., kNN accuracy in high-dim vs. two-dim, kNN recall). Please refer to Lause et al., 2024 (https://doi.org/10.1371/journal.pcbi.1012403) where many of these metrics are discussed and explained to make an informed choice on which metrics to use. Think about also using the electrophysiological features and other metadata to enhance different visualizations.

**4) Clustering.** To find cell types in the RNA count data, you will need to look for clusters. Try different clustering methods (leiden, GMM). Implement a negative binomial mixture model. For that you can follow a similar method that what is described in Harris et al. 2018 (https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.2006387#abstract0), with fixed r (r=2). Feel free to simplify the setup from the paper and not optimize over the set of important genes S but fix it instead, or skip the split and merge part of their clustering algorithm. A vanilla NBMM implementation should suffice. Take into account that the NBMM tries to cluster data that follows a negative binomial distribution. Therefore, it does not make sense to apply this clustering method to all kinds of normalized and transformed data. Please refer to the Harris et al. 2018 publication for the appropriate choice of normalization, and reflect on why this normalization makes sense. Evaluate your clustering results (metrics, compare number of clusters to original labels,...).

**5) Correlation between electrophysiological features and genes/PCs.** Finally, connect RNA counts and functional data: Most likely, there will be interesting relationships between the transcriptomic and electrophyiological features in this data. Find these correlations and a way of visualizing them. In studying correlations using the PCA-reduced version of the transcriptomics data, it could be interesting to study PC loadings to see which genes are dominating which PCs. For other advanced analyses, you can get inspitation from Kobak et al., 2021 (https://doi.org/10.1111/rssc.12494).
    

# Task 1

Outline of data exploration steps was mainly inspired by Luecken, M.D. & Theis, F.J. (2019)

## 1.1 QC Statistics per cell



In [None]:
# total counts per cell (count depth)

# exonCounts  # shape = 1232 cells x 42466 genes
total_counts_per_cell = exonCounts.sum(axis=1)

# plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(
    total_counts_per_cell, bins=50, color="lightblue", edgecolor="black"
)
axes[0].set_xlabel("Total RNA counts per cell")
axes[0].set_ylabel("Number of cells")
axes[0].set_title("Distribution of Total Counts Per Cell")

axes[1].bar(
    range(len(total_counts_per_cell)),
    total_counts_per_cell,
    width=1,
    alpha=0.5,
    color="blue",
)
axes[1].set_yscale("log")
axes[1].set_xlabel("Cell index")
axes[1].set_ylabel("Total RNA counts")
axes[1].set_title("Total Counts Per Cell")

# Rank-ordered plot of total counts per cell - Figure 2C in the paper
sorted_counts = np.sort(total_counts_per_cell)[::-1]

axes[2].plot(sorted_counts, color="blue")
axes[2].set_yscale("log")
axes[2].set_xlabel("Ranked Cell Index")
axes[2].set_ylabel("Total Counts (log scale)")
axes[2].set_title("Rank-ordered Total Counts per Cell")

*Interpretation:*

Left:
* The histogram reveals a right-skewed distribution of total RNA counts per cell, with most cells clustered between 0 and 5 million counts.

Middle:
* shows that most cells have a comparable number of total counts, with a few outliers having significantly higher and lower counts (e.g. cells around index 400 and 500).

Right:
* The rank-ordered curve shows a gradual decay in RNA counts, with a steep drop-off in low-quality cells at the tail—typical of single-cell data 
* most cells have comparable total cells
* the first ~20 cells have a lot more counts than the rest
* the last ~100 cells have very few counts

In [None]:
# number of expressed genes per cell
expressed_genes_per_cell = (exonCounts > 0).sum(axis=1)

# plot
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].scatter(range(len(expressed_genes_per_cell)), expressed_genes_per_cell)
axes[0].set_xlabel("Cell Index")
axes[0].set_ylabel("Expressed Genes")
axes[0].set_title("Total Number of Expressed Genes per Cell")

axes[1].hist(expressed_genes_per_cell, bins=50)
axes[1].set_xlabel("Total number of expressed genes per cell")
axes[1].set_ylabel("Number of cells")
axes[1].set_title("Distribution of Expressed Genes Per Cell")

*Interpretation:*

Left: Total number of expressed genes per cell
* The scatter plot shows considerable variation in the number of genes detected per cell, with a spread ranging from around 2,000 to over 17,000 genes, whereas most cells have between 5000 and 11,000 genes expressed.

Right: Distribution of expressed genes per cell
* Shows a similar pattern as the left plot, with most cells having between 2000 and 11,000 genes expressed.

In [None]:
# fraction of mitochondrial genes

# mt_like_genes = [g for g in genes if "mt" in g.lower()]
# print(mt_like_genes)
# after exploring the gene names, we can assume that mitochondrial genes start with "mt-" or "MT-"
mt_gene_mask = np.char.startswith(genes.astype(str).astype("U"), "mt-")

print("Number of mitochondrial genes found:", np.sum(mt_gene_mask))

# Sum counts over mitochondrial genes per cell
mt_counts_per_cell = exonCounts[:, mt_gene_mask].sum(axis=1)

# Fraction mitochondrial
fraction_mito = mt_counts_per_cell / total_counts_per_cell

# Print statistics
print(f"Mean mitochondrial fraction: {np.mean(fraction_mito):.3f}")
print("Cells with >20% mitochondrial:", np.sum(fraction_mito > 0.2))

# Plot
plt.figure(figsize=(8, 4))
plt.hist(fraction_mito, bins=50, color="salmon", edgecolor="black")
plt.xlabel("Fraction of Mitochondrial Counts per Cell")
plt.ylabel("Number of Cells")
plt.title("Distribution of Mitochondrial RNA Content")
plt.show()

*Interpretation:*

* most cells have a low fraction of mitochondrial genes (~1.7%), with a few outliers having a higher fraction (e.g. cells around index 400 and 500).

## 1.2 QC Statistics per gene


In [None]:
# mean expression across all cells
mean_expression_across_cells = exonCounts.mean(axis=0)
print("Mean expression across all cells:", mean_expression_across_cells.shape)

# variance across all cells
variance_expression_across_cells = exonCounts.var(axis=0)
print("Variance across all cells:", variance_expression_across_cells.shape)

# plot looked sparse, so log transform the data
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(
    mean_expression_across_cells + 1e-3,  # avoid log(0)
    variance_expression_across_cells + 1e-3,
    s=5,
    alpha=0.5,
)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("Mean Expression Across Cells (log scale)")
ax.set_ylabel("Variance Across Cells (log scale)")
ax.set_title("Mean vs Variance of Gene Expression (log scale)")
x_vals = np.linspace(
    min(mean_expression_across_cells + 1e-3),
    max(mean_expression_across_cells),
    100,
)
ax.plot(x_vals, x_vals, color="red", linestyle="--", label="y = x")
ax.legend()
plt.show()

*Interpretation:*

* The log-log plot reveals a clear mean-variance relationship, where genes with higher mean expression across cells also tend to have higher variance
* Almost all genes lie above the y=x line (variance = mean), indicating that the variance is generally higher than the mean for most genes (overdispersion).

In [None]:
# dropout rate / fraction of cells where a gene has zero counts
dropout_rate_per_gene = (exonCounts == 0).sum(axis=0) / exonCounts.shape[0]

# Quick summary
print(
    f"Mean dropout rate across all genes: {np.mean(dropout_rate_per_gene):.3f}"
)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot Histogram of dropout rates
axes[0].hist(dropout_rate_per_gene, bins=50, color="green", edgecolor="black")
axes[0].set_xlabel("Dropout Rate (Fraction of Cells with Zero Counts)")
axes[0].set_ylabel("Number of Genes")
axes[0].set_title("Dropout Rate per Gene")

# Scatter plot of dropout rate vs mean expression
axes[1].scatter(
    mean_expression_across_cells + 1e-3,  # avoid log(0)
    dropout_rate_per_gene,
    s=5,
    alpha=0.5,
)
axes[1].set_xscale("log")
axes[1].set_ylim(0, 1.05)
axes[1].set_xlabel("Mean Expression (log scale)")
axes[1].set_ylabel("Dropout Rate")
axes[1].set_title("Dropout Rate vs. Mean Expression")

plt.show()

*Interpretation:*

Left:
* the majority of genes have a very high dropout rate (mean ~83%), with a large number expressed in fewer than 15-20% of cells and a peak near 100% dropout, i.e. many genes are completely unexpressed in most cells

Right:
* Strong inverse relationship: Highly expressed genes are detected in nearly all cells (low dropout), while lowly expressed genes are often undetected (high dropout)

**count distributions of single genes within cell types**

In [None]:
# identify cell types
unique_cell_types = np.unique(rna_type)

# pick genes:
# Sst: somatostatin, an inhibitory neuron marker
# Slc17a7: an excitatory neuron marker (VGLUT1)
# Snap25: a pan-neuronal gene
genes_to_plot = ["Sst", "Slc17a7", "Snap25"]
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 8))

for ax, gene_name in zip(axes, genes_to_plot):
    gene_indices = np.where(genes == gene_name)[0]
    gene_index = gene_indices[0]

    # get gene counts
    gene_counts = exonCounts[:, gene_index]

    # group counts by cell type
    data = []
    for ct in unique_cell_types:
        mask = np.array(rna_type) == ct
        counts_in_group = gene_counts[mask]
        data.append(counts_in_group)

    ax.boxplot(
        data, tick_labels=unique_cell_types, showfliers=False
    )  # showfliers=False option removes extreme outliers
    ax.set_title(f"Expression of {gene_name}")

    ###M: added logscale to y-axis
    ax.set_yscale("log")  # log scale for better visibility

    ax.set_ylabel("Counts")
    ax.set_xlabel("Cell Type")
    ax.tick_params(axis="x", rotation=90, labelsize=8)

plt.show()

*Interpretation:*

Sst:
* Expression is highly enriched in specific inhibitoy cell types (notably those starting with "Sst"), occasionally enriched expression in other celltypes, but no broader pattern visible.

Slc17a7:
* Expression is enriched in many cells, but most obviously in cell types from L2 to L6, which belong to a family of excitatory neurons. This suggests that Slc17a7 is a marker for excitatory neurons in the cortex, as it is known to be expressed in glutamatergic neurons.

Snap25:
* This gene shows widespread expression across nearly all cell types, reflecting its function as a general neuronal marker involved in synaptic transmission

## 1.3 Statistics for electrophysiological features

In [None]:
# features overview

# List of cleaned electrophysiological features retained for analysis
ephysNames_filtered = ephysNames[
    ~np.isin(ephysNames, features_exclude)
]  # see above

print(
    "Remaining electrophysiological features (n = {}) for analysis".format(
        len(ephysNames_filtered)
    )
)
for i, name in enumerate(ephysNames_filtered, 1):
    print(f"{i:2d}. {name}")

In [None]:
# descriptive statistics of ephysiological features (keep in mind that data is already standardized)

# dictionary to collect stats
stats_dict = {
    "Mean": [],
    "Std": [],
    "Min": [],
    "Max": [],
    "Median": [],
    "Skewness": [],
}

# Compute stats per feature
for i in range(ephysData_filtered.shape[1]):
    # data = X[:, i]
    data = ephysData_filtered[:, i]

    # Collect statistics
    stats_dict["Mean"].append(np.mean(data))
    stats_dict["Std"].append(np.std(data))
    stats_dict["Min"].append(np.min(data))
    stats_dict["Max"].append(np.max(data))
    stats_dict["Median"].append(np.median(data))
    stats_dict["Skewness"].append(stats.skew(data))

# Convert to DataFrame
feature_stats_df = pd.DataFrame(stats_dict, index=ephysNames_filtered)

print("Basic statistics of electrophysiological features (standardized):")
display(feature_stats_df)

*Interpretation:*

Overall:
* mean and standard deviation are (roughly) 0 and 1, respectively, as the data is standardized
* Min/Max values show range in z-score units

Mentionable Observations:
* ISI adaptation index (ske=2.48), Rebound (mV) (2.05), Sag ratio (+2.06), and Rheobase (pA) (+1.64) are strongly right-skew, indicating that most cells have low values, but a smaller subset show much higher values.
* AP amplitude adaptation index (skew = -1.13) are left-skewed, indicating that most cells have high values (i.e. high adaptation), but a smaller subset show much lower values.
* Membrane time constant, Input resistance, and Max number of APs have high max values (>3.8 z-score), indicating large variability among cells in how they respond to current input.
* Several features show approximately symmetric distributions, like AP threshold (skew = −0.06), Afterhyperpolarization (mV) (skew = +0.03) and Resting membrane potential (skew = −0.12). This suggests that these properties are well-centered and consistent across cells.


In [None]:
# Plot distributions of standardized electrophysiological features
n_features = len(ephysNames_filtered)
n_cols = 4
n_rows = int(np.ceil(n_features / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))

axes = axes.flatten()

for i, feature in enumerate(ephysNames_filtered):
    ax = axes[i]
    sns.histplot(
        ephysData_filtered[:, i],
        bins=30,
        kde=True,
        color="skyblue",
        stat="density",
        ax=ax,
    )
    sns.kdeplot(
        ephysData_filtered[:, i],
        color="darkblue",
        linewidth=1,
        ax=ax,
    )  # smoothed density plot (darkblue line)

    ax.axvline(0, color="red", linestyle="--", linewidth=1)
    ax.set_title(feature, fontsize=9, fontweight="bold")
    ax.set_xlabel("Standardized value")
    ax.set_ylabel("Density")

# Remove empty subplots
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

fig.suptitle(
    "Distributions of Standardized Electrophysiological Features",
    fontsize=14,
    fontweight="bold",
)
plt.tight_layout(rect=[0, 0, 1, 0.97])

plt.show()

*Interpretation:*
* means (0) and standard deviations (1) are consistent with standardized data
* skewness is clearly visible
* generally, the distributions show clearly what the statistics suggest: all features are centered around zero with unit variance, and the observed skewness in the plots aligns with the computed skewness values, highlighting the diversity in distribution shapes.

In [None]:
# DataFrame with standardized ephys data
ephys_df = pd.DataFrame(ephysData_filtered, columns=ephysNames_filtered)

# Compute Pearson correlation matrix
corr_matrix = ephys_df.corr()

# Test
print("Correlation matrix shape:", corr_matrix.shape)

# "high correlation" threshold
# commonly, a correlation |r| > 0.6 is considered "strong"
threshold = 0.6
high_corr_pairs = []

# Iterate over the upper triangle of the correlation matrix (excluding the diagonal) because the matrix is symmetric
for i in range(len(corr_matrix.columns)):
    for j in range(i + 1, len(corr_matrix.columns)):
        corr_value = corr_matrix.iloc[i, j]
        if abs(corr_value) > threshold:
            feature_i = corr_matrix.columns[i]
            feature_j = corr_matrix.columns[j]
            high_corr_pairs.append((feature_i, feature_j, corr_value))

# Print results
print(
    "Feature pairings with high correlations (|corr| > {:.2f}):".format(
        threshold
    )
)
for feature1, feature2, corr in sorted(
    high_corr_pairs, key=lambda x: -abs(x[2])
):
    print(f"{feature1:30s} x {feature2:30s}: {corr:+.2f}")

In [None]:
# Plotting the correlation matrix
plt.figure(figsize=(12, 10))
sns.heatmap(
    corr_matrix,
    annot=True,
    fmt=".2f",
    cmap="vlag",
    center=0,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.8},
)
plt.title("Pairwise Pearson Correlation of Electrophysiological Features")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.show()

*Interpretation:*

Positive correlations:
* AP width (ms) x Upstroke-to-downstroke ratio: +0.94. Suggests these may capture overlapping aspects of action potential shape.
* Afterhyperpolarization (mV) x Upstroke-to-downstroke ratio: +0.67. Larger APs may lead to deeper afterhyperpolarizations, consistent with biophysical expectations
* AP amplitude (mV) x Afterhyperpolarization (mV): +0.67. Larger APs may lead to deeper afterhyperpolarizations, possibly reflecting cell type-specific properties or biophysical mechanisms.

Negative correlations:
* AP width (ms) x Max number of APs: -0.73. Cells firining more APs tend to have narrower APs, possibly indicating a trade-off between firing rate and action potential width.
* ISI coefficient of variation x Max number of APs: -0.70. More consistent ISIs are found in cells with higher firing rates, hinting at firing pattern specialization.
* Max number of APs x Upstroke-to-downstroke ratio: -0.68. Suggests that cells with more APs tend to have lower upstroke-to-downstroke ratios, possibly indicating differences in firing patterns or cell types.

Generally, highly correlated features might be redundant or be functionally related, which will be the main part of the next tasks.


## 1.4 Investigate Ground Truth Labels

In [None]:
# Analyze and visualize ground truth distribution
# ground truth data:
# - rna_type: 77 trancriptomic subtypes
# - rna_family: 9 broader transcriptomic families

# Barplot of RNA types
rna_type_counts = pd.Series(rna_type).value_counts()

plt.figure(figsize=(12, 5))
sns.barplot(
    x=rna_type_counts.index,
    y=rna_type_counts.values,
    hue=rna_type_counts.index,
    palette="muted",
    dodge=False,
    legend=False,
)
plt.xticks(rotation=90)
plt.title("RNA Types by Cell Count")
plt.ylabel("Number of Cells")
plt.xlabel("RNA Type")
plt.show()

# Barplot of RNA families
plt.figure(figsize=(8, 4))
sns.countplot(
    x=rna_family,
    hue=rna_family,
    order=np.unique(rna_family),
    palette="Set2",
    legend=False,
)
plt.title("Distribution of RNA Families")
plt.ylabel("Number of Cells")
plt.xlabel("RNA Family")
plt.xticks(rotation=45)
plt.show()

*Interpretation:*

Top: RNA types by cell count
* barplot shows the number of cells annotated with each RNA type
* only a few rna types are considerably highly represented (e.g.Pvalb Itrap2, L2/3 IT_2/)
* many RNA types have fewer than 10 cells, suggesting class imbalance --> could impact clustering and classification methods 

Bottom: RNA Families by cell count
* the plot aggregates rna types into broader families, which are essentially coarser labels
* Families like IT, Pvalb and Sst are dominating, whereas others like NP and Sncp are very underrepresented
* Compared to the top plot, the family distribution is more balanced, with fewer families having very few cells, yet the class imbalance is still present


## 1.5 Metadata

In [None]:
# print metadata labels
print("Metadata labels:")
for col in meta.columns:
    print(f"- {col}")

Generally, we want to focus on metadata that improves analysis quality. Not all recorded metadata is useful for our purposes, so we inspected the distributions and properties of the metadata. 

We additionally performed a background research which resulted in the following features:

Cre Line (Tasic et al. 2016):
* Cre lines indicate genetic targeting strategies, which can enrich for specific neuron types
* They help validate whether observed patterns are biologically driven or due to experimental targeting

Inferred layer (Loo et al. 2019):
* Cortical layers correspond to distinct cell populations and functions; many electrophysiological and transcriptomic differences are layer-specific.
* It could serve as a meaningful covariate or stratification variable

Sequencing batch (Tung et al. 2017):
* Sequencing batch can introduce technical variability (batch effects) unrelated to biology
* Identifying batch effects can help to avoid spurious conclusions and ensure generalizability



In [None]:
# Distribution of Cre lines
cre_counts = meta["Cre"][exclude_low_quality].value_counts()

# Plot
plt.figure(figsize=(10, 4))
sns.countplot(
    y=meta["Cre"][exclude_low_quality],
    order=cre_counts.index,
    palette="muted",
    hue=meta["Cre"][exclude_low_quality],
    legend=False,
)
plt.title("Distribution of Cre Lines")
plt.xlabel("Cell Count")
plt.ylabel("Cre Line")
plt.show()

In [None]:
# Distribution of cortical layers
layer_counts = meta["Inferred layer"][exclude_low_quality].value_counts()

# Plot
plt.figure(figsize=(8, 4))
sns.countplot(
    x=meta["Inferred layer"][exclude_low_quality],
    order=layer_counts.index,
    palette="coolwarm",
    hue=meta["Inferred layer"][exclude_low_quality],
    legend=False,
)
plt.title("Distribution of Inferred Cortical Layers")
plt.xlabel("Inferred Layer")
plt.ylabel("Number of Cells")
plt.show()

In [None]:
# Distribution of batches
batch_counts = meta["Sequencing batch"][exclude_low_quality].value_counts()

# plot
plt.figure(figsize=(8, 4))
sns.barplot(
    x=batch_counts.index.astype(str),
    y=batch_counts.values,
    palette="coolwarm",
    hue=batch_counts.index.astype(str),
    legend=False,
)
plt.title("Distribution of Sequencing Batches")
plt.xlabel("Sequencing Batch (ordered by cell count)")
plt.ylabel("Cell Count")
plt.show()

#### Metadata Conclusions

Cre Lines:
* The distribution shows several dominant lines (e.g. SST+, PV+), with many others represented sparsely.
* A large number of unique Cre lines suggests some are too underrepresented for individual modeling but may help explain biological variability when grouped or used selectively.

Inferred layer:
* The distribution is uneven, with most cells in layer 5, followed by 2/3 and 6, and very few in layer 1.
* This may impact how well certain layers are represented in modeling, and highlights that layer-specific trends could be explored in later tasks.

Sequencing batch:
* The barplot revealed an uneven distribution across batches, ranging from very large (e.g. batch 8) to tiny (batch 12).

**How they could be used in further analyses?:**

Cre Lines:
* use after clustering --> compare predicted clusters to Cre line distributions
* helps biologically validate clusters, e.g. a cluster dominated by SST+ Cre lines might represents inhibitory neurons (task 4)
* can examine how cre lines associate with PCs or gene expression signatures (task 5)

Inferred layer:
* use inferred layer as a coloring option to see if spatial cell identity is reflected in transcriptomic space (task 3)
* could be added as a covariate or group in regression or correlation analyses to see how gen-PC correlations differ across layers (task 5)

Sequencing batch:
* use batch information to check for batch effects after normalization and transformation 
* could color 2D visualizations (e.g. PCA/t-SNE) by batch to visually inspect clustering by batch, which is undesired and indicates technical confounding (task 2)
* use batch labels as negative control labels: ideally batches should not be separable in a meaningful embedding (task 3)





# Task 2

Preparation steps:

* we want to use meta data
* therefore we need to filter the metadata to only include the cells we are interested in
* this is done by excluding the low quality cells which is only already done for rna_type, rna_family and exonCounts
* to avoid misalignment, we reset indices so that the indices of the meta data match the indices of the exonCounts

In [None]:
# preparation steps

# RNA counts
exonCounts_filtered = exonCounts

# Ground truth labels
rna_type_filtered = rna_type.reset_index(drop=True)
rna_family_filtered = rna_family.reset_index(drop=True)

# Filter Metadata variables of interest
cre_filtered = meta["Cre"][exclude_low_quality].reset_index(drop=True)
layer_filtered = meta["Inferred layer"][exclude_low_quality].reset_index(
    drop=True
)
batch_filtered = meta["Sequencing batch"][exclude_low_quality].reset_index(
    drop=True
)

assert (
    exonCounts_filtered.shape[0]
    == len(rna_type_filtered)
    == len(cre_filtered)
    == len(layer_filtered)
    == len(batch_filtered)
)

## 2.1 normalization + transformation

log-CPM normalization - appropriately tailored for read count data

In [None]:
# compute total amount of reads per cell
total_counts_per_cell = exonCounts_filtered.sum(axis=1)

# avoid dividing by zero by setting any 0s to 1
total_counts_per_cell[total_counts_per_cell == 0] = 1

# NORMALIZATION
# calculate counts per million
cpm = (
    exonCounts_filtered.T / total_counts_per_cell
).T * 1e6  # transpose to divide along columns (genes)

# TRANSFORMATION
# log-transform with log(1+x)
log_cpm = np.log1p(cpm)
# sqrt-transform with sqrt(x)
sqrt_cpm = np.sqrt(cpm)

# create dataframe
cpm_df = pd.DataFrame(cpm, index=rna_type_filtered, columns=genes)
log_cpm_df = pd.DataFrame(log_cpm, index=rna_type_filtered, columns=genes)
sqrt_cpm_df = pd.DataFrame(sqrt_cpm, index=rna_type_filtered, columns=genes)

print("Log-CPM shape:", log_cpm.shape)
print("Sqrt-CPM shape:", sqrt_cpm.shape)

assert log_cpm_df.shape[0] == len(rna_type_filtered), "check for alignment"
assert sqrt_cpm_df.shape[0] == len(rna_type_filtered), "check for alignment"

In [None]:
# show how normalization changes total RNA counts
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].bar(
    range(len(total_counts_per_cell)),
    total_counts_per_cell,
    width=1,
    alpha=0.5,
    color="blue",
)
axes[0].set_yscale("log")
axes[0].set_xlabel("Cell index")
axes[0].set_ylabel("Raw RNA counts")
axes[0].set_title("Total RNA Counts Per Cell")

total_cpm_counts_per_cell = cpm.sum(axis=1)


axes[1].bar(
    range(len(total_cpm_counts_per_cell)),
    total_cpm_counts_per_cell,
    width=1,
    alpha=0.5,
    color="blue",
)
axes[1].set_yscale("log")
axes[1].set_xlabel("Cell index")
axes[1].set_ylabel("Normalized RNA counts")
axes[1].set_title("Total Normalized RNA Counts Per Cell")

=> the raw counts are larger numbers whereas the log-CPM shows a compressed scale

In [None]:
# show change after log transform

fig, axes = plt.subplots(1, 3, figsize=(10, 4))

axes[0].scatter(
    mean_expression_across_cells,
    variance_expression_across_cells,
    s=5,
    alpha=0.5,
)
axes[0].set_xscale("log")
axes[0].set_yscale("log")
axes[0].set_xlabel("Mean Expression Across Cells (log scale)")
axes[0].set_ylabel("Variance Across Cells (log scale)")
axes[0].set_title("Mean vs Variance of Gene Expression")


mean_log_expression = log_cpm.mean(axis=0)
var_log_expression = log_cpm.var(axis=0)


axes[1].scatter(
    mean_log_expression,
    var_log_expression,
    s=5,
    alpha=0.5,
)
axes[1].set_xscale("log")
axes[1].set_yscale("log")
axes[1].set_xlabel("Mean Expression Across Cells (log scale)")
axes[1].set_ylabel("Variance Across Cells (log scale)")
axes[1].set_title("Mean vs Variance of Gene Expression\nAfter Log-Transform")

# show change after sqrt transform
mean_sqrt_expression = sqrt_cpm.mean(axis=0)
var_sqrt_expression = sqrt_cpm.var(axis=0)


axes[2].scatter(
    mean_sqrt_expression,
    var_sqrt_expression,
    s=5,
    alpha=0.5,
)
axes[2].set_xscale("log")
axes[2].set_yscale("log")
axes[2].set_xlabel("Mean Expression Across Cells (log scale)")
axes[2].set_ylabel("Variance Across Cells (log scale)")
axes[2].set_title("Mean vs Variance of Gene Expression\nAfter Sqrt-Transform")

# add identity lines
x_vals = np.linspace(1e-3, 1e4, 100)  # log scale range
axes[0].plot(
    x_vals, x_vals, linestyle="--", color="red", label="variance = mean"
)
x_vals = np.linspace(1e-4, 1e1, 100)  # log scale range
axes[1].plot(
    x_vals, x_vals, linestyle="--", color="red", label="variance = mean"
)
x_vals = np.linspace(1e-4, 1e2, 100)  # log scale range
axes[2].plot(
    x_vals, x_vals, linestyle="--", color="red", label="variance = mean"
)

axes[0].legend()
axes[1].legend()
axes[2].legend()

*Interpretation:*

Left: Raw CPM (before transformation)
* variance increases rapidly with mean, indicating overdispersion
* many genes show much higher variance than mean, deviating strongly from the Poisson assumption variance = mean. But this is expected for raw count data, as it is not supposed to follow a Poisson distribution.
* Strong heteroscedasticity (variance depends on the magnitude of the mean). Could hint that PCA might not be the best choice for this data, as PCA assumes homoscedasticity (constant variance across all means).

Middle: after log transformation
* cloud of points is more compact, with less spread in variance for high means
* points are now more evenly distributed around the identity line, indicating that the log transformation has stabilized variance and made the data more homoscedastic (e.g. compress extreme values and reduce skew)
* log transformed data looks now more suitable for PCA

Right: after sqrt transformation
* despite not changing the overall shape much, sqrt transform has also stabilized variance and made the data more homoscedastic
* sqrt transformed data also looks more suitable for PCA

## 2.2 select genes

In [None]:
# choose highly variable genes (hvgs) from normalized data
# we DO NOT use log-transformed data here, since that would flatten the mean-variance relationship
# (and we want to find the genes with high absolute variance)

# mask out genes with zero mean
mask_raw = mean_expression_across_cells > 0.1

dispersion = variance_expression_across_cells / (
    mean_expression_across_cells + 1e-8
)

# remove genes with very low mean (avoid divide by zero)
filtered_dispersion = dispersion[mask_raw]
filtered_mean = mean_expression_across_cells[mask_raw]


# 2000 HVGs is the most common default, but even 1000 lead to a very high necessary number of principle components in PCA later on.
# Fewer than 500 HVGs risks missing subtle biological signals, so we chose 500 as the lowest "possible" number.

# We also looked at 1000 HVGs to compare it later on with the 500 HVGs.

top_500_indices = np.argsort(filtered_dispersion)[-500:]
top_1000_indices = np.argsort(filtered_dispersion)[-1000:]

# get gene names
hvg_genes_500 = np.array(genes)[mask_raw][top_500_indices]
hvg_genes_1000 = np.array(genes)[mask_raw][top_1000_indices]

In [None]:
# plot mean-variance with hvgs highlighted

# mask of genes that are in hvg list
is_hvg_500 = np.isin(genes, hvg_genes_500)
is_hvg_1000 = np.isin(genes, hvg_genes_1000)

# fig, ax = plt.subplots(figsize=(6, 5))
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# for 500 hvgs
# all genes
axes[0].scatter(
    mean_expression_across_cells[mask_raw],
    variance_expression_across_cells[mask_raw],
    s=5,
    alpha=0.4,
    label="all genes",
    color="lightblue",
)

# overlay HVGs
axes[0].scatter(
    mean_expression_across_cells[is_hvg_500],
    variance_expression_across_cells[is_hvg_500],
    s=10,
    alpha=0.7,
    label=f"Top 500 HVGs by dispersion",
    color="darkblue",
)

# for 1000 hvgs
# all genes
axes[1].scatter(
    mean_expression_across_cells[mask_raw],
    variance_expression_across_cells[mask_raw],
    s=5,
    alpha=0.4,
    label="all genes",
    color="lightblue",
)

# overlay HVGs
axes[1].scatter(
    mean_expression_across_cells[is_hvg_1000],
    variance_expression_across_cells[is_hvg_1000],
    s=10,
    alpha=0.7,
    label=f"Top 1000 HVGs by dispersion",
    color="darkblue",
)

for ax in axes:
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("Mean Expression Across Cells")
    ax.set_ylabel("Variance Across Cells")
    ax.set_title("HVGs Highlighted in Mean–Variance Plot")
    ax.legend()
plt.tight_layout()
plt.show()

## 2.3 PCA

note: use normalized + log-transformed data here!

In [None]:
# find indices of HVGs
hvg_indices_500 = np.where(np.isin(genes, hvg_genes_500))[0]
hvg_indices_1000 = np.where(np.isin(genes, hvg_genes_1000))[0]

# subset the expression matrix: shape = (cells, HVGs)
hvg_expression_500 = exonCounts[:, hvg_indices_500]
hvg_expression_1000 = exonCounts[:, hvg_indices_1000]

# use normalized + transformed data from before
log_hvg_df_500 = log_cpm_df.iloc[:, hvg_indices_500].copy()
sqrt_hvg_df_500 = sqrt_cpm_df.iloc[:, hvg_indices_500].copy()
log_hvg_df_1000 = log_cpm_df.iloc[:, hvg_indices_1000].copy()
sqrt_hvg_df_1000 = sqrt_cpm_df.iloc[:, hvg_indices_1000].copy()

# non-transformed data
raw_hvg_df_500 = cpm_df.iloc[:, hvg_indices_500].copy()
raw_hvg_df_1000 = cpm_df.iloc[:, hvg_indices_1000].copy()

In [None]:
# fit PCA + play with the number of components
# 400 components was shown to be the lowest number that still explains at least 60% of variance in the data
pca_log_500 = PCA(
    n_components=400  
)
pca_log_1000 = PCA(
    n_components=400 
)
X_pca_log_500 = pca_log_500.fit_transform(log_hvg_df_500)
X_pca_log_1000 = pca_log_1000.fit_transform(log_hvg_df_1000)

pca_sqrt_500 = PCA(
    n_components=400  
)
pca_sqrt_1000 = PCA(
    n_components=400  
)
X_pca_sqrt_500 = pca_sqrt_500.fit_transform(sqrt_hvg_df_500)
X_pca_sqrt_1000 = pca_sqrt_1000.fit_transform(sqrt_hvg_df_1000)

# non-transformed data
pca_raw_500 = PCA(
    n_components=400  
)
pca_raw_1000 = PCA(
    n_components=400  
)
X_pca_raw_500 = pca_raw_500.fit_transform(raw_hvg_df_500)
X_pca_raw_1000 = pca_raw_1000.fit_transform(raw_hvg_df_1000)

# Check if the number of cells matches the number of RNA types
assert (
    len(rna_type_filtered) == X_pca_log_500.shape[0]
), "Mismatch between RNA types and PCA output"

print("PCA result shape:", X_pca_log_500.shape)  # (num_cells, n_components)
print("PCA result shape:", X_pca_sqrt_1000.shape)  # (num_cells, n_components)

In [None]:
# plot explained variance to compare
explained_log_500 = pca_log_500.explained_variance_ratio_
explained_sqrt_500 = pca_sqrt_500.explained_variance_ratio_
explained_log_1000 = pca_log_1000.explained_variance_ratio_
explained_sqrt_1000 = pca_sqrt_1000.explained_variance_ratio_

explained_raw_500 = pca_raw_500.explained_variance_ratio_
explained_raw_1000 = pca_raw_1000.explained_variance_ratio_

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(
    np.cumsum(explained_log_1000), color="blue", label="log (1000 HVGs)"
)
axes[0].plot(
    np.cumsum(explained_sqrt_1000), color="orange", label="sqrt (1000 HVGs)"
)
axes[0].plot(
    np.cumsum(explained_raw_1000), color="green", label="raw (1000 HVGs)"
)
axes[1].plot(
    np.cumsum(explained_log_500), color="blue", label="log (500 HVGs)"
)
axes[1].plot(
    np.cumsum(explained_sqrt_500), color="orange", label="sqrt (500 HVGs)"
)
axes[1].plot(
    np.cumsum(explained_raw_500), color="green", label="raw (500 HVGs)"
)

axes[0].axhline(0.9, color="red", linestyle="--", label="90% threshold")
axes[1].axhline(0.9, color="red", linestyle="--", label="90% threshold")

for ax in axes:
    ax.set_xlabel("Number of Components")
    ax.set_ylabel("Cumulative Explained Variance")
    ax.set_title("Explained Variance by PCA Components")
    ax.legend()
    ax.grid(True)
plt.tight_layout()
plt.show()

Looking at raw data first, the curve is very steep and the first few components already explain most of the variance. However, this is most likely not biologically meaningful variance, since raw CPM values have dominant technical noise like sequencing depth or cell size. For this reason, we chose not to use the raw data in the following analyses.

For 1000 HVGs and log-transform, 390 components would be needed to capture ~90% of the variance. Sqrt-transform performs only a little better, around 275 components would be needed here. Even for 500 HVGS, way above 100 PCs are needed to explain 90% of the variance. This means that the variance is spread out across many dimensions and most individual PCs explain only a small amount of variance. Even though more PCs can capture more information, most of that may not be biologically meaningful, so choosing around fewer PCs still makes more sense.

While sqrt-transformed data seems a little bit better at capturing variance, we cannot be sure whether that variance is meaningful or noise. Since log-transform is more commonly used for bio data and better at reducing noise, we will mostle use that in the following. For 500 HVGs, ... PCs can account for ~80% of the variance, which seems fair given that it’s better to retain fewer PCs that reflect meaningful signal than to overfit with noise. Still, sqrt might come in handy for e.g. UMAP visualizations.

In [None]:
# Filter RNA types and families for visualization
labels_rna_family_filtered = LabelEncoder().fit_transform(rna_family_filtered)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# plot the first 2 components for rna type and family
# PCA colored by RNA Type
# Get color values for the filtered RNA types
rna_type_colors_filtered = np.vectorize(dict_rna_type_colors.get)(
    rna_type_filtered
)

scatter1 = axes[0].scatter(
    X_pca_log_500[:, 0],
    X_pca_log_500[:, 1],
    c=rna_type_colors_filtered,
    s=5,
    alpha=0.6,
)
axes[0].set_title("Colored by RNA Type")
axes[0].set_xlabel("PC1")
axes[0].set_ylabel("PC2")
axes[0].grid(True)

# PCA colored by RNA Family
scatter2 = axes[1].scatter(
    X_pca_log_500[:, 0],
    X_pca_log_500[:, 1],
    c=labels_rna_family_filtered,
    cmap="tab10",
    s=5,
    alpha=0.6,
)
axes[1].set_title("Colored by RNA Family")
axes[1].set_xlabel("PC1")
axes[1].set_ylabel("PC2")
axes[1].grid(True)
plt.suptitle("PCA of Highly Variable Genes (HVGs)")


plt.tight_layout()
plt.show()

In [None]:
# Set parameters
hvg_counts = [500, 1000, 2000]


# Set up figure
fig, axes = plt.subplots(len(hvg_counts), 2, figsize=(10, 9))
fig.suptitle(
    "Effect of HVG Count on PCA – RNA Type vs RNA Family", fontsize=16
)

# Loop through combinations
for i, n_hvg in enumerate(hvg_counts):
    # HVG selection
    dispersion = variance_expression_across_cells / (
        mean_expression_across_cells + 1e-8
    )
    mask = mean_expression_across_cells > 0.1

    filtered_dispersion = dispersion[mask]
    top_indices = np.argsort(filtered_dispersion)[-n_hvg:]
    hvg_genes = np.array(genes)[mask][top_indices]

    # Subset & scale
    log_hvg_df = log_cpm_df[hvg_genes]
    data = StandardScaler().fit_transform(log_hvg_df.values)

    # PCA to 2D
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(data)

    # Plot rna_type
    axes[i, 0].scatter(
        X_pca[:, 0], X_pca[:, 1], c=rna_type_colors, s=10, alpha=0.6
    )
    axes[i, 0].set_title(f"{n_hvg} HVGs\nColored by RNA Type")
    axes[i, 0].set_xlabel("PC1")
    axes[i, 0].set_ylabel("PC2")

    # Plot rna_family
    axes[i, 1].scatter(
        X_pca[:, 0],
        X_pca[:, 1],
        c=labels_rna_family_filtered,
        cmap="tab10",
        s=10,
        alpha=0.6,
    )
    axes[i, 1].set_title(f"{n_hvg} HVGs\nColored by RNA Family")
    axes[i, 1].set_xlabel("PC1")
    axes[i, 1].set_ylabel("PC2")

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

For 500 HVGs, the plots seem to show the clearest separation between two clusters. This makes sense, since for the least amount of HVGs, the least amount of information is left to be broken down. For more HVGs, 2 PCs can capture even less variance, so the plots look a little more 'chaotic'.

# Task 3

## 3.1 Visualization

### 3.1.1 t-SNE

In [None]:
# prepare data
log_hvg_df = log_cpm_df[hvg_genes_500]
sqrt_hvg_df = sqrt_cpm_df[hvg_genes_500]
X_hvg = StandardScaler().fit_transform(log_hvg_df.values)
X_hvg_sqrt = StandardScaler().fit_transform(sqrt_hvg_df.values)

# sanity checks
assert X_hvg.shape[0] == len(rna_type_filtered)
assert X_hvg.shape[0] == len(rna_family_filtered)

# filter log-normalized gene expression to the same 1224 cells
X_hvg_filtered = X_hvg[keepcells, :]  ### hier dann gar nicht verwendet?

# define dimensionality reduction methods and parameters
methods = {
    "t-SNE (perp=5)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=5, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
    "t-SNE (perp=30)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=30, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
    "t-SNE (perp=100)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=100, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
}

# create subplots: 2 rows (RNA type + RNA family) x 3 columns (t-SNE variants)
fig, axes = plt.subplots(2, len(methods), figsize=(16, 9))
fig.suptitle(
    "t-SNE Projections Colored by RNA Type and RNA Family", fontsize=16
)

for col_idx, (title, conf) in enumerate(methods.items()):
    X_embedded = conf["func"](X_hvg)

    # row 0: RNA type
    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_type_filtered,
        palette=dict_rna_type_colors,
        s=15,
        alpha=0.7,
        legend=False,
        ax=axes[0, col_idx],
    )
    axes[0, col_idx].set_title(f"{title} – RNA Type")
    axes[0, col_idx].set_xlabel("Dim 1")
    axes[0, col_idx].set_ylabel("Dim 2")

    # row 1: RNA family
    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_family_filtered,
        palette="tab10",
        s=15,
        alpha=0.7,
        legend=False,
        ax=axes[1, col_idx],
    )
    axes[1, col_idx].set_title(f"{title} – RNA Family")
    axes[1, col_idx].set_xlabel("Dim 1")
    axes[1, col_idx].set_ylabel("Dim 2")

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
plt.tight_layout()
plt.show()

The comparison shows how t-SNE's perplexity parameter affects local versus global structure:
* perplexity=5: clusters are small and tightly packed => emphasis on fine local relationships
* perplexity=30: balance between local detail & global structure => seems to provide the best overview of the data
* perplexity=100: the plot appears more globally smoothed and there is less distinct clustering => highlighting broader structure at the expense of local detail

### 3.1.2 UMAP

In [None]:
import umap.umap_ as umap
import matplotlib.pyplot as plt
import seaborn as sns

# define dimensionality reduction methods and parameters
methods = {
    "UMAP (n=15)": lambda X: umap.UMAP(
        n_neighbors=15, min_dist=0.1, random_state=42
    ).fit_transform(X),
    "UMAP (n=30)": lambda X: umap.UMAP(
        n_neighbors=30, min_dist=0.3, random_state=42
    ).fit_transform(X),
    "UMAP (n=50)": lambda X: umap.UMAP(
        n_neighbors=50, min_dist=0.3, random_state=42
    ).fit_transform(X),
}

# create 2 rows (log + sqrt), n columns = number of methods
n_methods = len(methods)
fig, axes = plt.subplots(2, n_methods, figsize=(5 * n_methods, 10))
axes = axes.reshape(2, n_methods)

# plot for log-transformed data
for col, (title, func) in enumerate(methods.items()):
    X_embedded_log = func(X_hvg)
    sns.scatterplot(
        x=X_embedded_log[:, 0],
        y=X_embedded_log[:, 1],
        hue=rna_family,
        palette="tab10",
        s=30,
        alpha=0.7,
        legend=False,
        ax=axes[0, col],
    )
    axes[0, col].set_title(f"{title} — log(CPM)")
    axes[0, col].set_xlabel("Dim 1")
    axes[0, col].set_ylabel("Dim 2")

# plot for sqrt-transformed data
for col, (title, func) in enumerate(methods.items()):
    X_embedded_sqrt = func(X_hvg_sqrt)
    sns.scatterplot(
        x=X_embedded_sqrt[:, 0],
        y=X_embedded_sqrt[:, 1],
        hue=rna_family,
        palette="tab10",
        s=30,
        alpha=0.7,
        legend=False,
        ax=axes[1, col],
    )
    axes[1, col].set_title(f"{title} — sqrt(CPM)")
    axes[1, col].set_xlabel("Dim 1")
    axes[1, col].set_ylabel("Dim 2")

plt.suptitle(
    "UMAP Embeddings (log vs sqrt CPM) by Different Neighborhood Sizes",
    fontsize=16,
    y=1.03,
)
plt.tight_layout()
plt.show()

The comparison shows how the 'n_neighbors' parameter and the type of data transformation affect the balance between local and global structure:

**type of data transformation:**

Log-transformed and sqrt-transformed data actually look very similar, so it seems like we gain no advantage from the sqrt-transformed data for UMAP, as speculated above. For this reason and for simplicity, we will fully focus on the log-transformed data from now on.

**number of neighbors:**

n=15: the plot reveals finer substructures and potential cell states

n=30 / n=50:
* clusters are more connected or blended.
* increasing `n_neighbors` favors global structure, but makes us lose finer distinctions
* useful to highlight broad cell types instead of subtypes

=> smaller 'n_neighbors' values are better for identifying subpopulations, while larger values show more global cell type structure


### 3.1.3 visual comparison of PCA, t-SNE, UMAP

In [None]:
# define dimensionality reduction methods and parameters
methods = {
    "PCA": {"func": lambda X: PCA(n_components=2).fit_transform(X)},
    "t-SNE (perp=30)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=30, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
    "UMAP (n=15)": {
        "func": lambda X: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=42
        ).fit_transform(X)
    },
}

# create subplots
n_methods = len(methods)
fig, axes = plt.subplots(1, n_methods, figsize=(5 * n_methods, 5))

# loop through and plot
for ax, (title, conf) in zip(axes, methods.items()):
    X_embedded = conf["func"](X_hvg)

    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        palette="tab10",
        s=30,
        alpha=0.7,
        legend=False,
        ax=ax,
    )

    ax.set_title(title)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")

plt.tight_layout()
plt.show()

Interpretation across visualization types: 

PCA

* no clear clusters.
* seems to capture global variance, but fails to separate subtle biological subpopulations
=> rather use for noise reduction / preprocessing

t-SNE (`n_neighbors`=30)

* multiple well-separated clusters.
* seems to reveal local structure, separating even small subpopulations
* downside: t-SNE is not ideal for preserving global relationships

UMAP (`n_neighbors`=15)

* distinct clusters, some continuity
* seems to balances local and global structure, showing gradients and groupings


=> further proceedings:

* Use PCA as input to UMAP/t-SNE
* Use UMAP or t-SNE for cluster discovery or cell type visualization.


TODO:we can try coloring by RNA type, cluster ID, or marker expression to aid interpretation.


## 3.2 Comparison using Quantitative Metrics

In [None]:
def evaluate_knn_projection(
    cpm_df,
    hvg_genes,
    labels,
    n_neighbors=10,
    test_size=0.3,
    random_state=42,
    pca_components_for_tsne_umap=50,
):
    """
    Compare dimensionality reduction methods using kNN classification metrics.

    Parameters:
        cpm_df (DataFrame): CPM-normalized+transformed gene expression (cells × genes)
        hvg_genes (list): list of highly variable gene names
        labels (array-like): class labels for each cell (e.g., RNA type)
        n_neighbors (int): number of neighbors for kNN
        test_size (float): train/test split size
        random_state (int): reproducibility seed
        pca_components_for_tsne_umap (int): number of PCs to feed into t-SNE/UMAP

    Returns:
        DataFrame with accuracy and recall for each method.
    """

    # standardize input data
    X_hvg = StandardScaler().fit_transform(
        cpm_df[hvg_genes].values
    )
    y = np.array(labels)

    # precompute PCA-reduced input for non-linear methods
    X_pca_for_embedding = PCA(
        n_components=pca_components_for_tsne_umap
    ).fit_transform(X_hvg)

    methods = {
        "High-dimensional (raw)": lambda X: X,
        "PCA (2D)": lambda X: PCA(n_components=2).fit_transform(X),
        "PCA (10D)": lambda X: PCA(n_components=10).fit_transform(X),
        "t-SNE (2D)": lambda X: TSNE(
            n_components=2, perplexity=30, random_state=random_state
        ).fit_transform(X),
        "t-SNE (2D from PCA)": lambda _: TSNE(
            n_components=2, perplexity=30, random_state=random_state
        ).fit_transform(X_pca_for_embedding),
        "t-SNE (10D)": lambda X: TSNE(
            n_components=10,
            perplexity=30,
            method="exact",
            random_state=random_state,
        ).fit_transform(X),
        "t-SNE (10D from PCA)": lambda _: TSNE(
            n_components=10,
            perplexity=30,
            method="exact",
            random_state=random_state,
        ).fit_transform(X_pca_for_embedding),
        "UMAP (2D)": lambda X: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=random_state
        ).fit_transform(X),
        "UMAP (2D from PCA)": lambda _: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=random_state
        ).fit_transform(X_pca_for_embedding),
        "UMAP (10D from PCA)": lambda _: umap.UMAP(
            n_neighbors=15,
            min_dist=0.1,
            n_components=10,
            random_state=random_state,
        ).fit_transform(X_pca_for_embedding),
    }

    results = []
    for name, func in methods.items():
        try:
            X_proj = func(X_hvg)
            X_train, X_test, y_train, y_test = train_test_split(
                X_proj,
                y,
                test_size=test_size,
                random_state=random_state,
            )
            print(f"{name}: X_proj shape = {X_proj.shape}")

            clf = KNeighborsClassifier(n_neighbors=n_neighbors)
            clf.fit(X_train, y_train)
            acc = clf.score(X_test, y_test)
            y_pred = clf.predict(X_test)
            avg_recall = recall_score(
                y_test, y_pred, average="macro", zero_division=0
            )

            # compute silhouette score and AMI as in Lause, Berens, Kobak (2024):

            # silhouette score
            sil = silhouette_score(X_proj, y)
            # unsupervised clustering for AMI
            n_clusters = len(np.unique(y))
            cluster_labels = KMeans(
                n_clusters=n_clusters, random_state=random_state
            ).fit_predict(X_proj)
            ami = adjusted_mutual_info_score(y, cluster_labels)

            results.append(
                {
                    "Method": name,
                    "kNN Accuracy": acc,
                    "kNN Recall (avg)": avg_recall,
                    "Silhouette Score": sil,
                    "AMI": ami,
                }
            )
        except Exception as e:
            results.append(
                {
                    "Method": name,
                    "kNN Accuracy": None,
                    "kNN Recall (avg)": None,
                    "Silhouette Score": None,
                    "AMI": None,
                    "Error": str(e),
                }
            )

    return pd.DataFrame(results)

In [None]:
print("cpm_df shape:", log_cpm_df.shape)
print("Sample of columns:", log_cpm_df.columns[:5])
print("Sample of hvg_genes:", hvg_genes[:5])

In [None]:
results_df = evaluate_knn_projection(
    cpm_df=log_cpm_df,
    hvg_genes=hvg_genes,
    labels=rna_type,
    n_neighbors=10,
    pca_components_for_tsne_umap=50,
)
print(results_df)

GANZ AM ENDE WERTE NOCHMAL ANPASSEN!

Best Overall (based on all metrics): t-SNE (10D from PCA):
* kNN Accuracy: 0.46
* Recall: 0.26
* AMI: 0.44 (highest)
* Silhouette: ~-0.21

kNN-performance of the models:
* t-SNE (10D) is our best performer in terms of accuracy
* t-SNE (10D from PCA) has the highest recall
* PCA (2D) performs poorly — linear reduction to 2D doesn't seem to work well for this data
* UMAP is competitive but doesn't benefit much from going to 10D
* Raw high-dimensional data still performs surprisingly well — showing that your data has strong native structure.
* a kNN accuracy of around 40% is not high but seems acceptable given the biological data

Overall low silhouette scores:
All low-dimensional methods have negative silhouette scores, meaning clusters may be overlapping or poorly separated.
This is common in t-SNE/UMAP, as they optimize local structure rather than global clustering.

AMI scores:
Most AMI scores surpass 0.4, which indicates moderately good label-cluster alignment.

In [None]:
# plot kNN accuracy vs. recall for different dimensionality reduction methods

def plot_knn_accuracy_vs_recall(results_df, figsize=(8, 6)):
    """
    Parameters:
        results_df (DataFrame): must contain columns 'Method', 'kNN Accuracy', and 'kNN Recall (avg)'
        figsize (tuple): size of the plot (width, height)
    """
    methods = results_df["Method"]
    accuracy = results_df["kNN Accuracy"]
    recall = results_df["kNN Recall (avg)"]

    plt.figure(figsize=figsize)
    texts = []

    # plot points
    for i in range(len(methods)):
        plt.scatter(accuracy[i], recall[i], s=200, label=methods[i])
        plt.text(accuracy[i], recall[i], methods[i], fontsize=8, rotation=15)

    plt.xlabel("kNN Accuracy")
    plt.ylabel("kNN Recall (macro avg)")
    plt.title("kNN Accuracy vs. Recall for Dimensionality Reduction Methods")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
plot_knn_accuracy_vs_recall(results_df)

Interpretation:

* t-SNE (2D) nearly matches the accuracy of the high-dimensional data while even slightly improving average recall. This shows that it seems to preserve neighborhood structure quite well in 2D.
* UMAP (2D) performs reasonably, a bit behind t-SNE, but still far ahead of PCA (2D).
* PCA (2D) performs poorly — this confirms PCA doesn't capture non-linear relationships or local structure well in low dimensions.
* t-SNE and UMAP seem to have benefitted from PCA initialization => visualize these again below
* High-dimensional kNN is very strong (as expected), but not better than t-SNE, UMAP or especially any of the other high-dimensional methods. Consequently, e.g. even the 2D t-SNE visualization seems to be biologically meaningful.

In [None]:
# visualize the 2D "winners"
# define 2D projections
X_pca2 = PCA(n_components=2).fit_transform(X_hvg)

# precompute PCA-reduced input for non-linear methods
X_pca_for_embedding = PCA(n_components=50).fit_transform(X_hvg)

# t-SNE and UMAP from PCA
X_tsne2 = TSNE(
    n_components=2, perplexity=30, method="exact", random_state=42
).fit_transform(X_pca_for_embedding)
X_umap2 = umap.UMAP(
    n_components=2, n_neighbors=15, min_dist=0.1, random_state=42
).fit_transform(X_pca_for_embedding)

In [None]:
# plot models from above
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

titles = ["PCA (2D)", "t-SNE (2D from PCA)", "UMAP (2D from PCA)"]
embeddings = [X_pca2, X_tsne2, X_umap2]

for ax, title, embed in zip(axes, titles, embeddings):
    sns.scatterplot(
        x=embed[:, 0],
        y=embed[:, 1],
        hue=rna_family,
        palette="tab10",
        s=30,
        alpha=0.8,
        legend=False,
        ax=ax,
    )
    ax.set_title(title)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")

plt.tight_layout()
plt.show()

Confirming the quantitative comparison, s-SNE and UMAP with PCA initialization look better than without, showing clearer clusters.

# Task 4

## 4.1 prepare data

Outline of clustering steps was mainly inspired by Harris et al. 2018

In [None]:
# as Harris et al. suggest:
# use either raw count data or counts normalized per cell by size factors

# we already have them:
# log_cpm_df --> log-transformed CPM-normalized data
# exonCounts --> raw counts
# cpm --> CPM-normalized data
# total_counts_per_cell --> total counts per cell
# hvg_genes --> highly variable gene names
# For leiden, we also need: log_hvg_ df (= log_cpm_df[hvg_genes])

# for GMM wen can (optionally) standardize the data
X_log_hvg_scaled = StandardScaler().fit_transform(log_hvg_df.values)

# for NBMM, we will need
exonCounts_df = pd.DataFrame(exonCounts, columns=genes)
raw_hvg_df = exonCounts_df[hvg_genes_500]

# and normalize by "library size", i.e. per-cell size factors
size_factors = raw_hvg_df.sum(axis=1) / np.median(raw_hvg_df.sum(axis=1))
normalized_counts_nbmm = raw_hvg_df.div(size_factors, axis=0)

# Compute size factors (total counts per cell)
# Rescale entire matrix to a stable magnitude (~CPM scale)
scaling_factor = 1e3 / normalized_counts_nbmm.values.mean()  # bring mean ~1000
normalized_counts_nbmm *= scaling_factor

print("log_hvg_df shape:", log_hvg_df.shape)
print("raw_hvg_df shape:", raw_hvg_df.shape)
print("normalized_counts_nbmm shape:", normalized_counts_nbmm.shape)

## 4.2 Leiden clustering

In [None]:
# Input data
X = X_log_hvg_scaled
pca = PCA(n_components=25)
X_pca = pca.fit_transform(X)

In [None]:
# Store results
results = []
ari_type_leiden = []
ari_family_leiden = []
K_leiden = []

n_neighbors_list = [15, 20, 30, 50]
resolution_list = [0.5, 1.0, 1.5, 2.0]

for n_neighbors in n_neighbors_list:
    # kNN graph
    knn = NearestNeighbors(n_neighbors=n_neighbors, metric="euclidean").fit(
        X_pca
    )
    distances, indices = knn.kneighbors(X_pca)

    # Create edges from kNN indices
    edges = []
    for i, neighbors in enumerate(indices):
        for neighbor in neighbors[1:]:  # skip self-loop
            edges.append((i, neighbor))

    # Build igraph graph
    g = ig.Graph(edges=edges, directed=False)
    g.simplify()

    for resolution in resolution_list:
        # Run Leiden clustering
        partition = leidenalg.find_partition(
            g,
            leidenalg.RBConfigurationVertexPartition,
            resolution_parameter=resolution,
        )
        leiden_labels = np.array(partition.membership)

        # Compute ARIs
        assert (
            len(leiden_labels) == X.shape[0]
        ), "Clustering output doesn't match data"
        assert len(rna_type_filtered) >= X.shape[0], "Too few labels!"

        # Align label arrays to X
        # Slice to match number of cells in X / clustering output
        rna_type_subset = np.array(rna_type_filtered)[: len(leiden_labels)]
        rna_family_subset = np.array(rna_family_filtered)[: len(leiden_labels)]

        ari_type_score = adjusted_rand_score(rna_type_subset, leiden_labels)
        ari_family_score = adjusted_rand_score(
            rna_family_subset, leiden_labels
        )

        # store for plotting
        K = len(np.unique(leiden_labels))
        K_leiden.append(K)
        ari_type_leiden.append(ari_type_score)
        ari_family_leiden.append(ari_family_score)

        # Store result
        results.append(
            {
                "n_neighbors": n_neighbors,
                "resolution": resolution,
                "n_clusters": len(np.unique(leiden_labels)),
                "ARI (RNA Type)": ari_type_score,
                "ARI (RNA Family)": ari_family_score,
            }
        )

# Convert to DataFrame
results_df = pd.DataFrame(results)
display(results_df)

Quick note: 

→ Adjusted Rand Index (ARI) is used as a clustering quality measure because it quantifies the similarity between predicted clusters and true labels while correcting for chance, making it robust and interpretable for comparing clustering performance.

*Interpretation:*

* Increasing the resolution parameter generally increases the number of clusters.
    * e.g. n_neighbors= 15, clusters from 7 (res 0.5) to 17 (res 2)
* highest ARI for RNA_family is 0.528 at
    * n_neighbors = 15
    * resolution = 1.0
    * n_clusters = 9 (matches the number of RNA families) → This is likely the most biologically meaningful Leiden clustering.  Provides strong evidence that the clustering is capturing biologically meaningful structure. An ARI of ~0.52 with a biologically matched number of clusters (9 RNA families) is strong support that the preprocessing and clustering choices were well-justified
* trade-off in resolution
    * Higher resolutions increase granularity, but ARI doesn’t necessarily improve
    * In some cases (e.g., res = 2.0), more clusters lead to lower ARI, suggesting over-clustering may obscure true biological structure
* RNA Types vs. RNA Family
    * ARI (RNA Family) scores are consistently higher than ARI (RNA Type), across the board.
    * → This suggests the clustering aligns more closely with RNA families than types — possibly because families are broader or better captured by the feature space.
* Stability across neighbors
    * The top ARI scores for RNA family are relatively consistent across different n_neighbors, especially at res=1.0 and 1.5.
    * This suggests that the Leiden results are not overly sensitive to this parameter in our case.


In [None]:
# plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

sns.lineplot(
    data=results_df,
    x="resolution",
    y="ARI (RNA Type)",
    hue="n_neighbors",
    marker="o",
    ax=axes[0],
)
axes[0].set_title("ARI vs Resolution (RNA Type)")
axes[0].set_ylabel("Adjusted Rand Index")
axes[0].set_xlabel("Leiden Resolution")

sns.lineplot(
    data=results_df,
    x="resolution",
    y="ARI (RNA Family)",
    hue="n_neighbors",
    marker="o",
    ax=axes[1],
)
axes[1].set_title("ARI vs Resolution (RNA Family)")
axes[1].set_ylabel("Adjusted Rand Index")
axes[1].set_xlabel("Leiden Resolution")

plt.tight_layout()
plt.show()

*Interpretation:*

The plot generally supports the numbers seen and described above. 

Left - RNA Type:

* ARI increases consistently with resolution, regardless of neighbor count.
* Lower neighbor counts (15, 20) tend to produce slightly higher ARI for RNA Type.
* However, overall ARIs stay < 0.25, meaning RNA Type is less distinctly separated in the data.
* RNA Type labels reflect less distinct expression profiles or more gradual transitions between groups. High resolution helps slightly, but separation is still weak.

Right - RNA Family:

* A clear performance peak at resolution = 1.0, especially for 15 neighbors.
* ARIs decline for higher resolutions, likely due to over-segmentation (splitting coherent RNA families).
* Best score (~0.52) occurs at resolution 1.0, 15 neighbors → matches biological ground truth of 9 families
* Leiden clustering captures biologically relevant structure best at moderate resolution and tighter neighborhood definitions. Pushing to higher resolutions divides the RNA families into smaller, less meaningful clusters



## 4.3 Gaussian Mixture Model



In [None]:
# Range of cluster numbers to try
cluster_range = range(5, 21)  # try 5 to 20 clusters

# Store results
gmm_results = []

for n_clusters in cluster_range:
    gmm = GaussianMixture(n_components=n_clusters, random_state=42)
    gmm_labels = gmm.fit_predict(X_log_hvg_scaled)

    ari_type_gmm = adjusted_rand_score(rna_type, gmm_labels)
    ari_family_gmm = adjusted_rand_score(rna_family, gmm_labels)

    gmm_results.append(
        {
            "n_clusters": n_clusters,
            "ari_rna_type": ari_type_gmm,
            "ari_rna_family": ari_family_gmm,
        }
    )

# Convert to DataFrame
df_gmm_results = pd.DataFrame(gmm_results)

# Display

plt.figure(figsize=(8, 5))
plt.plot(
    df_gmm_results["n_clusters"],
    df_gmm_results["ari_rna_type"],
    label="ARI vs RNA Type",
    marker="o",
)
plt.plot(
    df_gmm_results["n_clusters"],
    df_gmm_results["ari_rna_family"],
    label="ARI vs RNA Family",
    marker="s",
)
plt.xlabel("Number of GMM Clusters")
plt.ylabel("Adjusted Rand Index (ARI)")
plt.title("GMM Clustering: ARI vs Number of Clusters")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

*Interpretation:*

RNA Family (orange):

* reaches highest ARI (~ 0.45-0.46) at K = 8-10, again aligning well with the true number of RNA families (9)
* ARI drops sharply after K > 10, indicating over-segmentation
* peak is slightly lower than Leiden’s best (~0.52)
* GMM captures meaningful structure for RNA families when the cluster count is close to ground truth. Like Leiden, it performs best when K ≈ number of RNA families. However, it is slightly less expressive or flexible than Leiden in this task


RNA Type (blue):

* ARIs increase steadily with K and then plateau around 0.16–0.17 from K = 14 onward.
* Overall scores are modest, again suggesting that RNA type is not strongly clustered.
* This trend matches the Leiden results: RNA type differences are subtle and not easily captured by unsupervised clustering, even in GMM which assumes Gaussian structure.



## 4.4 Negative Binomial Mixture Model

In [None]:
# Initialization via kmeans
def initialize_responsibilities_kmeans(X, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, n_init=5, random_state=42).fit(X)
    labels_init = kmeans.labels_
    resp = np.zeros((X.shape[0], n_clusters))
    resp[np.arange(X.shape[0]), labels_init] = 1
    return resp


# put NBMM copmputation in a function
def run_nbmm(X, K, r=2, max_iter=100, tol=1e-4, seed=42, verbose=False):
    np.random.seed(seed)
    N, G = X.shape

    # Initialize with kmeans (see above)
    resp = initialize_responsibilities_kmeans(X, K)
    Z = np.argmax(resp, axis=1)

    # Initial parameters
    pi = resp.sum(axis=0) / N
    p = np.zeros((K, G))

    for k in range(K):
        X_k = X[Z == k]
        if X_k.shape[0] == 0:
            X_k = X[np.random.choice(N, size=5, replace=False)]
        mu_k = np.maximum(X_k.mean(axis=0), 1e-2)
        p[k] = r / (r + mu_k)

    prev_log_likelihood = -np.inf

    for iteration in range(max_iter):
        log_resp = np.zeros((N, K))

        for k in range(K):
            with np.errstate(divide="ignore", invalid="ignore"):

                p_k = np.clip(p[k], 1e-5, 1 - 1e-5)
                log_prob = nbinom.logpmf(X, n=r, p=p_k)
                if not np.isfinite(log_prob).all():
                    print(f"Cluster {k}: log_prob contains NaN or -inf")
                log_resp[:, k] = np.sum(log_prob, axis=1) + np.log(
                    pi[k] + 1e-8
                )

        # log-sum-exp stability
        log_resp -= np.max(log_resp, axis=1, keepdims=True)
        resp = np.exp(log_resp)
        resp_sum = resp.sum(axis=1, keepdims=True)
        resp_sum[resp_sum == 0] = 1e-8
        resp /= resp_sum

        # M-step
        Nk = resp.sum(axis=0)
        pi = Nk / N
        for k in range(K):
            weighted_sum = np.dot(resp[:, k], X)
            mean_k = weighted_sum / (Nk[k] + 1e-8)
            p[k] = r / (r + np.maximum(mean_k, 1e-2))

        # Convergence
        current_log_likelihood = np.sum(np.log(resp_sum))
        if verbose:
            print(
                f"Iter {iteration}, Log Likelihood: {current_log_likelihood:.2f}"
            )

        if np.abs(current_log_likelihood - prev_log_likelihood) < tol:
            break
        prev_log_likelihood = current_log_likelihood

    labels = np.argmax(resp, axis=1)
    return labels

In [None]:
# X_nbmm = normalized_counts_nbmm.values
X_nbmm = np.round(raw_hvg_df.values).astype(int)
rna_type_array = np.array(rna_type_filtered)
rna_family_array = np.array(rna_family_filtered)

K_values = range(5, 21)
ari_type_nbmm = []
ari_family_nbmm = []

for K in K_values:
    print(f"Running NBMM for K={K}")
    labels = run_nbmm(X_nbmm, K, r=2)
    ari_type_nbmm.append(adjusted_rand_score(rna_type_array, labels))
    ari_family_nbmm.append(adjusted_rand_score(rna_family_array, labels))

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(K_values, ari_type_nbmm, label="ARI vs RNA Type", marker="o")
plt.plot(K_values, ari_family_nbmm, label="ARI vs RNA Family", marker="s")
plt.xlabel("Number of Clusters (K)")
plt.ylabel("Adjusted Rand Index")
plt.title("NBMM Clustering Performance")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

*Interpretation:*

RNA Family (orange):

* ARI scores peak at K = 10 (~0.32) and again around K = 12.
* This is lower than Leiden (~0.52) and GMM (~0.46).
* Still shows a recognizable signal for RNA family structure, but less sharp and stable.
* NBMM captures RNA family structure to *some* extent, but it seems less expressive than Leiden or GMM here. This might be due to:
    * Model assumptions (NBMM assumes overdispersed count data).
    * Sensitivity to initialization or data preprocessing (e.g., normalization, rounding).

RNA Type (blue):

* ARI scores stay consistently low (0.04–0.11), much like with other methods.
* Slight improvement with increasing K, but still limited.
* Again, RNA type structure is weakly represented in the data, and NBMM (as the other methods) struggles to reveal it.

Leiden results come from varying n_neighbours and resolution, so k_leiden has duplicate values.

Therefore, we will not plot all results but only the maximum ARI score for each k value

In [None]:
# Extract GMM results
K_gmm = df_gmm_results["n_clusters"].values
ari_type_gmm = df_gmm_results["ari_rna_type"].values
ari_family_gmm = df_gmm_results["ari_rna_family"].values

# Aggregate by number of clusters
leiden_df = pd.DataFrame(
    {
        "K": K_leiden,
        "ARI_RNA_Type": ari_type_leiden,
        "ARI_RNA_Family": ari_family_leiden,
    }
)

# Take maximum ARI per K
leiden_summary = (
    leiden_df.groupby("K")
    .agg({"ARI_RNA_Type": "max", "ARI_RNA_Family": "max"})
    .reset_index()
)

plt.figure(figsize=(12, 6))

# RNA Type
plt.plot(
    leiden_summary["K"],
    leiden_summary["ARI_RNA_Type"],
    "o-",
    label="Leiden (RNA Type)",
)
plt.plot(
    df_gmm_results["n_clusters"],
    df_gmm_results["ari_rna_type"],
    "s-",
    label="GMM (RNA Type)",
)
plt.plot(K_values, ari_type_nbmm, "^-", label="NBMM (RNA Type)")

# RNA Family
plt.plot(
    leiden_summary["K"],
    leiden_summary["ARI_RNA_Family"],
    "o--",
    label="Leiden (RNA Family)",
)
plt.plot(
    df_gmm_results["n_clusters"],
    df_gmm_results["ari_rna_family"],
    "s--",
    label="GMM (RNA Family)",
)
plt.plot(K_values, ari_family_nbmm, "^--", label="NBMM (RNA Family)")

plt.xlabel("Number of Clusters (K)")
plt.ylabel("Adjusted Rand Index (ARI)")
plt.title("Clustering Method Comparison by ARI (Max ARI per K for Leiden)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

Leiden clustering performs best overall, especially for RNA family labels. This indicates:

* Graph-based clustering (Leiden) is more effective than GMM or NBMM in this context.
* RNA family identity is more strongly reflected in expression data than RNA type.


In [None]:
# GMM
best_idx_gmm_type = df_gmm_results["ari_rna_type"].idxmax()
best_idx_gmm_family = df_gmm_results["ari_rna_family"].idxmax()

# NBMM
best_idx_nbmm_type = pd.Series(ari_type_nbmm).idxmax()
best_idx_nbmm_family = pd.Series(ari_family_nbmm).idxmax()

# Leiden
best_idx_leiden_type = pd.Series(ari_type_leiden).idxmax()
best_idx_leiden_family = pd.Series(ari_family_leiden).idxmax()

# Build summary table
summary_data = {
    "Method": ["GMM", "NBMM", "Leiden"],
    "Best ARI (RNA Type)": [
        df_gmm_results["ari_rna_type"].iloc[best_idx_gmm_type],
        ari_type_nbmm[best_idx_nbmm_type],
        ari_type_leiden[best_idx_leiden_type],
    ],
    "Best K (RNA Type)": [
        df_gmm_results["n_clusters"].iloc[best_idx_gmm_type],
        K_values[best_idx_nbmm_type],
        K_leiden[best_idx_leiden_type],
    ],
    "Best ARI (RNA Family)": [
        df_gmm_results["ari_rna_family"].iloc[best_idx_gmm_family],
        ari_family_nbmm[best_idx_nbmm_family],
        ari_family_leiden[best_idx_leiden_family],
    ],
    "Best K (RNA Family)": [
        df_gmm_results["n_clusters"].iloc[best_idx_gmm_family],
        K_values[best_idx_nbmm_family],
        K_leiden[best_idx_leiden_family],
    ],
}

summary_df = pd.DataFrame(summary_data)
display(summary_df)

In [None]:
# Run t-SNE on scaled log-HVG data
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
X_tsne = tsne.fit_transform(X_pca)

# Plot, colored by RNA Type
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.scatterplot(
    x=X_tsne[:, 0],
    y=X_tsne[:, 1],
    hue=np.array(rna_type_filtered)[: X_tsne.shape[0]],
    palette="tab20",
    s=10,
    alpha=0.7,
    legend=False,
)
plt.title("t-SNE colored by RNA Type")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")

# Plot, colored by best Leiden clustering
# Get best Leiden labels (highest ARI(RNA Family))
best_index = np.argmax(ari_family_leiden)
best_labels = results[best_index]["n_clusters"]

# Recompute labels
best_n_neighbors = results[best_index]["n_neighbors"]
best_resolution = results[best_index]["resolution"]

# Build graph again to get the best clustering
knn = NearestNeighbors(n_neighbors=best_n_neighbors).fit(X_pca)
_, indices = knn.kneighbors(X_pca)
edges = [
    (i, neighbor)
    for i, neighbors in enumerate(indices)
    for neighbor in neighbors[1:]
]
g = ig.Graph(edges=edges, directed=False)
g.simplify()

# Run Leiden clustering with the best resolution
partition = leidenalg.find_partition(
    g,
    leidenalg.RBConfigurationVertexPartition,
    resolution_parameter=best_resolution,
)
leiden_labels_best = np.array(partition.membership)

# Plot
plt.subplot(1, 2, 2)
sns.scatterplot(
    x=X_tsne[:, 0],
    y=X_tsne[:, 1],
    hue=leiden_labels_best,
    palette="tab20",
    s=10,
    alpha=0.7,
    legend=False,
)
plt.title(
    f"t-SNE colored by Leiden Clusters\n(Res={best_resolution}, Neigh={best_n_neighbors})"
)
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")

plt.tight_layout()
plt.show()

Leiden clustering works well visually; it’s producing meaningful structure that resembles real biological types.

* The ARI score for RNA Family was highest for this configuration, and the right plot confirms that: there are 9 clear clusters, matching the number of RNA families.
* The partial overlap seen in the RNA Type plot reflects biological or technical noise, and why clustering RNA type is harder (lower ARI).


*Thoughts on NBMM performance:*

Reasons why NBMM could perform so poorly:
* Leiden and GMM use more expressive features
    * Leiden uses neighborhood graphs on PCA-reduced, log-normalized data — it’s robust to noise and captures manifold structure.
    * GMM works on scaled log-expression and handles ellipsoidal clusters well
    * In contrast: NBMM operates directly on raw count data, which retain a lot of technical noise and sparsity (see task1)
* Initialization & Optimization Challenges
    * NBMM was initialized with KMeans and uses an EM-like algorithm, but this can:
        * get stuck in local optima
        * be slow to converge, especially on sparse, high-dimensional count matrices
    * without careful parameter tuning (e.g. better initialization, r selection, regularization), it may not find meaningful clusters.
* RNA Type (length: 77) vs. RNA Family (length: 9)
    * Your ARI scores show NBMM performs better for RNA Family than for RNA Type.
    * This suggests NBMM can capture coarse structure, but not finer subclass separations — likely due to limited flexibility or resolution in the model


# Task 5

## 5.1 prepare data

* transcriptomics data is already normalized, hvgs were selected, PCA applied
* electrophysiological data: standardize/normalize (z-score), maybe apply PCA

then:
* match cells across modalities (=align them via cell IDs)

In [None]:
meta_df = meta.copy()
meta_df = meta_df[
    exclude_low_quality
]  # filter out low-quality cells from meta
meta_df["Cell"] = meta_df["Cell"].astype(str).str.strip()
meta_df = meta_df.set_index("Cell")

# set the same index on RNA and ephys
log_cpm_df.index = meta_df.index[: log_cpm_df.shape[0]]
ephys_df.index = meta_df.index[: ephys_df.shape[0]]

# z-score electrophysiological data
scaler = StandardScaler()
ephys_df_scaled = pd.DataFrame(
    scaler.fit_transform(ephys_df),
    index=ephys_df.index,
    columns=ephys_df.columns,
)
ephys_df = ephys_df_scaled

shared_cells = log_cpm_df.index.intersection(ephys_df.index).intersection(
    meta_df.index
)
print("Number of shared cells:", len(shared_cells))

log_cpm_df = log_cpm_df.loc[shared_cells]
ephys_df = ephys_df.loc[shared_cells]
meta_df = meta_df.loc[shared_cells]

# double check that alignment is correct
assert all(log_cpm_df.index == ephys_df.index)
assert all(ephys_df.index == meta_df.index)

In [None]:
print("Final aligned shapes:")
print("RNA:", log_cpm_df.shape)
print("Ephys:", ephys_df.shape)
print("Meta:", meta_df.shape)

In [None]:
# recompute HVGs

# use variance thresholding to mimic HVG selection
selector = VarianceThreshold(threshold=0.5)  # adjusted threshold
_ = selector.fit(log_cpm_df)

hvg_genes = log_cpm_df.columns[selector.get_support()]
print(f"Selected {len(hvg_genes)} HVGs based on current log_cpm_df.")

# rerun PCA
X_hvg = StandardScaler().fit_transform(log_cpm_df[hvg_genes])

pca = PCA(n_components=25)  # chose to keep 25 PCs

X_pca = pca.fit_transform(X_hvg)

# convert to df for easier correlation
pca_df = pd.DataFrame(
    X_pca,
    columns=[f"PC{i+1}" for i in range(X_pca.shape[1])],
    index=log_cpm_df.index,
)

now: explore relationships

* correlation analysis
* canonical correlation analysis
* UMAP
* PCA loadings to see gene contributions
* regression models
* biological interpretation

## 5.2 correlation analysis

In [None]:
# for all features vs all PCs:
cor_matrix_full = pd.DataFrame(index=ephys_df.columns, columns=pca_df.columns)
for ef in ephys_df.columns:
    for pc in pca_df.columns:
        cor_matrix_full.loc[ef, pc] = ephys_df[ef].corr(pca_df[pc])
cor_matrix_full = cor_matrix_full.astype(float)

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
sns.heatmap(
    cor_matrix_full,
    annot=True,
    cmap="vlag",
    center=0,
    annot_kws={"size": 7},
    ax=ax,
)
ax.set_title(
    "Correlation between electrophysiological features and PCA components"
)
ax.set_xlabel("Principal Components")
ax.set_ylabel("Ephys Features")
plt.show()

The heatmap shows Pearson correlation coefficients between the electrophysiological features and the transcriptomic PCs. We can see some negative values, indicating inverse correlation:
* Upstroke-to-downstroke ratio is negatively correlated with PC3, which suggests that higher values on PC3 go along with lower upstroke-downstroke ratios
* AP width and Afterhyperpolarization are also negatively correlated with PC3

There are also some positive values indicating positive correlation:
* higher Input resistance seems to go along with a high score for PC4
* higher Max number of APs seems to go along with a high score for PC3

Also, there are many near-zero values (white) that suggest no linear relationship between both types of data. Overall, PC3 stands out with the strongest ephys-correlation.

Now we want to isolate ion channel gene expression from data.

In [None]:
# convert all column names to uppercase to standardize for matching
log_cpm_df.columns = log_cpm_df.columns.str.upper()

# choose channel gene families of interest
ion_channel_prefixes = ["KCN", "SCN", "HCN", "CACNA", "KCNA", "GRIN", "GABR"]

# select genes that contain any of those chosen ion channel prefixes
ion_genes = [
    gene
    for gene in log_cpm_df.columns
    if any(prefix in gene for prefix in ion_channel_prefixes)
]

# subset expression matrix to only those genes
ion_expr = log_cpm_df[ion_genes]

print(f"Found {len(ion_genes)} ion channel genes in log_cpm_df.")
print("Top ion channel genes found:", ion_genes[:10])

In [None]:
# double check that index alignment is correct
common_cells = ion_expr.index.intersection(ephys_df.index)
ion_expr = ion_expr.loc[common_cells]
ephys_subset = ephys_df.loc[common_cells][
    ["AP threshold (mV)", "Rheobase (pA)", "Membrane time constant (ms)"]
]
# check for same order
assert all(ion_expr.index == ephys_subset.index), "index order mismatch"

After aligning them, we can now correlate each ion gene with each ephys feature:

In [None]:
# correlation matrix between genes and selected ephys features
cor_matrix = ion_expr.corrwith(
    ephys_subset["AP threshold (mV)"], axis=0
).to_frame(name="AP threshold (mV)")

cor_matrix["Rheobase (pA)"] = ion_expr.corrwith(
    ephys_subset["Rheobase (pA)"], axis=0
)

cor_matrix["Membrane time constant (ms)"] = ion_expr.corrwith(
    ephys_subset["Membrane time constant (ms)"], axis=0
)

# make matrix easier to read + ensure numeric
cor_matrix = cor_matrix.T.astype(float)

# sort genes by correlation strength to Rheobase => makes the plot easier to read
sorted_genes = (
    cor_matrix.loc["Rheobase (pA)"].abs().sort_values(ascending=False).index
)

In [None]:
# visualize the correlations:
plt.figure(figsize=(12, 6))
sns.heatmap(
    cor_matrix[sorted_genes],
    annot=None,
    cmap="vlag",
    center=0,
    fmt="",
)
plt.title("Correlation between Ion Channel Gene Expression and Ephys Features")
plt.xlabel("Ion Channel Genes")
plt.ylabel("Ephys Features")
plt.show()

*Interpretation*

* Rheobase shows the stronges correlations with many genes, which means that expression of these ion channel genes increases with Rheobase => more current is needed to elicit spikes. Genes included here are for example GRIN2A, KCNMA1 and SCN4B
* KCNK family genes (potassium channels) correlate with membrane time constant and AP threshold, at least to a medium degree. This is expected, since they play a known role in setting resting potential and excitability.
* Some SCN genes (sodium channels) also show rather small correlations, especially with AP threshold and Rheobase, which is also expected, since it is consistent with their role in spike initiation.

Overall, most correlations are pretty weak => check for significance!

In [None]:
# compute correlations & p-values
results = []
for gene in ion_expr.columns:
    for feat in ephys_subset.columns:
        r, p = pearsonr(ion_expr[gene], ephys_subset[feat])
        results.append({"gene": gene, "feature": feat, "r": r, "p": p})

cor_df = pd.DataFrame(results)

# FDR correction
cor_df["fdr_pass"], cor_df["qval"] = fdrcorrection(cor_df["p"])

In [None]:
# reuse correlation matrix from before
# this time add significance stars
annot_matrix = pd.DataFrame(index=cor_matrix.index, columns=cor_matrix.columns)

r_matrix = cor_df.pivot(index="feature", columns="gene", values="r")
qval_matrix = cor_df.pivot(index="feature", columns="gene", values="qval")


for row in cor_matrix.index:
    for col in cor_matrix.columns:
        r_val = cor_matrix.at[row, col]
        q_val = qval_matrix.at[
            row, col
        ]

        if pd.isna(r_val) or pd.isna(q_val):
            annot_matrix.at[row, col] = ""
            continue

        stars = ""
        if q_val < 0.001:
            stars = "*\n*\n*"
        elif q_val < 0.01:
            stars = "*\n*"
        elif q_val < 0.05:
            stars = "*"

        annot_matrix.at[row, col] = stars

# plot
plt.figure(figsize=(16, 6))
sns.heatmap(
    cor_matrix[sorted_genes],  # sort by rheobase
    annot=annot_matrix[sorted_genes],  # sorted annotation
    fmt="",
    cmap="vlag",
    center=0,
    annot_kws={"size": 10},
    cbar_kws={"label": "Pearson r"},
)
plt.title(
    "Correlation Between Ion Channel Genes and Ephys Features\n(* = FDR < 0.05)"
)
plt.xlabel("Ion Channel Genes (sorted by Rheobase correlation)")
plt.ylabel("Electrophysiological Features")
plt.show()

This time, significant correlations are marked with 3, 2 or 1 stars, depending on the strength of the significance (q_val < 0.001 / 0.01 / 0.05)

To investigate the significant correlations further, look at the top 6 strongest ones:

In [None]:
# filter + select top 6 FDR-significant correlations
top_hits = (
    cor_df[cor_df["fdr_pass"]]
    .sort_values("r", key=abs, ascending=False)
    .head(6)
    .reset_index(drop=True)
)

# plot scatterplots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

for i, row in top_hits.iterrows():
    gene = row["gene"]
    feat = row["feature"]
    r_val = row["r"]
    q_val = row["qval"]

    x = ion_expr[gene]
    y = ephys_subset[feat]

    sns.scatterplot(x=x, y=y, ax=axes[i], alpha=0.6, edgecolor=None)
    sns.regplot(x=x, y=y, ax=axes[i], scatter=False, color="red", ci=None)

    axes[i].set_title(
        f"{gene} vs. {feat}\n" f"r = {r_val:.2f}, q = {q_val:.1e}", fontsize=12
    )
    axes[i].set_xlabel(f"{gene} expression (log CPM)")
    axes[i].set_ylabel(feat)

plt.tight_layout()
plt.suptitle(
    "Top 6 FDR-Significant Gene–Ephys Correlations", fontsize=16, y=1.02
)
plt.show()

The plot above shows the top 6 statistically significant correlations between ion channel gene expression and ephys features. Numerically, the effect sizes are small. Also, the vertical streaks at 0 mean that a large fraction of cells had undetectable expression for the gene in question. This sparsity might weaken the correlations.
However, the very low q-values (<3e-5) actually suggest pretty strong statistical significance.

Interpretations of the single plots:

* CACNA2D3 vs. Membrane Time Constant: CACNA2D3 encodes a subunit of voltage-gated calcium channels. Higher expression is associated with slower membrane dynamics (increased time constant), which might reflect more sustained calcium-mediated currents or associated modulation of channel density.
* GRIN2A vs Rheobase: GRIN2A encodes an NMDA receptor subunit. Higher expression correlates with higher rheobase. This suggests that NMDA receptor signaling may contribute to increased excitability thresholds or delayed spiking, possibly through tonic depolarization or dendritic filtering.
* CACNA1E vs. Membrane Time Constant: CACNA1E encodes the R-type Ca^2+ channel α1E subunit. Again, we see a positive correlation with the time constant, supporting a general link between Ca^2+ channel expression and slower membrane charging.
* KCNH7/KCNH4 vs Rheobase: KCNH7 and KCNH4 are voltage-gated potassium channel. Our results suggest that more KCNH7 or KCNH4 expression is associated with greater current needed to fire, consistent with increased leak or repolarizing conductance reducing excitability.
* KCNQ1OT1 vs Rheobase: KCNQ1OT1 is not a protein-coding channel itself, but a long non-coding RNA associated with KCNQ1. The foudn correlation might reflect regulation of KCNQ1 expression.

## 5.3 canonical correlation analysis

=> inspired by Lause et al. (2024)

Identify pairs of latent dimensions that co-vary across RNA and e-phys.

In [None]:
# standardize both views
X_rna = StandardScaler().fit_transform(pca_df)
X_ephys = StandardScaler().fit_transform(ephys_df)

# apply CCA
cca = CCA(n_components=5)
X_rna_cca, X_ephys_cca = cca.fit_transform(X_rna, X_ephys)

# visualize CCA components
cca_df = pd.DataFrame(
    np.concatenate([X_rna_cca, X_ephys_cca], axis=1),
    columns=[f"RNA_CCA{i+1}" for i in range(X_rna_cca.shape[1])]
    + [f"EPHYS_CCA{i+1}" for i in range(X_ephys_cca.shape[1])],
    index=pca_df.index,
)

# correlation between CCA components
corr_matrix = pd.DataFrame(
    np.corrcoef(X_rna_cca.T, X_ephys_cca.T)[:5, 5:],
    index=[f"RNA_CCA{i+1}" for i in range(5)],
    columns=[f"EPHYS_CCA{i+1}" for i in range(5)],
)



sns.heatmap(corr_matrix, annot=True, cmap="vlag", center=0)
plt.title("Canonical Correlation between RNA and Ephys components")
plt.show()

The diagonal values in the above heatmap show the canonical correlation between matched components. The red values (around 0.5) suggest that there is a strong linear relationship between the respective genes and ephys traits. The off-diagonal values are very close to zero, which is good, since this suggests that each CCA component pair captures a distinct shared signal and not redundant info.

We've seen that there transcriptomic and ephys data share structure => What genes and what electrophysiological features are responsible for this coupling?

In [None]:
# loadings = contribution of each original variable to the CCA components
gene_loadings = pd.DataFrame(
    cca.x_weights_,
    index=pca_df.columns,  # original RNA PC names
    columns=[f"RNA_CCA{i+1}" for i in range(cca.n_components)],
)

ephys_loadings = pd.DataFrame(
    cca.y_weights_,
    index=ephys_df.columns,
    columns=[f"EPHYS_CCA{i+1}" for i in range(cca.n_components)],
)

# look at top contributing features to RNA_CCA1 and EPHYS_CCA1
top_rna = gene_loadings["RNA_CCA1"].abs().sort_values(ascending=False).head(10)
top_ephys = (
    ephys_loadings["EPHYS_CCA1"].abs().sort_values(ascending=False).head(10)
)

print("Top genes (RNA CCA1):\n", top_rna, sep="")
print("-------------------------", sep="")
print("Top ephys features (Ephys CCA1):\n", top_ephys, sep="")

Below, we visualize cell-level CCA projections, colored by Metadata. This is to see whether biological classes are reflected in shared gene–ephys structure.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

cca_proj = pd.DataFrame(
    {
        "RNA_CCA1": X_rna_cca[:, 0],
        "EPHYS_CCA1": X_ephys_cca[:, 0],
        "Cre": meta_df["Cre"],
    }
)

# custom palette for Cre lines
unique_cres = cca_proj["Cre"].unique()
cre_palette = dict(
    zip(unique_cres, sns.color_palette("tab20", len(unique_cres)))
)

plt.figure(figsize=(10, 6))
sns.scatterplot(
    data=cca_proj,
    x="RNA_CCA1",
    y="EPHYS_CCA1",
    hue="Cre",
    palette=cre_palette,
    s=30,  # uniform point size
    edgecolor="none",  # cleaner look
    alpha=0.7,
)

plt.axhline(0, color="gray", linestyle="--", lw=0.7)
plt.axvline(0, color="gray", linestyle="--", lw=0.7)

plt.title("CCA Projection of Cells (colored by Cre Line)", fontsize=14)
plt.xlabel("RNA CCA Component 1", fontsize=12)
plt.ylabel("Ephys CCA Component 1", fontsize=12)

# place legend outside the plot
plt.legend(
    title="Cre Line",
    bbox_to_anchor=(1.05, 1),
    loc="upper left",
    borderaxespad=0.0,
    fontsize=9,
    title_fontsize=10,
)

plt.grid(True, linestyle="--", alpha=0.3)
plt.tight_layout()
plt.show()

The axes in the plot are the projections onto the first canonical components from CCA. Coloring was done based on Cre line identity, as drawn from metadata.

The plot shows a diagonal trend, which confirms that there is a strong gene-ephys coupling. Also, even though not very clearly, some Cre lines cluster along specific parts of the diagonal. For example, PV-, SST+ and VIP+ occupy somewhat distinct regions, which shows their distinct gene-ephys signatures. This confirms that the canonical axes are biologically meaningful. To sum up, the first CCA component captures a shared axis of transcriptomic and electrophysiological variation that also aligns with genetic Cre-defined cell types.

## 5.4 UMAP over combined RNA PCs and scaled ephys features for joint manifold visualization

In [None]:
from sklearn.preprocessing import StandardScaler
import umap.umap_ as umap
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

meta_for_color = "Cre"  # "RNA family" / "Slice" / "Sequencing batch"

# align indices across data
shared_cells = pca_df.index.intersection(ephys_df.index).intersection(
    meta_df.index
)

# combine RNA PCs + scaled ephys
X_rna = pca_df.loc[shared_cells].values
X_ephys = StandardScaler().fit_transform(ephys_df.loc[shared_cells])
X_combined = np.concatenate([X_rna, X_ephys], axis=1)

# UMAP
reducer = umap.UMAP(
    n_neighbors=15, min_dist=0.3, metric="euclidean", random_state=42
)
embedding = reducer.fit_transform(X_combined)

embedding_df = pd.DataFrame(
    embedding, columns=["UMAP1", "UMAP2"], index=shared_cells
)
embedding_df["ColorGroup"] = meta_df.loc[shared_cells, meta_for_color]

# plot
plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=embedding_df,
    x="UMAP1",
    y="UMAP2",
    hue="ColorGroup",
    s=30,
    alpha=0.7,
    edgecolor=None,
    palette="tab20" if embedding_df["ColorGroup"].nunique() <= 20 else "husl",
)
plt.axhline(0, color="gray", linestyle="--", lw=0.5)
plt.axvline(0, color="gray", linestyle="--", lw=0.5)
plt.title("Joint UMAP: RNA PCs + Electrophysiological Features", fontsize=14)
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title=meta_for_color)
plt.tight_layout()
plt.show()

*Interpretation:*

Some Cre lines (e.g., PV+, VIP+) are more present in specific clusters, suggesting that these lines correspond to functionally and molecularly coherent cell types.
Others (e.g., WT, VIPR2-) are more spread out, which implies broader variability or technical artifacts. Also, the cells aren't clustered purely by RNA or ephys. This shows that both modalities contribute to structure, which is the goal of joint UMAP.

## 5.5 study PCA loadings to see gene contributions

=> look at what the PCs are made of

In [None]:
loadings = pd.DataFrame(
    pca.components_.T,
    index=hvg_genes,
    columns=[f"PC{i+1}" for i in range(pca.n_components_)],
)

In [None]:
top_loadings = {}
for pc in loadings.columns:
    top = (
        loadings[pc].abs().sort_values(ascending=False).head(10).index.tolist()
    )
    top_loadings[pc] = top

In [None]:
top_genes_pc1 = top_loadings["PC1"]  # gene names

# select just those genes' loadings across all PCs (or just PC1)
sns.heatmap(
    loadings.loc[top_genes_pc1],
    cmap="coolwarm",
    annot=True,
    fmt=".2f",
    linewidths=0.5,
    linecolor="gray",
    annot_kws={"size": 5},
    cbar_kws={"label": "Loading Value"},
)

plt.title("Top Gene Loadings for PC1")
plt.xlabel("Principal Components")
plt.ylabel("Genes")
plt.show()

The cells in the plot represent how strongly (positively or negatively) each gene contributes to each PC. Look at the loadings in more detail:

In [None]:
# choose which PCs to visualize
pcs_to_plot = ["PC1", "PC2", "PC3"]  # change or extend this list

# set up subplots
fig, axes = plt.subplots(
    1, len(pcs_to_plot), figsize=(6 * len(pcs_to_plot), 5), sharey=False
)

if len(pcs_to_plot) == 1:
    axes = [axes]  # ensure axes is always iterable

for ax, pc in zip(axes, pcs_to_plot):
    top_genes = loadings[pc].abs().sort_values(ascending=False).head(10).index
    top_values = loadings.loc[top_genes, pc]

    sns.barplot(
        x=top_values.values,
        y=top_genes,
        hue=top_genes,
        ax=ax,
        palette="viridis" if top_values.iloc[0] > 0 else "rocket",
        orient="h",
        legend=False,
    )
    ax.set_title(f"Top 10 Loadings for {pc}")
    ax.set_xlabel("Loading Value")
    ax.set_ylabel("Gene")

plt.tight_layout()
plt.show()

PC1 (left panel):

Genes like TRIM37, HDLPP, PCP4L1, DUSP10 are the top contributors.
All values are positive, suggesting PC1 is dominated by a positively co-expressed gene module.
If PC1 separates samples by cell type, region, or batch, these genes may underlie that variation.

PC2 (middle panel):

Mix of positive and negative loadings (e.g., UCK1: +0.07, RB1CC1: -0.06).
Suggests PC2 contrasts two opposing gene sets: samples with high PC2 scores express UCK1 etc., and low PC2 scores express RB1CC1, PER2 etc.

PC3 (right panel):

Similar contrast between MTUS2/ARID5B (positive) and MAP2K4/VIPAS39 (negative).
May represent a different biological axis (e.g., membrane vs. cytoskeletal genes => open for investigation).

## 5.6 regression models

Predict e-phys features from transcriptomic data:

* Use linear regression, random forests, and elastic net on PCA-reduced gene expression.
* Evaluate using cross-validation (R^2, RMSE).

In [None]:
# features to predict
ephys_targets = [
    "AP threshold (mV)",
    "Rheobase (pA)",
    "Membrane time constant (ms)",
]

models = {
    "Linear Regression": LinearRegression(),
    "Random Forest": RandomForestRegressor(n_estimators=100, random_state=42),
    "Elastic Net": ElasticNet(alpha=0.1, l1_ratio=0.5, random_state=42),
}

results = []

# subplot grid: 3 ephys targets × 3 models
fig, axes = plt.subplots(len(ephys_targets), len(models), figsize=(18, 12))
axes = axes.reshape(len(ephys_targets), len(models))

for i, target in enumerate(ephys_targets):
    # filter out genes with zero std dev to avoid divide-by-zero in correlation
    variable_genes = log_cpm_df.loc[:, log_cpm_df.std() > 0]

    # select top 50 genes correlated with target
    corrs = variable_genes.corrwith(ephys_df[target]).abs()
    top_genes = corrs.sort_values(ascending=False).head(50).index

    X = variable_genes[top_genes]
    y = ephys_df[target]

    # remove missing values
    valid_idx = X.dropna().index.intersection(y.dropna().index)
    X = X.loc[valid_idx]
    y = y.loc[valid_idx]

    # train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42
    )

    for j, (name, model) in enumerate(models.items()):
        pipe = make_pipeline(StandardScaler(), model)

        # cross-validation metrics
        scores = cross_validate(
            pipe, X, y, cv=5, scoring=["r2", "neg_root_mean_squared_error"]
        )
        results.append(
            {
                "Model": name,
                "Ephys Feature": target,
                "Mean R2": scores["test_r2"].mean(),
                "Mean RMSE": -scores[
                    "test_neg_root_mean_squared_error"
                ].mean(),
            }
        )

        # fit & predict on held-out test set
        pipe.fit(X_train, y_train)
        y_pred = pipe.predict(X_test)

        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred) ** 0.5)

        ax = axes[i, j]
        sns.scatterplot(x=y_test, y=y_pred, ax=ax, alpha=0.7)
        ax.plot(
            [y_test.min(), y_test.max()], [y_test.min(), y_test.max()], "r--"
        )
        ax.set_title(f"{target}\n{name}\nR² = {r2:.2f}, RMSE = {rmse:.2f}")
        ax.set_xlabel("True")
        ax.set_ylabel("Predicted")

plt.tight_layout()
plt.suptitle(
    "Predicted vs. True Ephys Features (All Models)", fontsize=18, y=1.03
)
plt.show()

# summary of results
results_df = pd.DataFrame(results)
print(results_df)

The results show a very poor performance of the models, which implies a very weak predictive power. Either transcriptomic data alone is insufficient to predict the ephys traits, or the input features need some very strong refining, which we were not able to do timewise.

## 5.3 Biological Interpretability

### 5.3.1 Identify Genes Most Correlated with Electrophysiological Properties

=> Correlate all genes with Rheobase (pA)

In [None]:
# filter out non-variable genes
variable_genes = log_cpm_df.loc[:, log_cpm_df.std() > 0]

correlations = variable_genes.corrwith(ephys_df["Rheobase (pA)"])

# drop any NaNs that may still result from missing values
correlations = correlations.dropna()

top_genes_pos = correlations.sort_values(ascending=False).head(50)
top_genes_neg = correlations.sort_values(ascending=True).head(50)

# combine for enrichment analysis
top_genes = pd.concat([top_genes_pos, top_genes_neg]).index.tolist()

### 5.3.2 Perform Gene Ontology (GO) or Pathway Enrichment

In [None]:
gp = GProfiler(return_dataframe=True)
results = gp.profile(organism="mmusculus", query=top_genes)

# view top GO terms
results[
    ["name", "p_value", "term_size", "query_size", "intersection_size"]
].head()

In [None]:
# sort by p-value + select top 15 terms
top_terms = results.sort_values("p_value").head(15)

sns.barplot(
    y=top_terms["name"],
    x=-np.log10(top_terms["p_value"]),
    hue=top_terms["name"],
    legend=False,
    palette="viridis",
)
plt.xlabel("-log10(p-value)")
plt.title("Top GO Terms Enriched in Ephys-Linked Genes")
plt.tight_layout()
plt.show()

This plot shows the biological pathways most enriched among the top 100 genes most correlated with Rheobase (pA). The results suggest that genes whose expression levels are highly correlated with Rheobase tend to belong to functional categories critical for neuronal communication. This then supports the idea that electrophysiological phenotypes (like firing threshold) are genetically coupled to fundamental neuronal identity.

# References

- Harris, K. D., Hochgerner, H., Skene, N. G., Magno, L., Katona, L., Bengtsson Gonzales, C., Somogyi, P., Kessaris, N., Linnarsson, S., & Hjerling-Leffler, J. (2018). Classes and continua of hippocampal CA1 inhibitory neurons revealed by single-cell transcriptomics. *PLoS Biology*, 16(6), e2006387. https://doi.org/10.1371/journal.pbio.2006387

- Lause, J., Berens, P., & Kobak, D. (2024). The art of seeing the elephant in the room: 2D embeddings of single-cell data do make sense. *PLoS Computational Biology*, 20(10), e1012403. https://doi.org/10.1371/journal.pcbi.1012403

- Loo, L., Simon, J.M., Xing, L. et al. Single-cell transcriptomic analysis of mouse neocortical development. *Nat Commun* 10, 134 (2019). https://doi.org/10.1038/s41467-018-08079-9

- Luecken, M. D., & Theis, F. J. (2019). Current best practices in single-cell RNA-seq analysis: a tutorial. *Molecular Systems Biology*, 15(6), e8746. https://doi.org/10.15252/msb.20188746

- Tasic, B., Menon, V., Nguyen, T. N., Kim, T. K., Jarsky, T., Yao, Z., ... & Zeng, H. (2016). Adult mouse cortical cell taxonomy revealed by single cell transcriptomics. *Nature neuroscience*, 19(2), 335-346. https://doi.org/10.1038/nn.4216

- Tung, PY., Blischak, J., Hsiao, C. et al. Batch effects and the effective design of single-cell gene expression studies. *Sci Rep* 7, 39921 (2017). https://doi.org/10.1038/srep39921