<a href="https://colab.research.google.com/github/linhtrinh213/Interpretable_VAE_TF_regulons/blob/main/Tidy_VEGA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Libraries and Data Loading


Library Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#### Library Setup ####
# install and import required packages
!pip install scanpy
!pip install decoupler
!pip install omnipath

import torch; torch.manual_seed(100)
import torch.nn as nn
import torch.utils
import torch.distributions
import torchvision
from torchvision import datasets, transforms
import math
import numpy as np
np.random.seed(100)
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
import scanpy as sc
from collections import OrderedDict
from collections import Counter

# select the right device, depending on whether your Colab runs on GPU or CPU
### IMPORTANT: we recommend to change your runtime to GPU, otherwise the training takes much longer
device = 'mps'


Collecting scanpy
  Downloading scanpy-1.11.1-py3-none-any.whl.metadata (9.9 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.11.4-py3-none-any.whl.metadata (9.3 kB)
Collecting legacy-api-wrap>=1.4 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting scikit-learn<1.6.0,>=1.1 (from scanpy)
  Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting array-api-compat!=1.5,>1.4 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.11.2-py3-none-any.whl.metadata (1.9 kB)
Downloading scanpy-1.11.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m56.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading anndata-0.11.4-py3-none-any.whl (144 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.

PBMC Data download

In [3]:
# download the data
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG' -O PBMC_train.h5ad
# Load data and use Scanpy to convert it into AnnData
PBMC_train = sc.read_h5ad("PBMC_train.h5ad")

--2025-05-15 13:02:17--  https://docs.google.com/uc?export=download&id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG
Resolving docs.google.com (docs.google.com)... 64.233.170.139, 64.233.170.113, 64.233.170.102, ...
Connecting to docs.google.com (docs.google.com)|64.233.170.139|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG&export=download [following]
--2025-05-15 13:02:17--  https://drive.usercontent.google.com/download?id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 74.125.68.132, 2404:6800:4003:c02::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|74.125.68.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 45277554 (43M) [application/octet-stream]
Saving to: ‘PBMC_train.h5ad’


2025-05-15 13:02:24 (42.4 MB/s) - ‘PBMC_train.h5ad’ saved [45277554/4

In [4]:
# devide data into control and stimulated


# Subset for a specific condition, e.g., "control"
PBMC_control = PBMC_train[PBMC_train.obs["condition"] == "control"].copy()

# Another example for "treated"
PBMC_stimulated = PBMC_train[PBMC_train.obs["condition"] == "stimulated"].copy()


In [5]:
# Take a look at the data
print(PBMC_train) #the data is stored as an anndata object
print(Counter(PBMC_train.obs["cell_type"])) #summary of cell types
print(Counter(PBMC_train.obs["condition"])) #summary of conditions

AnnData object with n_obs × n_vars = 13515 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'connectivities', 'distances'
Counter({'CD4T': 4452, 'FCGR3A+Mono': 2881, 'CD14+Mono': 2049, 'B': 1448, 'NK': 931, 'CD8T': 892, 'Dendritic': 862})
Counter({'stimulated': 7109, 'control': 6406})


# Regulon Import and Mask Creation

Regulon Import

In [6]:
### CollecTRI network: a curated collection of TFs and their transcriptional targets ###

import decoupler as dc
regulons = dc.get_collectri(organism='human', split_complexes=False)    # processed regulons

  # spilt_complexes: keep or spilt complexes into subunits
  # weight: 1 is activation. -1 is inhibition

print(regulons)

      source  target  weight
0       ABL1     BAX       1
1       ABL1    BCL2      -1
2       ABL1    BCL6      -1
3       ABL1   CCND2       1
4       ABL1  CDKN1A       1
...      ...     ...     ...
40625   ZXDC  CDKN1C       1
40626   ZXDC  CDKN2A       1
40627   ZXDC   CIITA       1
40628   ZXDC   HLA-E       1
40629   ZXDC     IL5       1

[40630 rows x 3 columns]


In [7]:
####### In the following sections, we used and adapted code from https://github.com/LucasESBS/vega ######
####### Here the mask is created and added to adata #######
import pandas as pd
import numpy as np

def create_mask(adata, regulons, add_nodes:int=10, sep = "\t"):
    """
    Initialize mask M that specifies which latent nodes connect to which decoder nodes.
    Args:
        adata (Anndata): Scanpy single-cell object, we will store the computed mask and the names of the biological processes there
        regulons: which TFs affected which genes
        add_nodes (int): Additional latent nodes for capturing additional variance
    Return:
        adata (Anndata): Scanpy single-cell object that now stores the computed mask and the names of biological processes (in the .uns["_vega"] attribute)
        mask (array): mask M that specifies whether a gene is included in the gene set of a pathway (value one) or not (value zero)
    """

    # Create the mask
    # 1. Get unique genes (targets) and TFs (sources)
    genes = regulons['target'].unique()
    tfs = regulons['source'].unique()

    # 2. Initialize matrix M with zeros
    M = pd.DataFrame(0, index=genes, columns=tfs)

    # 3. Set M[i,j] = 1 where the gene i is affected by TF j
    for _, row in regulons.iterrows(): # for each row in regulons
        M.loc[row['target'], row['source']] = 1 #the corresponding genes, TF box = 1


    # Add unannotated nodes
    vec = np.ones((M.shape[0], add_nodes))
    M = np.hstack((M, vec))

    adata.uns['_vega'] = dict() #create attribute "_vega" to store the mask and pathway information
    adata.uns['_vega']['mask'] = M
    adata.uns['_vega']['TFs'] = list(tfs) + ['UNANNOTATED_'+str(k) for k in range(add_nodes)]


    return adata, M

In [20]:
PBMC_train, mask = create_mask(PBMC_train,regulons , add_nodes=1)

print(mask)
print(mask.shape)

[[1. 0. 0. ... 0. 0. 1.]
 [1. 0. 0. ... 0. 0. 1.]
 [1. 0. 0. ... 0. 0. 1.]
 ...
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]]
(6627, 1164)


# VEGA Components
Here are defined the encoder, decoder



In [9]:
# Define Encoder:
class Encoder(nn.Module):
    def __init__(self, latent_dims, input_dims, dropout, z_dropout): #dropout between the dense layers, z_dropout define the dropout rates between the encoder/latent space
        super(Encoder, self).__init__() #run the initialize code from nn.Module -> this class behaves like a Pytorch model
        self.encoder = nn.Sequential(
                                     nn.Linear(input_dims, 800),
                                     nn.ReLU(),
                                     nn.Dropout(p = dropout),
                                     nn.Linear(800, 800),
                                     nn.ReLU(),
                                     nn.Dropout(p = dropout))  #two layer, fully connected encoder with dropout

        # outputs mean vector u
        self.mu = nn.Sequential(nn.Linear(800, latent_dims), # the 800 neurons in the second layers -> latent space
                                nn.Dropout(p = z_dropout))
        # outputs standard variance
        self.sigma = nn.Sequential(nn.Linear(800, latent_dims),
                                   nn.Dropout(p = z_dropout))

        self.N = torch.distributions.Normal(0, 1)  # define Gaussian distribution for each input
        self.N.loc = self.N.loc.to(device) # move to the right device
        self.N.scale = self.N.scale.to(device)
        self.kl = 0 # place holder for storing KL divergence (regularization term)
        # KL measures how far the learned Gaussian is from the standard normal (0,1) -> this is a regularization term in VAE
    def forward(self, x):
        x = self.encoder(x) # pass the data to the encoder
        mu =  self.mu(x) # predict mean vector
        sigma = torch.exp(self.sigma(x)) # predict standard var exp for numeric stability
        z = mu + sigma*self.N.sample(mu.shape)  # Sample z using reparameterization trick

        self.kl = (0.5*sigma**2 + 0.5*mu**2 - torch.log(sigma) - 1/2).sum() #calculation of kullback-leibler divergence

        return z # output is the sampled latent vector


In [10]:
# define VEGA's decoder (which is adapted to enable biologcal intepretability )

class DecoderVEGA(nn.Module):
  """
  Define VEGA's decoder (sparse, one-layer, linear)
  """
  def __init__(self, mask):
        super(DecoderVEGA, self).__init__()

        self.sparse_layer = nn.Sequential(SparseLayer(mask)) # we define the architecture of the decoder below with the class "SparseLayer"

  def forward(self, x):
    z = self.sparse_layer(x.to(device)) # pass the latent representation x -> predict gene expression z
    return(z)


# define a class SparseLayer, that specifies the decoder architecture (sparse connections based on the mask)
class SparseLayer(nn.Module):
  def __init__(self, mask):
        """
        Extended torch.nn module which mask connection
        """
        super(SparseLayer, self).__init__()

        self.mask = nn.Parameter(torch.tensor(mask, dtype=torch.float).t(), requires_grad=False) # mask: binary matrix of shape (genes, TFs) — entry (i, j) = 1 if TF j regulates gene i.
        self.weight = nn.Parameter(torch.Tensor(mask.shape[1], mask.shape[0])) # mask dimension: in our cases 1164, 6627 ()
        self.bias = nn.Parameter(torch.Tensor(mask.shape[1])) #there should be in total 1164 nodes aka TF in the latent space -> hence 1164 biases
        self.reset_parameters()

        # mask weight
        self.weight.data = self.weight.data * self.mask # All connections not allowed by the mask are zeroed out

  def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1)) #
        self.weight.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
        # See the autograd section for explanation of what happens here
        return SparseLayerFunction.apply(input, self.weight, self.bias, self.mask)


######### You don't need to understand this part of the code in detail #########
# SparseLayerFunction: Custom gradient-aware layer to apply the mask during both forward and backward passes
class SparseLayerFunction(torch.autograd.Function):
    """
    We define our own autograd function which masks it's weights by 'mask'.
    For more details, see https://pytorch.org/docs/stable/notes/extending.html
    """

    # Note that both forward and backward are @staticmethods
    @staticmethod
    def forward(ctx, input, weight, bias, mask):

        weight = weight * mask # change weight to 0 where mask == 0
        #calculate the output
        output = input.mm(weight.t())
        output += bias.unsqueeze(0).expand_as(output) # Add bias to all values in output
        ctx.save_for_backward(input, weight, bias, mask)
        return output

    @staticmethod
    def backward(ctx, grad_output): # define the gradient formula
        input, weight, bias, mask = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = grad_mask = None

        # These needs_input_grad checks are optional and only to improve efficiency
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
            # change grad_weight to 0 where mask == 0
            grad_weight = grad_weight * mask
        if ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, grad_mask



In [11]:
class VEGA(nn.Module):
    def __init__(self, latent_dims, input_dims, mask, dropout = 0.3, z_dropout = 0.3):
        super(VEGA, self).__init__()
        self.encoder = Encoder(latent_dims, input_dims, dropout, z_dropout) # we use the same encoder as before (two-layer, fully connected, non-linear)
        self.decoder = DecoderVEGA(mask)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

The training loop is defined below

In [12]:
#training loop
def trainVEGA(vae, data, epochs=50, beta = 0.0001, learning_rate = 0.001):
    opt = torch.optim.Adam(vae.parameters(), lr = learning_rate, weight_decay = 5e-4)
    vae.train() #train mode
    losses = []
    klds = []
    mses = []

    for epoch in range(epochs):
        loss_e = 0
        kld_e = 0
        mse_e = 0

        for x in data:
            x = x.to(device)
            opt.zero_grad()
            x_hat = vae(x)
            mse = ((x - x_hat)**2).sum()
            kld = beta* vae.encoder.kl
            loss = mse +  kld # loss calculation
            loss.backward()
            opt.step()
            loss_e += loss.to('cpu').detach().numpy()
            kld_e += kld.to('cpu').detach().numpy()
            mse_e += mse.to('cpu').detach().numpy()
            #vae.decoder.positive_weights() # we restrict the decoder for positive weights

        losses.append(loss_e/(len(data)*128))
        klds.append(kld_e/(len(data)*128))
        mses.append(mse_e/(len(data)*128))

        print("epoch: ", epoch, " loss: ", loss_e/(len(data)*128))

    return vae, losses, klds, mses

# Model Training

In [13]:
import torch
# NOTE: In Runtime, Change type (hardware accelerator) -> GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [14]:
# devide data into control and stimulated

# Control subset
PBMC_control = PBMC_train[PBMC_train.obs["condition"] == "control"].copy()

# Stimulated subset
PBMC_stimulated = PBMC_train[PBMC_train.obs["condition"] == "stimulated"].copy()

In [15]:
# Set up training data for each VEGA for pyTorch
PBMC_controlX = torch.utils.data.DataLoader(PBMC_control.X.toarray(), batch_size=128)
PBMC_stimulatedX = torch.utils.data.DataLoader(PBMC_stimulated.X.toarray(), batch_size=128)

# training data (gene expression data) I DONT KNOW WHAT THIS IS FOR
#latent_dims = 50 #choose the number of latent variables


In [16]:
# instantiate the model as vega_ctr
vega_ctr = VEGA(latent_dims= mask.shape[1], input_dims = PBMC_control.shape[0], mask = mask.T, z_dropout = 0.5, dropout = 0.3).to(device)
# Train vega_ctr on the control data
vega_ctr, vega_losses_ctr, vega_klds_ctr, vega_mses_ctr = trainVEGA(vega_ctr, PBMC_controlX, epochs = 50, beta = 0.0001) #takes about 2 mins on GPU

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x6998 and 6406x800)