In [None]:
import numpy as np
import pertdata as pt
import torch
from scipy import sparse

# Load and Preprocess the scRNA-seq Data

Show available datasets from the `pertdata` package:

In [None]:
datasets_dict = pt.datasets()
print("Available datasets:")
for key in datasets_dict.keys():
    print(f"  {key}")

Load the "NormanWeissman2019_filtered" dataset:

In [None]:
NormanWeissman2019_filtered = pt.PertDataset(name="NormanWeissman2019_filtered")
print(NormanWeissman2019_filtered)
adata = NormanWeissman2019_filtered.adata

- The expression matrix is in adata.X with genes as rows and cells as columns.
- Perturbation labels are stored in `adata.obs['perturbation']`.
- Gene names are in `adata.var_names`.

Ensure that the data is in a dense format:

In [None]:
if sparse.issparse(adata.X):
    adata.X = adata.X.toarray()

Aggregate (pseudobulk) the single-cell data per perturbation condition:

In [None]:
perturbations = adata.obs["perturbation"].unique()
X_train = []

for pert in perturbations:
    # Subset the data for the current perturbation.
    cells = adata[adata.obs["perturbation"] == pert]
    # Compute the average expression across cells.
    mean_expression = cells.X.mean(axis=0)
    X_train.append(mean_expression)

# Convert Y_train to a numpy array
Y_train = np.array(X_train).T  # Shape: (n_genes, n_perturbations)

## The Linear Gene Expression Model

In the **Linear Gene Expression Model**, we have:
- Embeddings for read-out genes: $G$ (an $n \times K$ matrix, where $n$ is the number of read-out genes and $K$ is the dimensionality of the embeddings).
- Embeddings for perturbations: $P$ (an $m \times K$ matrix, where $m$ is the number of perturbations).
- Weight matrix: $W$ (a $K \times K$ matrix to be learned).
- Bias vector: $b$ (an $n \times 1$ vector of average gene expressions).

The model predicts gene expression values using:
$$
Y_{\text{train}} \approx G W P^\top + b
$$

Note: The bias vector $b$ (with dimensions $n \times 1$) is added to each column of the matrix $G W P^\top$ (with dimensions $n \times m$).
This operation effectively _broadcasts_ the vector $b$ across all $m$ columns, repeating it $m$ times to match the dimensions of $G W P^\top$.
PyTorch handles broadcasting automatically implicitly.

## Formulating the Linear Gene Expression Model As a Neural Network

The LGEM formulation matches the standard linear layer in neural networks, where the output is a linear transformation of the input plus a bias term.
By combining $G$ and $W$ into a single matrix $M = G W$ with dimensions $n \times K$, we can write the prediction for each perturbation as:

$$
y = M p^\top + b
$$.

Here:
- $y$ is the predicted gene expression vector ($n \times 1$).
- $p^\top$ is the transpose of the perturbation embedding vector ($K \times 1$).
- $M$ serves as the weight matrix in the neural network.
- $b$ is the bias vector.

This can directly be interpreted as the standard linear layer given by:

$$
y = W x + b
$$.

Since both models are linear and $G$ is fixed, the neural network _without activation functions_ is equivalent to the original LGEM.

We can now choose to keep $G$ fixed (i.e., it consists of the top $K$ principal components from a PCA on $Y_{\text{train}}$), and only learn $W$.
$M = G W$ would then be a combination of the fixed $G$ and the learned $W$.
Or we can choose to set $G$ to be learnable.

Overview:
1.	Fixed $G$:
	•	Model Equation:  y = (G W) x + b 
	•	Parameters Learned:  W ,  b 
	•	Equivalence: Maintained with the original model.
2.	Learnable  G :
	•	Model Equation:  y = (G W) x + b 
	•	Parameters Learned:  G ,  W ,  b 
	•	Equivalence: Model becomes more flexible but differs from the original.
3. Non-Linearity
4. Learnable G & Non-Linearity

In [None]:
import torch.nn as nn


class LinearGeneExpressionModel(nn.Module):  # noqa: D101
    def __init__(self, G, b):  # noqa: D107
        super(LinearGeneExpressionModel, self).__init__()
        self.G = G  # Fixed gene embeddings (n_genes x K)
        self.b = b  # Fixed bias vector (n_genes x 1)
        K = G.shape[1]  # Dimensionality of embeddings
        self.W = nn.Parameter(torch.randn(K, K))  # Learnable weight matrix

    def forward(self, P):  # noqa: D102
        # P: Perturbation embeddings (n_perturbations x K)
        GW = self.G @ self.W  # (n_genes x K)
        Y_pred = GW @ P.T  # (n_genes x n_perturbations)
        Y_pred = Y_pred + self.b  # Broadcasting b across columns
        return Y_pred

In [None]:
# Initialize the model with fixed G and b
model = LinearGeneExpressionModel(G=G_tensor, b=b_tensor)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Note: Only model.parameters() that have requires_grad=True (i.e.,  W ) will be updated.

In [None]:
num_epochs = 1000  # Set the number of training epochs

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # Forward pass
    Y_pred = model(P_tensor)  # Y_pred: (n_genes x n_perturbations)

    # Compute loss
    loss = criterion(Y_pred, Y_train_tensor)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    # Print loss every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

In [None]:
# Evaluate the model
model.eval()
with torch.no_grad():
    Y_pred = model(P_tensor)
    final_loss = criterion(Y_pred, Y_train_tensor)
    print(f"Final Training Loss: {final_loss.item():.4f}")