In [1]:
import os

import tqdm
import anndata
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from sklearn.metrics import matthews_corrcoef, roc_auc_score

In [2]:
adata_rna = anndata.read_h5ad('../output/scRNA/adata_rna.h5ad')

In [3]:
adata_rna

AnnData object with n_obs × n_vars = 25034 × 22778
    obs: 'cell_type', 'donor_id', 'cell_type_original', 'Donor', 'Cell type'

In [3]:
gene_names = adata_rna.var.index.to_numpy()
gene_name_dict = {gene_name: i for i, gene_name in enumerate(gene_names)}
gene_names

array(['A1BG', 'A1BG-AS1', 'A2M', ..., 'ZYG11B', 'ZYX', 'ZZEF1'],
      dtype=object)

In [4]:
# Download GRN
df_collectri = pd.read_csv("https://github.com/pablormier/omnipath-static/raw/main/op/collectri-26.09.2023.zip")


In [5]:

# Iterate over the edges (regulatory relationships)
edge_idx = set()
for gene_a, gene_b in zip(df_collectri['source'], df_collectri['target']):
    if (gene_a not in gene_name_dict) or (gene_b not in gene_name_dict):
        continue  # Consider only gene names that are present in the training data
    i = gene_name_dict[gene_a]  # Index of first gene
    j = gene_name_dict[gene_b]  # Index of second gene
    edge_idx.add((i, j))
edge_idx = np.asarray(list(edge_idx), dtype=int)

# Convert list of edges into an adjacency matrix
grn = np.zeros((len(gene_names), len(gene_names)))
grn[edge_idx[:, 0], edge_idx[:, 1]] = 1


In [13]:
(grn==0).sum()/grn.size
grn.shape

(5333, 5333)

In [6]:
# Remove rows and columns with no annotation
grn_idx = np.where(np.logical_or(grn.sum(axis=0) > 0, grn.sum(axis=1) > 0))[0]
grn = grn[grn_idx, :][:, grn_idx]

In [9]:
class Scaler(torch.nn.Module):
    
    def __init__(self, m: int) -> None:
        torch.nn.Module.__init__(self)
        self.m: int = m
        self.a: torch.Tensor = torch.nn.Parameter(torch.ones((1, self.m)))
        self.b: torch.Tensor = torch.nn.Parameter(torch.zeros((1, self.m)))
    
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.a * X + self.b


class VAE(torch.nn.Module):
    
    def __init__(self, n_genes: int) -> None:
        torch.nn.Module.__init__(self)
        self.n_genes: int = n_genes
        self.A = torch.nn.Parameter(0.5 * torch.eye(n_genes).float())
        # self.A = 0.5 * torch.eye(n_genes).float()
        
        self.encoder = torch.nn.Sequential(
            Scaler(self.n_genes),
            torch.nn.LeakyReLU(),
            Scaler(self.n_genes),
            torch.nn.LeakyReLU(),
        )
        
        self.mu_regressor = torch.nn.Sequential(
            Scaler(self.n_genes)
        )
        self.logvar_regressor = torch.nn.Sequential(
            Scaler(self.n_genes)
        )
        
        self.decoder = torch.nn.Sequential(
            Scaler(self.n_genes),
            torch.nn.LeakyReLU(),
            Scaler(self.n_genes)
        )
    
    def reparametrize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        if self.training:
            std = torch.sqrt(torch.exp(logvar) + 1e-10)
            return std * torch.randn_like(mu) + mu
        else:
            return mu
    
    def forward(self, X_rna: torch.Tensor) -> torch.Tensor:        
        encoded_rna = self.encoder(X_rna)
        
        mu = self.mu_regressor(encoded_rna)
        logvar = self.logvar_regressor(encoded_rna)
        
        grn = torch.eye(self.n_genes) - self.A.t()  # TODO: transpose
        mu = torch.matmul(mu, grn)
        logvar = torch.matmul(logvar, grn)
        
        Z = self.reparametrize(mu, logvar)
        
        grn_inv = torch.linalg.inv(grn)
        Z = torch.matmul(Z, grn_inv)
        
        decoded_rna = self.decoder(Z)
        
        return decoded_rna, mu, logvar
n_genes = len(grn_idx)
n_epochs = 10

vae = VAE(n_genes)
vae.train()

batch_size = 256

optimizer = torch.optim.SGD(vae.parameters(), lr=1e-3, weight_decay=1e-6)
# TODO: Adam?

losses = []

for epoch in range(n_epochs):
    losses.append(0)
    n_batches = 0
    
    idx = np.arange(len(adata_rna))
    np.random.shuffle(idx)

    for batch_idx in tqdm.tqdm(np.array_split(idx, int(len(idx) / batch_size)), desc=f'Epoch {epoch + 1}'):
        
        x_rna = torch.FloatTensor(adata_rna[batch_idx, grn_idx].X.todense())
        
        # Reset gradients
        optimizer.zero_grad()

        # Forward pass
        decoded, mu, logvar = vae(x_rna)

        
        # Compute loss
        reconstruction_loss = torch.sum(torch.square(x_rna - decoded))
        kld = 0.5 * torch.sum(1.0 + logvar - mu.pow(2) - logvar.exp())
        loss = reconstruction_loss - kld
        loss = loss / (len(batch_idx) * n_genes)
        
        # Backward pass
        loss.backward()
        
        # Update parameters
        optimizer.step()

        # Update total loss
        losses[-1] += loss.item()
        n_batches += 1
        
        mask = ~np.eye(n_genes, dtype=bool)
        grn_pred = np.abs(vae.A.cpu().data.numpy().T)
        print('AUROC', roc_auc_score(grn[mask], grn_pred[mask]))
        print('Loss', loss.item())

    losses[-1] /= n_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, n_epochs, losses[-1]))

Epoch 1:   0%|          | 0/97 [00:05<?, ?it/s]


ValueError: Input contains NaN.

In [15]:
grn_idx.shape

(5333,)