_Neural Data Science_

Lecturer: Dr. Jan Lause, Prof. Dr. Philipp Berens

Tutors: Jonas Beck, Fabio Seel, Julius Würzler

Summer term 2025

Student names: Nina Lutz, Mathis Nommensen

LLM Disclaimer: <span style='background: yellow'>*Did you use an LLM to solve this exercise? If yes, which one and where did you use it? [Copilot, Claude, ChatGPT, etc.]* </span>

# Project 3: Single-cell data analysis.

In [None]:
#!pip install memory-profiler #N: i needed to install this. delete code block if not needed anymore

In [None]:
# %matplotlib notebook #N: had to change this
%matplotlib inline

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import string

import scipy as sp
from scipy import sparse
import sklearn

## add your packages ##

import time
import pickle
import memory_profiler
import seaborn as sns
import scipy.stats as stats

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
from sklearn.preprocessing import StandardScaler

%load_ext memory_profiler

from pathlib import Path

In [None]:
import black
import jupyter_black

jupyter_black.load(line_length=79)

In [None]:
variables_path = Path("../results/variables")
figures_path = Path("../results/figures")
data_path = Path("../data")

In [None]:
# plt.style.use("matplotlib_style.txt")
plt.style.use("../matplotlib_style.txt")  # N: had to change this as well

In [None]:
np.random.seed(42)

## Project and data description

In this project, we are going to work with the typical methods and pipelines used in single-cell data analysis and get some hands-on experience with the techniques used in the field. For that, we will be using Patch-seq multimodal data from cortical neurons in mice, from Scala et al. 2021 (https://www.nature.com/articles/s41586-020-2907-3#Sec7). From the different data modalities they used, we will focus on transcriptomics and electrophysiological data. 

In a real-world scenario, single cell data rarely comes with any "ground truth" labels. Often, the goal of researchers after measuring cells is to precisely classify them, grouping them into families or assigning them cell types based on the recorded features. This is normally done using usupervised methods, such as clustering methods.

However, the single-cell data that we are using in this project has some cell types assigned to each cell. These are not "ground truth" type annotations, but were one of the results from the original Scala et al. work. Still, we are going to use those annotations for validation (despite them not really being ground truth) to sanity-check some of our analyses, such as visualizations, clustering, etc. We will mainly work with cell types (`rna_types`, 77 unique types) and cell families (`rna_families`, 9 unique families).

From the transcriptomics mRNA counts, we will only work with the exon counts for simplicity. Some of the electrophysiological features are not high-quality recordings, therefore we will also filter them out.

## Import data

### Meta data

In [None]:
# META DATA

meta = pd.read_csv(data_path / "m1_patchseq_meta_data.csv", sep="\t")

cells = meta["Cell"].values

layers = meta["Targeted layer"].values.astype("str")
cre = meta["Cre"].values
yields = meta["Yield (pg/µl)"].values
yields[yields == "?"] = np.nan
yields = yields.astype("float")
depth = meta["Soma depth (µm)"].values
depth[depth == "Slice Lost"] = np.nan
depth = depth.astype(float)
thickness = meta["Cortical thickness (µm)"].values
thickness[thickness == 0] = np.nan
thickness = thickness.astype(float)
traced = meta["Traced"].values == "y"
exclude = meta["Exclusion reasons"].values.astype(str)
exclude[exclude == "nan"] = ""

mice_names = meta["Mouse"].values
mice_ages = meta["Mouse age"].values
mice_cres = np.array(
    [
        c if c[-1] != "+" and c[-1] != "-" else c[:-1]
        for c in meta["Cre"].values
    ]
)
mice_ages = dict(zip(mice_names, mice_ages))
mice_cres = dict(zip(mice_names, mice_cres))

print("Number of cells with measured depth:    ", np.sum(~np.isnan(depth)))
print("Number of cells with measured thickness:", np.sum(~np.isnan(thickness)))
print("Number of reconstructed cells:          ", np.sum(traced))

sliceids = meta["Slice"].values
a, b = np.unique(sliceids, return_counts=True)
assert np.all(b <= 2)
print("Number of slices with two cells:        ", np.sum(b == 2))

# Some consistency checks
assert np.all(
    [
        np.unique(meta["Date"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse age"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse gender"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse genotype"].values[mice_names == m]).size == 1
        for m in mice_names
    ]
)
assert np.all(
    [
        np.unique(meta["Mouse"].values[sliceids == s]).size == 1
        for s in sliceids
    ]
)

In [None]:
meta.columns

### "Ground truth labels"

In [None]:
# filter out low quality cells in term of RNA
print(
    "There are",
    np.sum(meta["RNA family"] == "low quality"),
    "cells with low quality RNA recordings.",
)
exclude_low_quality = meta["RNA family"] != "low quality"

In [None]:
rna_family = meta["RNA family"][exclude_low_quality]
rna_type = meta["RNA type"][exclude_low_quality]

In [None]:
print(len(np.unique(rna_family)))
print(len(np.unique(rna_type)))

In [None]:
pickle_in = open(data_path / "dict_rna_type_colors.pkl", "rb")
dict_rna_type_colors = pickle.load(pickle_in)

In [None]:
rna_type_colors = np.vectorize(dict_rna_type_colors.get)(rna_type)

### Transcriptomic data

In [None]:
# READ COUNTS
data_exons = pd.read_csv(
    data_path / "m1_patchseq_exon_counts.csv.gz", na_filter=False, index_col=0
)

assert all(cells == data_exons.columns)
genes = np.array(data_exons.index)

# filter out low quality cells in term of rna family
exonCounts = data_exons.values.transpose()[exclude_low_quality]
print("Count matrix shape (exon):  ", exonCounts.shape)

In [None]:
# GENE LENGTH

data = pd.read_csv(data_path / "gene_lengths.txt")
assert all(data["GeneID"] == genes)
exonLengths = data["exon_bp"].values

### Electrophysiological features

In [None]:
# EPHYS DATA

ephysData = pd.read_csv(data_path / "m1_patchseq_ephys_features.csv")
ephysNames = np.array(ephysData.columns[1:]).astype(str)
ephysCells = ephysData["cell id"].values
ephysData = ephysData.values[:, 1:].astype("float")
names2ephys = dict(zip(ephysCells, ephysData))
ephysData = np.array(
    [
        names2ephys[c] if c in names2ephys else ephysData[0] * np.nan
        for c in cells
    ]
)

print("Number of cells with ephys data:", np.sum(np.isin(cells, ephysCells)))

assert np.sum(~np.isin(ephysCells, cells)) == 0

In [None]:
# Filtering ephys data

features_exclude = [
    "Afterdepolarization (mV)",
    "AP Fano factor",
    "ISI Fano factor",
    "Latency @ +20pA current (ms)",
    "Wildness",
    "Spike frequency adaptation",
    "Sag area (mV*s)",
    "Sag time (s)",
    "Burstiness",
    "AP amplitude average adaptation index",
    "ISI average adaptation index",
    "Rebound number of APs",
]
features_log = [
    "AP coefficient of variation",
    "ISI coefficient of variation",
    "ISI adaptation index",
    "Latency (ms)",
]

X = ephysData[exclude_low_quality]
print(X.shape)
for e in features_log:
    X[:, ephysNames == e] = np.log(X[:, ephysNames == e])
X = X[:, ~np.isin(ephysNames, features_exclude)]

keepcells = ~np.isnan(np.sum(X, axis=1))
X = X[keepcells, :]
print(X.shape)

X = X - X.mean(axis=0)
ephysData_filtered = X / X.std(axis=0)

In [None]:
# delete later
print(ephysNames.shape)
print(ephysData.shape)
print(ephysData_filtered.shape)

In [None]:
np.sum(np.isnan(ephysData_filtered))

# Research questions to investigate

**1) Inspect the data by computing key statistics.** For RNA counts, you can compute and plot statistics, e.g. total counts per cell, number of expressed genes per cell, mean count per gene, variance per gene, mean-variance relationship... See https://www.embopress.org/doi/full/10.15252/msb.20188746 for common quality control statistics. Keep in mind that the RNA data in this project is read counts, not UMI counts, so it is not supposed to follow a Poisson distribution.> To get an idea of the technical noise in the data, you can plot count distributions of single genes within cell types (like in the lecture). 

Similarly, you can compute and plot statistics over the electrophyiological data. Also, investigate the distribution of "ground truth" labels. Comment about other relevant metadata, and think if you can use it as some external validation for other analyses. If you do use other metadata throughout the project, explain why and what you get out of it. Take into account that certain features may not be very informative for our purposes (e.g. mouse age), so only choose features that provide you with useful information in this context. If you want to get additional information about the metadata, have a look at the extended data section in the original publication (e.g., cre-lines in Figure 1c in the extended data).

**2) Normalize & transform the data; select genes & apply PCA.** There are several ways of normalizing the RNA count data (Raw, CPM, CPMedian, RPKM, see https://www.reneshbedre.com/blog/expression_units.html, https://translational-medicine.biomedcentral.com/articles/10.1186/s12967-021-02936-w). Take into account that there are certain normalizations that only make sense for UMI data, but not for this read count data. You also explored different transformations in the assignment (none, log, sqrt). Compare how the different transformations change the two-dimensional visualization. After normalization and transformation, choose a set of highly variable genes (as demonstrated in the lecture) and apply PCA. Play with the number of selected genes and the number of PCA components, and again compare their effects on the two-dimensional visualization.

**3) Two-dimensional visualization.** To visualize the RNA count data after normalization, transformation, gene selection and PCA, try different methods (just PCA, t-SNE, UMAP, ..) and vary their parameters (exaggeration, perplexity, ..). Compare them using quantitative metrics (e.g., kNN accuracy in high-dim vs. two-dim, kNN recall). Please refer to Lause et al., 2024 (https://doi.org/10.1371/journal.pcbi.1012403) where many of these metrics are discussed and explained to make an informed choice on which metrics to use. Think about also using the electrophysiological features and other metadata to enhance different visualizations.

**4) Clustering.** To find cell types in the RNA count data, you will need to look for clusters. Try different clustering methods (leiden, GMM). Implement a negative binomial mixture model. For that you can follow a similar method that what is described in Harris et al. 2018 (https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.2006387#abstract0), with fixed r (r=2). Feel free to simplify the setup from the paper and not optimize over the set of important genes S but fix it instead, or skip the split and merge part of their clustering algorithm. A vanilla NBMM implementation should suffice. Take into account that the NBMM tries to cluster data that follows a negative binomial distribution. Therefore, it does not make sense to apply this clustering method to all kinds of normalized and transformed data. Please refer to the Harris et al. 2018 publication for the appropriate choice of normalization, and reflect on why this normalization makes sense. Evaluate your clustering results (metrics, compare number of clusters to original labels,...).

**5) Correlation between electrophysiological features and genes/PCs.** Finally, connect RNA counts and functional data: Most likely, there will be interesting relationships between the transcriptomic and electrophyiological features in this data. Find these correlations and a way of visualizing them. In studying correlations using the PCA-reduced version of the transcriptomics data, it could be interesting to study PC loadings to see which genes are dominating which PCs. For other advanced analyses, you can get inspitation from Kobak et al., 2021 (https://doi.org/10.1111/rssc.12494).
    

# Task 1

## 1.1 QC Statistics per cell

RNA Counts and stuff

In [None]:
exonCounts.shape[0]
exonCounts.shape[1]
exonCounts.shape

In [None]:
# total counts per cell (count depth)

# exonCounts  # shape = 1232 cells x 42466 genes
total_counts_per_cell = exonCounts.sum(axis=1)

# plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(
    total_counts_per_cell, bins=50, color="lightblue", edgecolor="black"
)
axes[0].set_xlabel("Total RNA counts per cell")
axes[0].set_ylabel("Number of cells")
axes[0].set_title("Distribution of Total Counts Per Cell")

axes[1].bar(
    range(len(total_counts_per_cell)),
    total_counts_per_cell,
    width=1,
    alpha=0.5,
    color="blue",
)
axes[1].set_yscale("log")
axes[1].set_xlabel("Cell index")
axes[1].set_ylabel("Total RNA counts")
axes[1].set_title("Total Counts Per Cell")

# Rank-ordered plot of total counts per cell - Figure 2C in the paper
sorted_counts = np.sort(total_counts_per_cell)[::-1]

axes[2].plot(sorted_counts, color="blue")
axes[2].set_yscale("log")
axes[2].set_xlabel("Ranked Cell Index")
axes[2].set_ylabel("Total Counts (log scale)")
axes[2].set_title("Rank-ordered Total Counts per Cell")

*Interpretation:*

Left:
* The histogram reveals a right-skewed distribution of total RNA counts per cell, with most cells clustered between 0 and 5 million counts.

Middle:
* shows that most cells have a comparable number of total counts, with a few outliers having significantly higher and lower counts (e.g. cells around index 400 and 500).

Right:
* The rank-ordered curve shows a gradual decay in RNA counts, with a steep drop-off in low-quality cells at the tail—typical of single-cell data 
* most cells have comparable total cells
* the first ~20 cells have a lot more counts than the rest
* the last ~100 cells have very few counts

In [None]:
# number of expressed genes per cell
expressed_genes_per_cell = (exonCounts > 0).sum(axis=1)

# plot
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].scatter(range(len(expressed_genes_per_cell)), expressed_genes_per_cell)
axes[0].set_xlabel("Cell Index")
axes[0].set_ylabel("Expressed Genes")
axes[0].set_title("Total Number of Expressed Genes per Cell")

axes[1].hist(expressed_genes_per_cell, bins=50)
axes[1].set_xlabel("Total number of expressed genes per cell")
axes[1].set_ylabel("Number of cells")
axes[1].set_title("Distribution of Expressed Genes Per Cell")

*Interpretation:*

Left: Total number of expressed genes per cell
* The scatter plot shows considerable variation in the number of genes detected per cell, with a spread ranging from around 2,000 to over 17,000 genes, whereas most cells have between 5000 and 11,000 genes expressed.

Right: Distribution of expressed genes per cell
* Shows a similar pattern as the left plot, with most cells having between 2000 and 11,000 genes expressed.

In [None]:
# fraction of mitochondrial genes

# mt_like_genes = [g for g in genes if "mt" in g.lower()]
# print(mt_like_genes)
# after checking the gene names, we can assume that mitochondrial genes start with "mt-" or "MT-"
mt_gene_mask = np.char.startswith(genes.astype(str).astype("U"), "mt-")

print("Number of mitochondrial genes found:", np.sum(mt_gene_mask))

# Sum counts over mitochondrial genes per cell
mt_counts_per_cell = exonCounts[:, mt_gene_mask].sum(axis=1)

# Fraction mitochondrial
fraction_mito = mt_counts_per_cell / total_counts_per_cell

# Print statistics
print("Mean mitochondrial fraction:", np.mean(fraction_mito))
print("Cells with >20% mitochondrial:", np.sum(fraction_mito > 0.2))

# Plot
plt.figure(figsize=(8, 4))
plt.hist(fraction_mito, bins=50, color="salmon", edgecolor="black")
plt.xlabel("Fraction of Mitochondrial Counts per Cell")
plt.ylabel("Number of Cells")
plt.title("Distribution of Mitochondrial RNA Content")
plt.show()

*Interpretation:*

* most cells have a low fraction of mitochondrial genes (~1.7%), with a few outliers having a higher fraction (e.g. cells around index 400 and 500).

In [None]:
"""could be deleted later"""

# Look for any gene names containing 'mt' or 'MT'
mt_like_genes = [g for g in genes if "mt" in g.lower()]
print(mt_like_genes)

## 1.2 QC Statistics per gene


In [None]:
# mean expression across all cells
mean_expression_across_cells = exonCounts.mean(axis=0)
print("Mean expression across all cells:", mean_expression_across_cells.shape)
# variance across all cells
variance_expression_across_cells = exonCounts.var(axis=0)
print("Variance across all cells:", variance_expression_across_cells.shape)

# plot looks sparse, so log transform the data

fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(
    mean_expression_across_cells + 1e-3,  # avoid log(0)
    variance_expression_across_cells + 1e-3,
    s=5,
    alpha=0.5,
)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("Mean Expression Across Cells (log scale)")
ax.set_ylabel("Variance Across Cells (log scale)")
ax.set_title("Mean vs Variance of Gene Expression (log scale)")
x_vals = np.linspace(
    min(mean_expression_across_cells + 1e-3),
    max(mean_expression_across_cells),
    100,
)
ax.plot(x_vals, x_vals, color="red", linestyle="--", label="y = x")
ax.legend()
plt.show()

*Interpretation:*

* The log-log plot reveals a clear mean-variance relationship, where genes with higher mean expression across cells also tend to have higher variance
* Almost all genes lie above the y=x line (variance = mean), indicating that the variance is generally higher than the mean for most genes (overdispersion).

In [None]:
# dropout rate / fraction of cells where a gene has zero counts

dropout_rate_per_gene = (exonCounts == 0).sum(axis=0) / exonCounts.shape[0]

# Quick summary
print("Mean dropout rate across all genes:", np.mean(dropout_rate_per_gene))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot Histogram of dropout rates
axes[0].hist(dropout_rate_per_gene, bins=50, color="green", edgecolor="black")
axes[0].set_xlabel("Dropout Rate (Fraction of Cells with Zero Counts)")
axes[0].set_ylabel("Number of Genes")
axes[0].set_title("Dropout Rate per Gene")

# Scatter plot of dropout rate vs mean expression
axes[1].scatter(
    mean_expression_across_cells + 1e-3,  # avoid log(0)
    dropout_rate_per_gene,
    s=5,
    alpha=0.5,
)
axes[1].set_xscale("log")
axes[1].set_ylim(0, 1.05)
axes[1].set_xlabel("Mean Expression (log scale)")
axes[1].set_ylabel("Dropout Rate")
axes[1].set_title("Dropout Rate vs. Mean Expression")

plt.show()

*Interpretation:*

Left:
* the majority of genes have a very high dropout rate (mean ~83%), with a large number expressed in fewer than 15-20% of cells and a peak near 100% dropout, i.e. many genes are completely unexpressed in most cells

Right:
* Strong inverse relationship: Highly expressed genes are detected in nearly all cells (low dropout), while lowly expressed genes are often undetected (high dropout)

**count distributions of single genes within cell types**

In [None]:
# identify cell types
unique_cell_types = np.unique(rna_type)

# pick genes:
# Sst: somatostatin, an inhibitory neuron marker
# Slc17a7: an excitatory neuron marker (VGLUT1)
# Snap25: a pan-neuronal gene
genes_to_plot = ["Sst", "Slc17a7", "Snap25"]
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 8))

for ax, gene_name in zip(axes, genes_to_plot):
    gene_indices = np.where(genes == gene_name)[0]
    gene_index = gene_indices[0]

    # get gene counts
    gene_counts = exonCounts[:, gene_index]

    # group counts by cell type
    data = []
    for ct in unique_cell_types:
        mask = np.array(rna_type) == ct
        counts_in_group = gene_counts[mask]
        data.append(counts_in_group)

    ax.boxplot(
        data, tick_labels=unique_cell_types, showfliers=False
    )  # the showfliers=False option removes extreme outliers
    ax.set_title(f"Expression of {gene_name}")

    ###M: added logscale to y-axis
    ax.set_yscale("log")  # log scale for better visibility

    ax.set_ylabel("Counts")
    ax.set_xlabel("Cell Type")
    ax.tick_params(axis="x", rotation=90, labelsize=8)

# plt.tight_layout()
plt.show()

*Interpretation:*

Sst:
* Expression is highly enriched in specific inhibitoy cell types (notably those starting with "Sst"), occasionally enriched expression in other celltypes, but no broader pattern visible.

Slc17a7:
* Expression is enriched in many cells, but most obviously in cell types from L2 to L6, which belong to a family of excitatory neurons. This suggests that Slc17a7 is a marker for excitatory neurons in the cortex, as it is known to be expressed in glutamatergic neurons.

Snap25:
* This gene shows widespread expression across nearly all cell types, reflecting its function as a general neuronal marker involved in synaptic transmission

## 1.3 Statistics for electrophysiological features

In [None]:
# features overview

# List of cleaned electrophysiological features retained for analysis
ephysNames_filtered = ephysNames[
    ~np.isin(ephysNames, features_exclude)
]  # see above

print(
    "Remaining electrophysiological features (n = {}) for analysis".format(
        len(ephysNames_filtered)
    )
)
for i, name in enumerate(ephysNames_filtered, 1):
    print(f"{i:2d}. {name}")

In [None]:
# descriptive statistics of ephysiological features (keep in mind that data is already standardized)

# dictionary to collect stats
stats_dict = {
    "Mean": [],
    "Std": [],
    "Min": [],
    "Max": [],
    "Median": [],
    "Skewness": [],
}

# Compute stats per feature
### replace all Xs with ephysData_filtered
# for i in range(X.shape[1]):
for i in range(ephysData_filtered.shape[1]):
    # data = X[:, i]
    data = ephysData_filtered[:, i]

    # Collect statistics
    stats_dict["Mean"].append(np.mean(data))
    stats_dict["Std"].append(np.std(data))
    stats_dict["Min"].append(np.min(data))
    stats_dict["Max"].append(np.max(data))
    stats_dict["Median"].append(np.median(data))
    stats_dict["Skewness"].append(stats.skew(data))

# Convert to DataFrame
feature_stats_df = pd.DataFrame(stats_dict, index=ephysNames_filtered)

print("Basic statistics of electrophysiological features (standardized):")
display(feature_stats_df)

*Interpretation:*

Overall:
* mean and standard deviation are (roughly) 0 and 1, respectively, as the data is standardized
* Min/Max values show range in z-score units

Mentionable Observations:
* ISI adaptation index (ske=2.48), Rebound (mV) (2.05), Sag ratio (+2.06), and Rheobase (pA) (+1.64) are strongly right-skew, indicating that most cells have low values, but a smaller subset show much higher values.
* AP amplitude adaptation index (skew = -1.13) are left-skewed, indicating that most cells have high values (i.e. high adaptation), but a smaller subset show much lower values.
* Membrane time constant, Input resistance, and Max number of APs have high max values (>3.8 z-score), indicating large variability among cells in how they respond to current input.
* Several features show approximately symmetric distributions, like AP threshold (skew = −0.06), Afterhyperpolarization (mV) (skew = +0.03) and Resting membrane potential (skew = −0.12). This suggests that these properties are well-centered and consistent across cells.


In [None]:
n_features = len(ephysNames_filtered)
n_cols = 4  # Number of columns in the grid
n_rows = int(
    np.ceil(n_features / n_cols)
)  # Number of rows in the grid, rounded up

fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))

axes = axes.flatten()

for i, feature in enumerate(ephysNames_filtered):
    ax = axes[i]
    sns.histplot(
        # X[:, i],
        ephysData_filtered[:, i],
        bins=30,
        kde=True,
        color="skyblue",
        stat="density",
        ax=ax,
    )
    sns.kdeplot(
        # X[:, i], color="darkblue", linewidth=1, ax=ax
        ephysData_filtered[:, i],
        color="darkblue",
        linewidth=1,
        ax=ax,
    )  # smoothed density plot (darkblue line)

    ax.axvline(0, color="red", linestyle="--", linewidth=1)
    ax.set_title(feature, fontsize=9, fontweight="bold")
    ax.set_xlabel("Standardized value")
    ax.set_ylabel("Density")

# Remove empty subplots
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

fig.suptitle(
    "Distributions of Standardized Electrophysiological Features",
    fontsize=14,
    fontweight="bold",
)
plt.tight_layout(rect=[0, 0, 1, 0.97])

plt.show()

*Interpretation:*
* means (0) and standard deviations (1) are consistent with standardized data
* skewness is clearly visible
* generally, the distributions show clearly what the statistics suggest: all features are centered around zero with unit variance, and the observed skewness in the plots aligns with the computed skewness values, highlighting the diversity in distribution shapes.

In [None]:
# DataFrame with standardized ephys data
# ephys_df = pd.DataFrame(X, columns=ephysNames_filtered)
ephys_df = pd.DataFrame(ephysData_filtered, columns=ephysNames_filtered)

# Compute Pearson correlation matrix
corr_matrix = ephys_df.corr()

# Test
print("Correlation matrix shape:", corr_matrix.shape)

# "high correlation" threshold
threshold = 0.6  # whats the best value?
high_corr_pairs = []

# Iterate over the upper triangle of the correlation matrix (excluding the diagonal) because the matrix is symmetric
for i in range(len(corr_matrix.columns)):
    for j in range(i + 1, len(corr_matrix.columns)):
        corr_value = corr_matrix.iloc[i, j]
        if abs(corr_value) > threshold:
            feature_i = corr_matrix.columns[i]
            feature_j = corr_matrix.columns[j]
            high_corr_pairs.append((feature_i, feature_j, corr_value))

# Print results
print(
    "Feature pairings with high correlations (|corr| > {:.2f}):".format(
        threshold
    )
)
for feature1, feature2, corr in sorted(
    high_corr_pairs, key=lambda x: -abs(x[2])
):
    print(f"{feature1:30s} x {feature2:30s}: {corr:+.2f}")

In [None]:
# Plotting the correlation matrix
plt.figure(figsize=(12, 10))
sns.heatmap(
    # sns.clustermap(  # clustermap clusters features based on correlation, kann einkommentiert werden
    corr_matrix,
    annot=True,
    fmt=".2f",
    cmap="vlag",
    center=0,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.8},
)
plt.title("Pairwise Pearson Correlation of Electrophysiological Features")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.show()

*Interpretation:*

Positive correlations:
* AP width (ms) x Upstroke-to-downstroke ratio: +0.94. Suggests these may capture overlapping aspects of action potential shape.
* Afterhyperpolarization (mV) x Upstroke-to-downstroke ratio: +0.67. Larger APs may lead to deeper afterhyperpolarizations, consistent with biophysical expectations
* AP amplitude (mV) x Afterhyperpolarization (mV): +0.67. Larger APs may lead to deeper afterhyperpolarizations, possibly reflecting cell type-specific properties or biophysical mechanisms.

Negative correlations:
* AP width (ms) x Max number of APs: -0.73. Cells firining more APs tend to have narrower APs, possibly indicating a trade-off between firing rate and action potential width.
* ISI coefficient of variation x Max number of APs: -0.70. More consistent ISIs are found in cells with higher firing rates, hinting at firing pattern specialization.
* Max number of APs x Upstroke-to-downstroke ratio: -0.68. Suggests that cells with more APs tend to have lower upstroke-to-downstroke ratios, possibly indicating differences in firing patterns or cell types.

Generally, highly correlated features might be redundant or be functionally related, which will be the main part of the next tasks.


## 1.4 Investigate Ground Truth Labels

In [None]:
# Analyze and visualize ground truth distribution
# ground truth data:
# - rna_type: 77 trancriptomic subtypes
# - rna_family: 9 broader transcriptomic families

# Barplot of RNA types
rna_type_counts = pd.Series(rna_type).value_counts()

plt.figure(figsize=(12, 5))
sns.barplot(
    x=rna_type_counts.index,
    y=rna_type_counts.values,
    hue=rna_type_counts.index,
    palette="muted",
    dodge=False,
    legend=False,
)
plt.xticks(rotation=90)
plt.title("RNA Types by Cell Count")
plt.ylabel("Number of Cells")
plt.xlabel("RNA Type")
# plt.tight_layout()
plt.show()

# Barplot of RNA families
plt.figure(figsize=(8, 4))
sns.countplot(
    x=rna_family,
    hue=rna_family,
    order=np.unique(rna_family),
    palette="Set2",
    legend=False,
)
plt.title("Distribution of RNA Families")
plt.ylabel("Number of Cells")
plt.xlabel("RNA Family")
plt.xticks(rotation=45)
# plt.tight_layout()
plt.show()

*Interpretation:*

Top: RNA types by cell count
* barplot shows the number of cells annotated with each RNA type
* only a few rna types are considerably highly represented (e.g.Pvalb Itrap2, L2/3 IT_2/)
* many RNA types have fewer than 10 cells, suggesting class imbalance --> could impact clustering and classification methods 

Bottom: RNA Families by cell count
* the plot aggregates rna types into broader families, which are essentially coarser labels
* Families like IT, Pvalb and Sst are dominating, whereas others like NP and Sncp are very underrepresented
* Compared to the top plot, the family distribution is more balanced, with fewer families having very few cells, yet the class imbalance is still present


## 1.5 Metadata

In [None]:
# print metadata labels
print("Metadata labels:")
for col in meta.columns:
    print(f"- {col}")

**TODO: das sollten wir nochmal überarbeiten bzw. genauer formulieren**


Generally, we want to focus on metadata that improves analysis quality. Not all recorded metadata is useful for our purposes, so we will focus on the following features:
* sequencing batch - technical batch effects can confound biological signals, so we will use this to control for batch effects in our analyses
* targeted layer
* inferred layer
* soma depth
* cortical thickness
* cre lines

Specifically interesting:

Cre Line:
* Cre lines indicate genetic targeting strategies, which can enrich for specific neuron types
* They help validate whether observed patterns are biologically driven or due to experimental targeting

Inferred layer:
* Cortical layers correspond to distinct cell populations and functions; many electrophysiological and transcriptomic differences are layer-specific.
* It could serve as a meaningful covariate or stratification variable

Sequencing batch:
* Sequencing batch can introduce technical variability (batch effects) unrelated to biology
* Identifying batch effects can help to avoid spurious conclusions and ensure generalizability



In [None]:
# Distribution of Cre lines
cre_counts = meta["Cre"][exclude_low_quality].value_counts()
print("Cre line distribution:\n", cre_counts)

# Plot
plt.figure(figsize=(10, 4))
sns.countplot(
    y=meta["Cre"][exclude_low_quality],
    order=cre_counts.index,
    palette="muted",
    hue=meta["Cre"][exclude_low_quality],
    legend=False,
)
plt.title("Distribution of Cre Lines")
plt.xlabel("Cell Count")
plt.ylabel("Cre Line")
# plt.tight_layout()
plt.show()

# Cross-tab with RNA type
# cre_rna_crosstab = pd.crosstab(meta["Cre"][exclude_low_quality], rna_type)
# print("Cre Line vs RNA Type Crosstab:")
# display(cre_rna_crosstab)

In [None]:
# Distribution of cortical layers
layer_counts = meta["Inferred layer"][exclude_low_quality].value_counts()
print("Inferred layer distribution:\n", layer_counts)

# Plot
plt.figure(figsize=(8, 4))
sns.countplot(
    x=meta["Inferred layer"][exclude_low_quality],
    order=layer_counts.index,
    palette="coolwarm",
    hue=meta["Inferred layer"][exclude_low_quality],
    legend=False,
)
plt.title("Distribution of Inferred Cortical Layers")
plt.xlabel("Inferred Layer")
plt.ylabel("Number of Cells")
# plt.tight_layout()
plt.show()

In [None]:
# Distribution of batches
batch_counts = meta["Sequencing batch"][exclude_low_quality].value_counts()
plt.figure(figsize=(8, 4))
sns.barplot(
    x=batch_counts.index.astype(str),
    y=batch_counts.values,
    palette="coolwarm",
    hue=batch_counts.index.astype(str),
    legend=False,
)
plt.title("Distribution of Sequencing Batches")
plt.xlabel("Sequencing Batch (ordered by cell count)")
plt.ylabel("Cell Count")
# plt.tight_layout()
plt.show()

#### Metadata Conclusions

Cre Lines:
* The distribution shows several dominant lines (e.g. SST+, PV+), with many others represented sparsely.
* A large number of unique Cre lines suggests some are too underrepresented for individual modeling but may help explain biological variability when grouped or used selectively.

Inferred layer:
* The distribution is uneven, with most cells in layer 5, followed by 2/3 and 6, and very few in layer 1.
* This may impact how well certain layers are represented in modeling, and highlights that layer-specific trends could be explored in later tasks.

Sequencing batch:
* The barplot revealed an uneven distribution across batches, ranging from very large (e.g. batch 8) to tiny (batch 12).

**How they could be used in further analyses?:**

Cre Lines:
* use after clustering --> compare predicted clusters to Cre line distributions
* helps biologically validate clusters, e.g. a cluster dominated by SST+ Cre lines might represents inhibitory neurons (task 4)
* can examine how cre lines associate with PCs or gene expression signatures (task 5)

Inferred layer:
* use inferred layer as a coloring option to see if spatial cell identity is reflected in transcriptomic space (task 3)
* could be added as a covariate or group in regression or correlation analyses to see how gen-PC correlations differ across layers (task 5)

Sequencing batch:
* use batch information to check for batch effects after normalization and transformation 
* could color 2D visualizations (e.g. PCA/t-SNE) by batch to visually inspect clustering by batch, which is undesired and indicates technical confounding (task 2)
* use batch labels as negative control labels: ideally batches should not be separable in a meaningful embedding (task 3)





# Task 2

In [None]:
# preparation steps
# we want to use meta data
# therefore we need to filter the metadata to only include the cells we are interested in
# this is done by excluding the low quality cells which is only already done for rna_type and rna_family
# this was already done for exonCounts
# to avoid misalignment, we reset indices so that the indices of the meta data match the indices of the exonCounts

# RNA counts
exonCounts_filtered = exonCounts
# könnte man sich sparen, da exonCounts schon gefiltert ist, aber so ist es konsistent mit den anderen Variablen

# Ground truth labels
rna_type_filtered = rna_type.reset_index(drop=True)
rna_family_filtered = rna_family.reset_index(drop=True)

# Filter Metadata variables of interest
cre_filtered = meta["Cre"][exclude_low_quality].reset_index(drop=True)
layer_filtered = meta["Inferred layer"][exclude_low_quality].reset_index(
    drop=True
)
batch_filtered = meta["Sequencing batch"][exclude_low_quality].reset_index(
    drop=True
)

assert (
    exonCounts_filtered.shape[0]
    == len(rna_type_filtered)
    == len(cre_filtered)
    == len(layer_filtered)
    == len(batch_filtered)
)

## 2.1 normalization + transformation

log-CPM normalization - appropriately tailored for read count data

In [None]:
# exonCounts.shape = (n_cells, n_genes)
# genes = list of gene names
# rna_types = list of cell names (?)

# compute total amount of reads per cell
total_counts_per_cell = exonCounts_filtered.sum(axis=1)

# avoid dividing by zero by setting any 0s to 1
total_counts_per_cell[total_counts_per_cell == 0] = 1

# NORMALIZATION
# calculate counts per million
cpm = (
    exonCounts_filtered.T / total_counts_per_cell
).T * 1e6  # transpose to divide along columns (genes)

# TRANSFORMATION
# log-transform with log(1+x)
log_cpm = np.log1p(cpm)
# sqrt-transform with sqrt(x)
sqrt_cpm = np.sqrt(cpm)

# create dataframe
cpm_df = pd.DataFrame(cpm, index=rna_type_filtered, columns=genes)
log_cpm_df = pd.DataFrame(log_cpm, index=rna_type_filtered, columns=genes)
sqrt_cpm_df = pd.DataFrame(sqrt_cpm, index=rna_type_filtered, columns=genes)

print("Log-CPM shape:", log_cpm.shape)
print("Sqrt-CPM shape:", sqrt_cpm.shape)
assert log_cpm_df.shape[0] == len(rna_type_filtered), "check for alignment"
assert sqrt_cpm_df.shape[0] == len(rna_type_filtered), "check for alignment"

In [None]:
# show how normalization changes total RNA counts
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].bar(
    range(len(total_counts_per_cell)),
    total_counts_per_cell,
    width=1,
    alpha=0.5,
    color="blue",
)
axes[0].set_yscale("log")
axes[0].set_xlabel("Cell index")
axes[0].set_ylabel("Raw RNA counts")
axes[0].set_title("Total RNA Counts Per Cell")

# log_cpm_df.shape = (n_cells, n_genes)
total_cpm_counts_per_cell = cpm.sum(axis=1)


axes[1].bar(
    range(len(total_cpm_counts_per_cell)),
    total_cpm_counts_per_cell,
    width=1,
    alpha=0.5,
    color="blue",
)
axes[1].set_yscale("log")
axes[1].set_xlabel("Cell index")
axes[1].set_ylabel("Normalized RNA counts")
axes[1].set_title("Total Normalized RNA Counts Per Cell")

=> the raw counts are larger numbers whereas the log-CPM shows a compressed scale

In [None]:
# show change after log transform

fig, axes = plt.subplots(1, 3, figsize=(10, 4))

axes[0].scatter(
    mean_expression_across_cells,
    variance_expression_across_cells,
    s=5,
    alpha=0.5,
)
axes[0].set_xscale("log")
axes[0].set_yscale("log")
axes[0].set_xlabel("Mean Expression Across Cells (log scale)")
axes[0].set_ylabel("Variance Across Cells (log scale)")
axes[0].set_title("Mean vs Variance of Gene Expression")


mean_log_expression = log_cpm.mean(axis=0)
var_log_expression = log_cpm.var(axis=0)


axes[1].scatter(
    mean_log_expression,
    var_log_expression,
    s=5,
    alpha=0.5,
)
axes[1].set_xscale("log")
axes[1].set_yscale("log")
axes[1].set_xlabel("Mean Expression Across Cells (log scale)")
axes[1].set_ylabel("Variance Across Cells (log scale)")
axes[1].set_title("Mean vs Variance of Gene Expression\nAfter Log-Transform")

# show change after sqrt transform
mean_sqrt_expression = sqrt_cpm.mean(axis=0)
var_sqrt_expression = sqrt_cpm.var(axis=0)


axes[2].scatter(
    mean_sqrt_expression,
    var_sqrt_expression,
    s=5,
    alpha=0.5,
)
axes[2].set_xscale("log")
axes[2].set_yscale("log")
axes[2].set_xlabel("Mean Expression Across Cells (log scale)")
axes[2].set_ylabel("Variance Across Cells (log scale)")
axes[2].set_title("Mean vs Variance of Gene Expression\nAfter Sqrt-Transform")

# add identity lines
x_vals = np.linspace(1e-3, 1e4, 100)  # log scale range
axes[0].plot(
    x_vals, x_vals, linestyle="--", color="red", label="variance = mean"
)
x_vals = np.linspace(1e-4, 1e1, 100)  # log scale range
axes[1].plot(
    x_vals, x_vals, linestyle="--", color="red", label="variance = mean"
)
x_vals = np.linspace(1e-4, 1e2, 100)  # log scale range
axes[2].plot(
    x_vals, x_vals, linestyle="--", color="red", label="variance = mean"
)

axes[0].legend()
axes[1].legend()
axes[2].legend()

*Interpretation:*

Left: Raw CPM (before transformation)
* variance increases rapidly with mean, indicating overdispersion
* many genes show much higher variance than mean, deviating strongly from the Poisson assumption variance = mean. But this is expected for raw count data, as it is not supposed to follow a Poisson distribution.
* Strong heteroscedasticity (variance depends on the magnitude of the mean). Could hint that PCA might not be the best choice for this data, as PCA assumes homoscedasticity (constant variance across all means).

Middle: after log transformation
* cloud of points is more compact, with less spread in variance for high means
* points are now more evenly distributed around the identity line, indicating that the log transformation has stabilized variance and made the data more homoscedastic (e.g. compress extreme values and reduce skew)
* log transformed data looks now more suitable for PCA

Right: after sqrt transformation
* despite not changing the overall shape much, sqrt transform has also stabilized variance and made the data more homoscedastic
* sqrt transformed data also looks more suitable for PCA

## 2.2 select genes

In [None]:
# mask out genes with zero mean
mask_raw = mean_expression_across_cells > 0.1

# compute fano factor for raw counts or cpm
fano = (
    variance_expression_across_cells[mask_raw]
    / mean_expression_across_cells[mask_raw]
)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# plot raw data fano
axes[0].scatter(
    mean_expression_across_cells[mask_raw],
    fano,
    s=5,
    alpha=0.5,
)
axes[0].set_xscale("log")
axes[0].set_yscale("log")
axes[0].set_xlabel("Mean Expression Across Cells")
axes[0].set_ylabel("Fano Factor (Var / Mean)")
axes[0].set_title("Fano Factor")

##### x-scale seems wrong???

# plot variance vs mean
mask_log = mean_log_expression > 0
axes[1].scatter(
    mean_log_expression[mask_log],
    var_log_expression[mask_log],
    s=5,
    alpha=0.5,
)
axes[1].set_xscale("log")
axes[1].set_yscale("log")
axes[1].set_xlabel("Mean (log CPM)")
axes[1].set_ylabel("Variance (log CPM)")
axes[1].set_title("Variance vs Mean After Log Transform")

plt.tight_layout()
plt.show()

*Interpretation:*

Left:
* 

In [None]:
# choose highly variable genes (hvgs) from normalized data
# we DO NOT use log-transformed data here, since that would flatten the mean-variance relationship
# (and we want to find the genes with high absolute variance)
dispersion = variance_expression_across_cells / (
    mean_expression_across_cells + 1e-8
)

# remove genes with very low mean (avoid divide by zero)
filtered_dispersion = dispersion[mask_raw]
filtered_mean = mean_expression_across_cells[mask_raw]

# play with the number of selected genes:
# n_top_genes = 2000
# n_top_genes = 1000
n_top_genes = 500

# 2000 HVGs is the most common default, but even 1000 lead to a very high necessary number of principle components in PCA later on.
# Fewer than 500 HVGs risks missing subtle biological signals, so we chose 500 as the lowest "possible" number.

# top_indices = np.argsort(filtered_dispersion)[-n_top_genes:]
top_500_indices = np.argsort(filtered_dispersion)[-500:]
top_1000_indices = np.argsort(filtered_dispersion)[-1000:]

# get gene names
# hvg_genes = np.array(genes)[mask_raw][top_indices]
hvg_genes_500 = np.array(genes)[mask_raw][top_500_indices]
hvg_genes_1000 = np.array(genes)[mask_raw][top_1000_indices]

# print(hvg_genes)

In [None]:
# plot mean-variance with hvgs highlighted

# mask of genes that are in hvg list
is_hvg_500 = np.isin(genes, hvg_genes_500)
is_hvg_1000 = np.isin(genes, hvg_genes_1000)

# fig, ax = plt.subplots(figsize=(6, 5))
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# for 500 hvgs
# all genes
axes[0].scatter(
    mean_expression_across_cells[mask_raw],
    variance_expression_across_cells[mask_raw],
    s=5,
    alpha=0.4,
    label="all genes",
    color="lightblue",
)

# overlay HVGs
axes[0].scatter(
    mean_expression_across_cells[is_hvg_500],
    variance_expression_across_cells[is_hvg_500],
    s=10,
    alpha=0.7,
    # label="HVGs",
    label=f"Top 500 HVGs by dispersion",
    color="darkblue",
)

# for 1000 hvgs
# all genes
axes[1].scatter(
    mean_expression_across_cells[mask_raw],
    variance_expression_across_cells[mask_raw],
    s=5,
    alpha=0.4,
    label="all genes",
    color="lightblue",
)

# overlay HVGs
axes[1].scatter(
    mean_expression_across_cells[is_hvg_1000],
    variance_expression_across_cells[is_hvg_1000],
    s=10,
    alpha=0.7,
    # label="HVGs",
    label=f"Top 1000 HVGs by dispersion",
    color="darkblue",
)

for ax in axes:
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("Mean Expression Across Cells")
    ax.set_ylabel("Variance Across Cells")
    ax.set_title("HVGs Highlighted in Mean–Variance Plot")
    ax.legend()
plt.tight_layout()
plt.show()

# TODO (?) Add a legend explaining HVG criteria (e.g. “top 2000 by dispersion”)
# M: Fixed it

## 2.3 PCA

note: use normalized + log-transformed data here!

In [None]:
# find indices of HVGs
hvg_indices_500 = np.where(np.isin(genes, hvg_genes_500))[0]
hvg_indices_1000 = np.where(np.isin(genes, hvg_genes_1000))[0]

# subset the expression matrix: shape = (cells, HVGs)
hvg_expression_500 = exonCounts[:, hvg_indices_500]
hvg_expression_1000 = exonCounts[:, hvg_indices_1000]

# use normalized + transformed data from before
log_hvg_df_500 = log_cpm_df.iloc[:, hvg_indices_500].copy()
sqrt_hvg_df_500 = sqrt_cpm_df.iloc[:, hvg_indices_500].copy()
log_hvg_df_1000 = log_cpm_df.iloc[:, hvg_indices_1000].copy()
sqrt_hvg_df_1000 = sqrt_cpm_df.iloc[:, hvg_indices_1000].copy()

# non-transformed data
raw_hvg_df_500 = cpm_df.iloc[:, hvg_indices_500].copy()
raw_hvg_df_1000 = cpm_df.iloc[:, hvg_indices_1000].copy()

In [None]:
from sklearn.decomposition import PCA

# fit PCA + play with the number of components
pca_log_500 = PCA(
    n_components=400  # we chose the lowest possible number to still explain at least 60% of variance
)
pca_log_1000 = PCA(
    n_components=400  # we chose the lowest possible number to still explain at least 60% of variance
)
X_pca_log_500 = pca_log_500.fit_transform(log_hvg_df_500)
X_pca_log_1000 = pca_log_1000.fit_transform(log_hvg_df_1000)

pca_sqrt_500 = PCA(
    n_components=400  # we chose the lowest possible number to still explain at least 60% of variance
)
pca_sqrt_1000 = PCA(
    n_components=400  # we chose the lowest possible number to still explain at least 60% of variance
)
X_pca_sqrt_500 = pca_sqrt_500.fit_transform(sqrt_hvg_df_500)
X_pca_sqrt_1000 = pca_sqrt_1000.fit_transform(sqrt_hvg_df_1000)

# non-transformed data
pca_raw_500 = PCA(
    n_components=400  # we chose the lowest possible number to still explain at least 60% of variance
)
pca_raw_1000 = PCA(
    n_components=400  # we chose the lowest possible number to still explain at least 60% of variance
)
X_pca_raw_500 = pca_raw_500.fit_transform(raw_hvg_df_500)
X_pca_raw_1000 = pca_raw_1000.fit_transform(raw_hvg_df_1000)

# ab hier entweder für alle oder raushauen?
assert (
    len(rna_type_filtered) == X_pca_log_500.shape[0]
), "Mismatch between RNA types and PCA output"

print("PCA result shape:", X_pca_log_500.shape)  # (num_cells, n_components)
print("PCA result shape:", X_pca_sqrt_1000.shape)  # (num_cells, n_components)

In [None]:
# plot explained variance to compare
explained_log_500 = pca_log_500.explained_variance_ratio_
explained_sqrt_500 = pca_sqrt_500.explained_variance_ratio_
explained_log_1000 = pca_log_1000.explained_variance_ratio_
explained_sqrt_1000 = pca_sqrt_1000.explained_variance_ratio_

explained_raw_500 = pca_raw_500.explained_variance_ratio_
explained_raw_1000 = pca_raw_1000.explained_variance_ratio_

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(
    np.cumsum(explained_log_1000), color="blue", label="log (1000 HVGs)"
)
axes[0].plot(
    np.cumsum(explained_sqrt_1000), color="orange", label="sqrt (1000 HVGs)"
)
axes[0].plot(
    np.cumsum(explained_raw_1000), color="green", label="raw (1000 HVGs)"
)
axes[1].plot(
    np.cumsum(explained_log_500), color="blue", label="log (500 HVGs)"
)
axes[1].plot(
    np.cumsum(explained_sqrt_500), color="orange", label="sqrt (500 HVGs)"
)
axes[1].plot(
    np.cumsum(explained_raw_500), color="green", label="raw (500 HVGs)"
)

axes[0].axhline(0.9, color="red", linestyle="--", label="90% threshold")
axes[1].axhline(0.9, color="red", linestyle="--", label="90% threshold")

for ax in axes:
    ax.set_xlabel("Number of Components")
    ax.set_ylabel("Cumulative Explained Variance")
    ax.set_title("Explained Variance by PCA Components")
    ax.legend()
    ax.grid(True)
plt.tight_layout()
plt.show()

Looking at raw data first, the curve is very steep and the first few components already explain most of the variance. However, this is most likely not biologically meaningful variance, since raw CPM values have dominant technical noise like sequencing depth or cell size. For this reason, we chose not to use the raw data in the following analyses.

For 1000 HVGs and log-transform, 390 components would be needed to capture ~90% of the variance. Sqrt-transform performs only a little better, around 275 components would be needed here. Even for 500 HVGS, way above 100 PCs are needed to explain 90% of the variance. This means that the variance is spread out across many dimensions and most individual PCs explain only a small amount of variance. Even though more PCs can capture more information, most of that may not be biologically meaningful, so choosing around fewer PCs still makes more sense.

Since sqrt-transformed data seems to capture the most variance, we will use that in the following. For 500 HVGs, 50 PCs can account for ~80% of the variance, which seems fair given that it’s better to retain fewer PCs that reflect meaningful signal than to overfit with noise.

N: ABER: log more commonly used in bio + better at reducing noise... t-SNE with log pca start performed sliiiightly better than with sqrt start. kind of want to change everything back to log... on the other hand, sqrt is apparently better for umap (maybe try both there!)

#### TODO: go back to log...?

In [None]:
from sklearn.preprocessing import LabelEncoder

# labels_rna_type = LabelEncoder().fit_transform(rna_type_filtered)
labels_rna_family_filtered = LabelEncoder().fit_transform(rna_family_filtered)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# plot the first 2 components for rna type and family
# PCA colored by RNA Type
# Get color values for the filtered RNA types
rna_type_colors_filtered = np.vectorize(dict_rna_type_colors.get)(
    rna_type_filtered
)

scatter1 = axes[0].scatter(
    X_pca_sqrt_500[:, 0],
    X_pca_sqrt_500[:, 1],
    c=rna_type_colors_filtered,
    s=5,
    alpha=0.6,
)
axes[0].set_title("Colored by RNA Type")
axes[0].set_xlabel("PC1")
axes[0].set_ylabel("PC2")
axes[0].grid(True)

# PCA colored by RNA Family
scatter2 = axes[1].scatter(
    X_pca_sqrt_500[:, 0],
    X_pca_sqrt_500[:, 1],
    c=labels_rna_family_filtered,
    cmap="tab10",
    s=5,
    alpha=0.6,
)
axes[1].set_title("Colored by RNA Family")
axes[1].set_xlabel("PC1")
axes[1].set_ylabel("PC2")
axes[1].grid(True)
plt.suptitle("PCA of Highly Variable Genes (HVGs)")


plt.tight_layout()
plt.show()

In [None]:
##### N: hier nochmal kompakt mit parametern variiert. das zeigt genau gar keine veränderung?? wie soll man die im 2D überhaupt sehen??

#### M: würde sagen, einfach nur die Anzahl der HVGs variieren und PC variation für UMAP und t-SNE 'aufheben'
from sklearn.preprocessing import StandardScaler

# Set parameters
hvg_counts = [500, 1000, 2000]
# pca_components = [15, 25, 50]

# Set up figure
fig, axes = plt.subplots(len(hvg_counts), 2, figsize=(10, 9))
fig.suptitle(
    "Effect of HVG Count on PCA – RNA Type vs RNA Family", fontsize=16
)

# Loop through combinations
for i, n_hvg in enumerate(hvg_counts):
    # HVG selection
    dispersion = variance_expression_across_cells / (
        mean_expression_across_cells + 1e-8
    )
    mask = mean_expression_across_cells > 0.1

    filtered_dispersion = dispersion[mask]
    top_indices = np.argsort(filtered_dispersion)[-n_hvg:]
    hvg_genes = np.array(genes)[mask][top_indices]

    # Subset & scale
    sqrt_hvg_df = sqrt_cpm_df[hvg_genes]
    data = StandardScaler().fit_transform(sqrt_hvg_df.values)

    # PCA to 2D
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(data)

    # Plot rna_type
    axes[i, 0].scatter(
        X_pca[:, 0], X_pca[:, 1], c=rna_type_colors, s=10, alpha=0.6
    )
    axes[i, 0].set_title(f"{n_hvg} HVGs\nColored by RNA Type")
    axes[i, 0].set_xlabel("PC1")
    axes[i, 0].set_ylabel("PC2")

    # Plot rna_family
    axes[i, 1].scatter(
        X_pca[:, 0],
        X_pca[:, 1],
        c=labels_rna_family_filtered,
        cmap="tab10",
        s=10,
        alpha=0.6,
    )
    axes[i, 1].set_title(f"{n_hvg} HVGs\nColored by RNA Family")
    axes[i, 1].set_xlabel("PC1")
    axes[i, 1].set_ylabel("PC2")

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

In [None]:
# pip install umap-learn scikit-learn

# Task 3

## 3.1 Visualization

### 3.1.1 t-SNE

In [None]:
# prepare data
# log_hvg_df = log_cpm_df[hvg_genes]
sqrt_hvg_df = sqrt_cpm_df[hvg_genes_500]
X_hvg = StandardScaler().fit_transform(sqrt_hvg_df.values)

# Sanity checks
assert X_hvg.shape[0] == len(rna_type_filtered)
assert X_hvg.shape[0] == len(rna_family_filtered)

# filter log-normalized gene expression to the same 1224 cells
X_hvg_filtered = X_hvg[keepcells, :]  ### hier dann gar nicht verwendet?

# Define dimensionality reduction methods and parameters
methods = {
    "t-SNE (perp=5)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=5, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
    "t-SNE (perp=30)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=30, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
    "t-SNE (perp=100)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=100, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
}

# Create subplots
"""n_methods = len(methods)
fig, axes = plt.subplots(1, n_methods, figsize=(5 * n_methods, 5))

# Loop through and plot
for ax, (title, conf) in zip(axes, methods.items()):
    X_embedded = conf["func"](X_hvg)

    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_family_filtered,
        # hue=rna_type,  # color by cell type
        # c=ephys_df["AP width (ms)"],  # v1
        # c=ephys_df["Max number of APs"],  # v2
        # c=ephys_df["AP amplitude (mV)"],  # v3
        # c=ephys_df["Latency (ms)"],  # v4
        palette="tab10",
        # cmap="inferno",  # change colormap here
        s=30,
        alpha=0.7,
        legend=False,
        ax=ax,
    )

    ax.set_title(title)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")"""

### M: dachte mir ich probier das mal aus, aber das muss nicht die finale version sein

# Create subplots: 2 rows (RNA type + RNA family) x 3 columns (t-SNE variants)
fig, axes = plt.subplots(2, len(methods), figsize=(16, 9))
fig.suptitle(
    "t-SNE Projections Colored by RNA Type and RNA Family", fontsize=16
)

for col_idx, (title, conf) in enumerate(methods.items()):
    X_embedded = conf["func"](X_hvg)

    # Row 0: RNA type
    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_type_filtered,
        palette=dict_rna_type_colors,
        s=15,
        alpha=0.7,
        legend=False,
        ax=axes[0, col_idx],
    )
    axes[0, col_idx].set_title(f"{title} – RNA Type")
    axes[0, col_idx].set_xlabel("Dim 1")
    axes[0, col_idx].set_ylabel("Dim 2")

    # Row 1: RNA family
    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_family_filtered,
        palette="tab10",
        s=15,
        alpha=0.7,
        legend=False,
        ax=axes[1, col_idx],
    )
    axes[1, col_idx].set_title(f"{title} – RNA Family")
    axes[1, col_idx].set_xlabel("Dim 1")
    axes[1, col_idx].set_ylabel("Dim 2")

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
plt.tight_layout()
plt.show()

The comparison shows how t-SNE's perplexity parameter affects local versus global structure:
* perplexity=5: clusters are small and tightly packed => emphasis on fine local relationships
* perplexity=30: balance between local detail & global structure => seems to provide the best overview of the data
* perplexity=100: the plot appears more globally smoothed and there is less distinct clustering => highlighting broader structure at the expense of local detail

### 3.1.2 UMAP

In [None]:
import umap.umap_ as umap


# Define dimensionality reduction methods and parameters
methods = {
    "UMAP (n=15)": {
        "func": lambda X: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=42
        ).fit_transform(X)
    },
    "UMAP (n=30)": {
        "func": lambda X: umap.UMAP(
            n_neighbors=30, min_dist=0.3, random_state=42
        ).fit_transform(X)
    },
    "UMAP (n=50)": {
        "func": lambda X: umap.UMAP(
            n_neighbors=50, min_dist=0.3, random_state=42
        ).fit_transform(X)
    },
}

# Create subplots
n_methods = len(methods)
fig, axes = plt.subplots(1, n_methods, figsize=(5 * n_methods, 5))

# Loop through and plot
for ax, (title, conf) in zip(axes, methods.items()):
    X_embedded = conf["func"](X_hvg)

    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_family,
        # hue=rna_type,  # color by cell type
        # c=ephys_df["AP width (ms)"],  # v1
        # c=ephys_df["Max number of APs"],  # v2
        # c=ephys_df["AP amplitude (mV)"],  # v3
        # c=ephys_df["Latency (ms)"],  # v4
        palette="tab10",
        # cmap="inferno",  # change colormap here
        s=30,
        alpha=0.7,
        legend=False,
        ax=ax,
    )

    ax.set_title(title)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")

plt.tight_layout()
# plt.title("UMAP with different n, colored according to AP width")
plt.show()

The comparison shows how the 'n_neighbors' parameter affects the balance between local and global structure:

n=15: the plot reveals finer substructures and potential cell states

n=30 / n=50:
* clusters are more connected or blended.
* increasing `n_neighbors` favors global structure, but makes us lose finer distinctions
* useful to highlight broad cell types instead of subtypes

=> smaller 'n_neighbors' values are better for identifying subpopulations, while larger values show more global cell type structure


### 3.1.3 visual comparison of PCA, t-SNE, UMAP

In [None]:
import umap.umap_ as umap

# Prepare your data
# log_hvg_df = log_cpm_df[hvg_genes]
# sqrt_hvg_df = sqrt_cpm_df[hvg_genes]
# X = StandardScaler().fit_transform(log_hvg_df.values)

# Define dimensionality reduction methods and parameters
methods = {
    "PCA": {"func": lambda X: PCA(n_components=2).fit_transform(X)},
    "t-SNE (perp=30)": {
        "func": lambda X: TSNE(
            n_components=2, perplexity=30, max_iter=1000, random_state=42
        ).fit_transform(X)
    },
    "UMAP (n=15)": {
        "func": lambda X: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=42
        ).fit_transform(X)
    },
}

# Create subplots
n_methods = len(methods)
fig, axes = plt.subplots(1, n_methods, figsize=(5 * n_methods, 5))

# Loop through and plot
for ax, (title, conf) in zip(axes, methods.items()):
    X_embedded = conf["func"](X_hvg)

    sns.scatterplot(
        x=X_embedded[:, 0],
        y=X_embedded[:, 1],
        hue=rna_family,  # rna_type,  # color by cell type
        # c=ephys_df["AP width (ms)"],  ["Max number of APs"] ["AP amplitude (mV)"] ["Latency (ms)"]
        # cmap="inferno",  # change colormap here
        palette="tab10",
        # c=rna_type_colors,
        s=30,
        alpha=0.7,
        legend=False,
        ax=ax,
    )

    ax.set_title(title)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")

plt.tight_layout()
plt.show()

Interpretation across visualization types: 

PCA

* no clear clusters.
* seems to capture global variance, but fails to separate subtle biological subpopulations
=> rather use for noise reduction / preprocessing

t-SNE (`n_neighbors`=30)

* multiple well-separated clusters.
* seems to reveal local structure, separating even small subpopulations
* downside: t-SNE is not ideal for preserving global relationships

UMAP (`n_neighbors`=15)

* distinct clusters, some continuity
* seems to balances local and global structure, showing gradients and groupings


=> further proceedings:

* Use PCA as input to UMAP/t-SNE
* Use UMAP or t-SNE for cluster discovery or cell type visualization.


TODO:we can try coloring by RNA type, cluster ID, or marker expression to aid interpretation.


## 3.2 Comparison using Quantitative Metrics

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import recall_score
from sklearn.cluster import KMeans
import umap.umap_ as umap
from sklearn.metrics import silhouette_score, adjusted_mutual_info_score

""" def safe_tsne_10d(X, name="t-SNE"):
    if X.shape[0] < 50 or X.shape[1] < 10:
        raise ValueError(f"{name}: Too few samples or features for 10D t-SNE")
    return TSNE(n_components=10, perplexity=30, random_state=42).fit_transform(
        X
    ) """


def evaluate_knn_projection(
    cpm_df,
    hvg_genes,
    labels,
    n_neighbors=10,
    test_size=0.3,
    random_state=42,
    pca_components_for_tsne_umap=50,
):
    """
    Compare dimensionality reduction methods using kNN classification metrics.

    Parameters:
        cpm_df (DataFrame): CPM-normalized+transformed gene expression (cells × genes)
        hvg_genes (list): list of highly variable gene names
        labels (array-like): class labels for each cell (e.g., RNA type)
        n_neighbors (int): number of neighbors for kNN
        test_size (float): train/test split size
        random_state (int): reproducibility seed
        pca_components_for_tsne_umap (int): number of PCs to feed into t-SNE/UMAP

    Returns:
        DataFrame with accuracy and recall for each method.
    """
    # Standardize input data
    # X_hvg = StandardScaler().fit_transform(log_cpm_df[hvg_genes].values)
    X_hvg = StandardScaler().fit_transform(
        cpm_df[hvg_genes].values
    )  # eigentlich auch oben schon gemacht...
    # X_hvg =X_hvg_filtered
    y = np.array(labels)

    # Precompute PCA-reduced input for non-linear methods
    X_pca_for_embedding = PCA(
        n_components=pca_components_for_tsne_umap
    ).fit_transform(X_hvg)

    # print("X shape:", X.shape)
    # print("X_pca_for_embedding shape:", X_pca_for_embedding.shape)

    methods = {
        "High-dimensional (raw)": lambda X: X,
        "PCA (2D)": lambda X: PCA(n_components=2).fit_transform(X),
        "PCA (10D)": lambda X: PCA(n_components=10).fit_transform(X),
        # "PCA (25D)": lambda X: PCA(n_components=25).fit_transform(X), # wäre "unfair" hier 25D zu nehmen lol
        "t-SNE (2D)": lambda X: TSNE(
            n_components=2, perplexity=30, random_state=random_state
        ).fit_transform(X),
        "t-SNE (2D from PCA)": lambda _: TSNE(
            n_components=2, perplexity=30, random_state=random_state
        ).fit_transform(X_pca_for_embedding),
        "t-SNE (10D)": lambda X: TSNE(
            n_components=10,
            perplexity=30,
            method="exact",
            random_state=random_state,
        ).fit_transform(X),
        "t-SNE (10D from PCA)": lambda _: TSNE(
            n_components=10,
            perplexity=30,
            method="exact",
            random_state=random_state,
        ).fit_transform(X_pca_for_embedding),
        "UMAP (2D)": lambda X: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=random_state
        ).fit_transform(X),
        "UMAP (2D from PCA)": lambda _: umap.UMAP(
            n_neighbors=15, min_dist=0.1, random_state=random_state
        ).fit_transform(X_pca_for_embedding),
        "UMAP (10D from PCA)": lambda _: umap.UMAP(
            n_neighbors=15,
            min_dist=0.1,
            n_components=10,
            random_state=random_state,
        ).fit_transform(X_pca_for_embedding),
    }

    results = []
    for name, func in methods.items():
        try:
            X_proj = func(X_hvg)
            X_train, X_test, y_train, y_test = train_test_split(
                X_proj,
                y,
                test_size=test_size,
                random_state=random_state,
            )
            print(f"{name}: X_proj shape = {X_proj.shape}")

            clf = KNeighborsClassifier(n_neighbors=n_neighbors)
            clf.fit(X_train, y_train)
            acc = clf.score(X_test, y_test)
            y_pred = clf.predict(X_test)
            avg_recall = recall_score(
                y_test, y_pred, average="macro", zero_division=0
            )

            # compute silhouette score and AMI as in Lause, Berens, Kobak (2024):

            # silhouette score
            sil = silhouette_score(X_proj, y)
            # unsupervised clustering for AMI
            n_clusters = len(np.unique(y))
            cluster_labels = KMeans(
                n_clusters=n_clusters, random_state=random_state
            ).fit_predict(X_proj)
            ami = adjusted_mutual_info_score(y, cluster_labels)

            results.append(
                {
                    "Method": name,
                    "kNN Accuracy": acc,
                    "kNN Recall (avg)": avg_recall,
                    "Silhouette Score": sil,
                    "AMI": ami,
                }
            )
        except Exception as e:
            results.append(
                {
                    "Method": name,
                    "kNN Accuracy": None,
                    "kNN Recall (avg)": None,
                    "Silhouette Score": None,
                    "AMI": None,
                    "Error": str(e),
                }
            )

    return pd.DataFrame(results)

In [None]:
print("cpm_df shape:", sqrt_cpm_df.shape)
print("Sample of columns:", log_cpm_df.columns[:5])
print("Sample of hvg_genes:", hvg_genes[:5])

In [None]:
# results_df = evaluate_knn_projection(log_cpm_df, hvg_genes, rna_type)
results_df = evaluate_knn_projection(
    cpm_df=sqrt_cpm_df,
    hvg_genes=hvg_genes,
    labels=rna_type,
    n_neighbors=10,
    pca_components_for_tsne_umap=50,
)
print(results_df)

GANZ AM ENDE WERTE NOCHMAL ANPASSEN!

Best Overall (based on all metrics): t-SNE (10D from PCA):
* kNN Accuracy: 0.46
* Recall: 0.26
* AMI: 0.44 (highest)
* Silhouette: ~-0.21

kNN-performance of the models:
* t-SNE (10D) is our best performer in terms of accuracy
* t-SNE (10D from PCA) has the highest recall
* PCA (2D) performs poorly — linear reduction to 2D doesn't seem to work well for this data
* UMAP is competitive but doesn't benefit much from going to 10D
* Raw high-dimensional data still performs surprisingly well — showing that your data has strong native structure.
* a kNN accuracy of around 40% is not high but seems acceptable given the biological data

Overall low silhouette scores:
All low-dimensional methods have negative silhouette scores, meaning clusters may be overlapping or poorly separated.
This is common in t-SNE/UMAP, as they optimize local structure rather than global clustering.

AMI scores:
Most AMI scores surpass 0.4, which indicates moderately good label-cluster alignment.

In [None]:
# plots kNN accuracy vs. recall for different dimensionality reduction methods
### N: weiß nich ob wir brauchen/wollen?

# from adjustText import adjust_text


def plot_knn_accuracy_vs_recall(results_df, figsize=(8, 6)):
    """
    Parameters:
        results_df (DataFrame): must contain columns 'Method', 'kNN Accuracy', and 'kNN Recall (avg)'
        figsize (tuple): size of the plot (width, height)
    """
    methods = results_df["Method"]
    accuracy = results_df["kNN Accuracy"]
    recall = results_df["kNN Recall (avg)"]

    plt.figure(figsize=figsize)
    texts = []

    # plot points
    for i in range(len(methods)):
        plt.scatter(accuracy[i], recall[i], s=200, label=methods[i])
        plt.text(accuracy[i], recall[i], methods[i], fontsize=8, rotation=15)

    plt.xlabel("kNN Accuracy")
    plt.ylabel("kNN Recall (macro avg)")
    plt.title("kNN Accuracy vs. Recall for Dimensionality Reduction Methods")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
plot_knn_accuracy_vs_recall(results_df)

!!! umformulieren


* t-SNE (2D) nearly matches the accuracy of the high-dimensional data while even slightly improving average recall — showing it's preserving neighborhood structure quite well in 2D.
* UMAP (2D) performs reasonably, a bit behind t-SNE, but still far ahead of PCA.
* PCA (2D) performs poorly — this confirms PCA doesn't capture non-linear relationships or local structure well in low dimensions.
* High-dimensional kNN is strongest overall (as expected), but not dramatically better than t-SNE — suggesting your 2D t-SNE visualizations are biologically meaningful.

In [None]:
# define 2D projections
X_pca2 = PCA(n_components=2).fit_transform(X_hvg)
X_tsne2 = TSNE(
    n_components=2, perplexity=30, method="exact", random_state=42
).fit_transform(X_hvg)
X_umap2 = umap.UMAP(
    n_components=2, n_neighbors=15, min_dist=0.1, random_state=42
).fit_transform(X_hvg)

In [None]:
# plot models from above
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

titles = ["PCA (2D)", "t-SNE (2D)", "UMAP (2D)"]
embeddings = [X_pca2, X_tsne2, X_umap2]

for ax, title, embed in zip(axes, titles, embeddings):
    sns.scatterplot(
        x=embed[:, 0],
        y=embed[:, 1],
        hue=rna_family,  # rna_type,  # or rna_type
        palette="tab10",
        s=30,
        alpha=0.8,
        legend=False,
        ax=ax,
    )
    ax.set_title(title)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")

plt.tight_layout()
plt.show()

# Task 4

# Task 5

## 5.1 prepare data

* transcriptomics data is already normalized, hvgs were selected, PCA applied
* electrophysiological data: standardize/normalize (z-score), maybe apply PCA

then:
* match cells across modalities (=align them via cell IDs)

In [None]:
print(meta.columns)

In [None]:
# Set the index properly on meta
meta_df = meta.copy()
meta_df["Cell"] = meta_df["Cell"].astype(str).str.strip()
meta_df = meta_df.set_index("Cell")

# Now set the same index on RNA and ephys
sqrt_cpm_df.index = meta_df.index[: sqrt_cpm_df.shape[0]]
ephys_df.index = meta_df.index[: ephys_df.shape[0]]

In [None]:
shared_cells = sqrt_cpm_df.index.intersection(ephys_df.index).intersection(
    meta_df.index
)
print("Number of shared cells:", len(shared_cells))

sqrt_cpm_df = sqrt_cpm_df.loc[shared_cells]
ephys_df = ephys_df.loc[shared_cells]
meta_df = meta_df.loc[shared_cells]

print("✅ Final aligned shapes:")
print("RNA:", sqrt_cpm_df.shape)
print("Ephys:", ephys_df.shape)
print("Meta:", meta_df.shape)

In [None]:
# Standardize gene expression data
# X_hvg = StandardScaler().fit_transform(log_cpm_df[hvg_genes]) # alt
X_hvg = StandardScaler().fit_transform(
    sqrt_cpm_df[hvg_genes]
)  # schon oben gemacht => nein, muss nochmal neu, damit die dimensionen passen

# Run PCA
pca = PCA(n_components=25)  # choose number of PCs to keep
X_pca = pca.fit_transform(X_hvg)

# Convert to DataFrame for easier correlation
pca_df = pd.DataFrame(
    X_pca,
    columns=[f"PC{i+1}" for i in range(X_pca.shape[1])],
    index=sqrt_cpm_df.index,
)

## 5.2 explore relationships

* correlation analysis
* regression models

**canonical correlation analysis?**

* Identify pairs of latent dimensions that co-vary across RNA and e-phys.
* Seurat and Scanpy can run CCA.

**clustering / dimensionality reduction?**
* Jointly embed both modalities (e.g., MOFA+, MultiVI, or Concatenated PCA).
* Then cluster cells and see if these clusters correspond to known cell types or exhibit interesting features.

### 5.2.1 correlation analysis

* compute correlations between electrophysiological features and PCA components
* compute correlations between electrophysiological features and genes
* => visualize top hits using heatmaps or scatter plots.

In [None]:
# for all features vs all PCs:
cor_matrix_full = pd.DataFrame(index=ephys_df.columns, columns=pca_df.columns)
for ef in ephys_df.columns:
    for pc in pca_df.columns:
        cor_matrix_full.loc[ef, pc] = ephys_df[ef].corr(pca_df[pc])
cor_matrix_full = cor_matrix_full.astype(float)

In [None]:
plt.figure(figsize=(12, 8))
sns.heatmap(
    cor_matrix_full, annot=True, cmap="vlag", center=0, annot_kws={"size": 7}
)
plt.title(
    "Correlation between electrophysiological features and PCA components"
)
plt.xlabel("Principal Components")
plt.ylabel("Ephys Features")
fig.tight_layout()
plt.show()

In [None]:
# Convert all column names to uppercase to standardize for matching
sqrt_cpm_df.columns = sqrt_cpm_df.columns.str.upper()

# Define ion channel gene families of interest
ion_channel_prefixes = ["KCNK", "SCN", "HCN"]

# Select genes that contain any of the ion channel prefixes
ion_genes = [
    gene
    for gene in sqrt_cpm_df.columns
    if any(prefix in gene for prefix in ion_channel_prefixes)
]

# Subset expression matrix to only those genes
ion_expr = sqrt_cpm_df[ion_genes]

print(f"Found {len(ion_genes)} ion channel genes in sqrt_cpm_df.")

In [None]:
# Make sure index alignment is correct
common_cells = ion_expr.index.intersection(ephys_df.index)
ion_expr = ion_expr.loc[common_cells]
ephys_subset = ephys_df.loc[common_cells][
    ["AP threshold (mV)", "Rheobase (pA)", "Membrane time constant (ms)"]
]

In [None]:
# Compute correlation matrix between genes and selected ephys features
cor_matrix = ion_expr.corrwith(
    ephys_subset["AP threshold (mV)"], axis=0
).to_frame(name="AP threshold (mV)")

cor_matrix["Rheobase (pA)"] = ion_expr.corrwith(
    ephys_subset["Rheobase (pA)"], axis=0
)

cor_matrix["Membrane time constant (ms)"] = ion_expr.corrwith(
    ephys_subset["Membrane time constant (ms)"], axis=0
)

# Transpose to make it easier to read
cor_matrix = cor_matrix.T

In [None]:
plt.figure(figsize=(12, 6))
sns.heatmap(cor_matrix, annot=True, cmap="vlag", center=0, fmt=".2f")
plt.title("Correlation between Ion Channel Gene Expression and Ephys Features")
plt.xlabel("Ion Channel Genes")
plt.ylabel("Ephys Features")
fig.tight_layout()
plt.show()

*Interpretation*

TODO: BISSLE UMFORMULIEREN

* Rheobase has the most pronounced correlations (up to ~0.13), suggesting some ion channel genes might influence current threshold.
* KCNK family genes (potassium channels) appear to correlate modestly with membrane time constant and AP threshold, as expected due to their role in setting resting potential and excitability.
* Some SCN genes (sodium channels) also show modest correlations, especially with AP threshold and Rheobase, which is consistent with their role in spike initiation.

In [None]:
print(ephys_df.columns)

In [None]:
# plot scatter plots for strongest gene-feature pairs

# identify strongest correlations
cor_df = cor_matrix.stack().reset_index()
cor_df.columns = ["Ephys Feature", "Gene", "Pearson r"]
cor_df["abs_r"] = cor_df["Pearson r"].abs()
top_corrs = cor_df.sort_values("abs_r", ascending=False).head(6)

# plot scatterplots for top gerne-feature pairs
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

for i, (_, row) in enumerate(top_corrs.iterrows()):
    ephys_feature = row["Ephys Feature"]
    gene = row["Gene"]
    r_val = row["Pearson r"]

    # extract data
    x = sqrt_cpm_df[gene]
    y = ephys_df[ephys_feature]

    # scatter plot with regression line
    sns.scatterplot(x=x, y=y, ax=axes[i], alpha=0.6, edgecolor=None)
    sns.regplot(x=x, y=y, ax=axes[i], scatter=False, color="red", ci=None)

    axes[i].set_title(f"{gene} vs. {ephys_feature}\nPearson r = {r_val:.2f}")
    axes[i].set_xlabel(f"{gene} expression (log CPM)")
    axes[i].set_ylabel(ephys_feature)

plt.tight_layout()
plt.suptitle("Top 6 Correlated Gene-Feature Pairs", fontsize=16, y=1.02)
plt.show()

!!! hier sind immer ganz viele Werte bei 0? why? und how to fix?

In [None]:
# look at statistical significance of correlations
from scipy.stats import pearsonr

# Select genes of interest
ion_channel_genes = [
    "OBSCN",
    "SCN4B",
    "SCN8A",
]  # list of genes => which ones do we want to look at??

# Subset gene expression
X_genes = sqrt_cpm_df[ion_channel_genes].loc[ephys_df.index]

# Subset ephys features
ephys_features = [
    "AP threshold (mV)",
    "Rheobase (pA)",
    "Membrane time constant (ms)",
]
X_ephys = ephys_df[ephys_features]

# Initialize results
cor_matrix = pd.DataFrame(index=ephys_features, columns=ion_channel_genes)
pval_matrix = pd.DataFrame(index=ephys_features, columns=ion_channel_genes)

# Compute correlations and p-values
for feature in ephys_features:
    for gene in ion_channel_genes:
        corr, pval = pearsonr(X_genes[gene], X_ephys[feature])
        cor_matrix.at[feature, gene] = corr
        pval_matrix.at[feature, gene] = pval

# Optionally, mask non-significant values
significance_mask = pval_matrix.astype(float) < 0.05
annot_matrix = cor_matrix.round(2).astype(str)
annot_matrix[~significance_mask] = ""

In [None]:
# heatmap to visualize significance

plt.figure(figsize=(10, 4))
sns.heatmap(
    cor_matrix.astype(float),
    annot=annot_matrix,
    fmt="s",
    cmap="vlag",
    center=0,
    cbar_kws={"label": "Pearson r"},
)
plt.title(
    "Ion Channel Gene Expression vs Ephys Features (Significant Correlations Annotated)"
)
plt.xlabel("Ion Channel Genes")
plt.ylabel("Ephys Features")
fig.tight_layout()
plt.show()

SCN8A shows a positive correlation (~0.13) with Rheobase, which may suggest that higher expression of SCN8A (a voltage-gated sodium channel gene) is associated with increased current required to elicit an action potential.

OBSCN expression correlates positively with membrane time constant, so structural genes might be linked to passive membrane properties.

### 5.2.2 study PCA loadings to see gene contributions

In [None]:
""" loadings = pd.DataFrame(
    pca.components_.T,
    index=sqrt_cpm_df.columns,
    columns=[f"PC{i+1}" for i in range(pca.n_components_)],
) """

In [None]:
""" top_loadings = {}
for pc in loadings.columns:
    top = loadings[pc].abs().sort_values(ascending=False).head(10).index.tolist()
    top_loadings[pc] = top
 """

### 5.2.3 regression models

Predict e-phys features from transcriptomic data:

* Use linear regression, random forests, or elastic net on PCA-reduced gene expression.
* Evaluate using cross-validation (R², RMSE).

Partial Least Squares Regression (PLSR) is popular for this kind of analysis.

In [None]:
from sklearn.linear_model import LinearRegression, ElasticNet
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_validate
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import pandas as pd

# Recompute HVGs from the current sqrt_cpm_df
gene_variances = sqrt_cpm_df.var().sort_values(ascending=False)
hvg_genes = gene_variances.head(2000).index.tolist()

top_n = 100
hvg_subset = hvg_genes[:top_n]
X_hvg = StandardScaler().fit_transform(sqrt_cpm_df[hvg_subset])

X = X_hvg  # X_pca
ephys_targets = [
    "AP threshold (mV)",
    "Rheobase (pA)",
    "Membrane time constant (ms)",
]

# Define models
models = {
    "Linear Regression": LinearRegression(),
    "Random Forest": RandomForestRegressor(n_estimators=100, random_state=42),
    "Elastic Net": ElasticNet(alpha=0.1, l1_ratio=0.5, random_state=42),
}

# Store results
results = []

# Loop through each ephys feature
for feature in ephys_targets:
    y = ephys_df[feature].values

    for name, model in models.items():
        pipe = make_pipeline(StandardScaler(), model)
        scores = cross_validate(
            pipe, X, y, cv=5, scoring=["r2", "neg_root_mean_squared_error"]
        )
        results.append(
            {
                "Model": name,
                "Ephys Feature": feature,
                "Mean R2": scores["test_r2"].mean(),
                "Mean RMSE": -scores[
                    "test_neg_root_mean_squared_error"
                ].mean(),
            }
        )

results_df = pd.DataFrame(results)
print(results_df)

=> very poor performance at the moment :/

## 5.3 Biological Interpretability

Identify genes most correlated with electrophysiological properties.
Perform gene ontology (GO) enrichment or pathway analysis on these genes.
Validate known relationships: e.g., do cells with fast APs express SCN1A highly?
Explore novel marker relationships for specific firing behaviors.

# References

Lause, J., Berens, P., & Kobak, D. (2024). The art of seeing the elephant in the room: 2D embeddings of single-cell data do make sense. *PLoS Computational Biology*, 20(10), e1012403. https://doi.org/10.1371/journal.pcbi.1012403