The aim of this notebook is to introduce an AutoEncoder model for multi-modal integration of the ATAC + GEX data. The multi-modal AutoEncoder idea will be taken from (https://ieeexplore.ieee.org/document/8715409) - specifically Figure 4.

We'll extend this idea to the gene-expression and atac-seq data. We'll select variable features for both modalities and then jointly encode the gex and atac-seq data into the same space. After training and monitoring the loss to a certain extent, we'll freeze model training and get the embeddings (from the latent space) for the multi-modal data and then test to see how well this does based on our evaluation (and see if it's better than simply concatenating the PCA reduction of each modality individually).

Let's get started by loading the single-cell libraries, the data, and getting the variable features from both modalities. We'll restrict ourselves to 2500 features for GEX and 5000 features for ATAC/chromatin accessibility

In [1]:
import os

import numpy as np
import pandas as pd
import scanpy as sc 
import anndata as ann
import episcanpy as esc

In [2]:
os.chdir("..")
multiome = sc.read_h5ad("data/multimodal/GSE194122_openproblems_neurips2021_multiome_BMMC_processed.h5ad")

In [3]:
gex = multiome[:, multiome.var["feature_types"] == "GEX"] # Subset all data, not just the counts 
sc.pp.highly_variable_genes(gex, n_top_genes=2500, flavor="seurat_v3") # Feature selection

Trying to set attribute `._uns` of view, copying.


In [4]:
atac = multiome[:, multiome.var["feature_types"] == "ATAC"] # Subset all data, not just the counts 
esc.pp.select_var_feature(atac, nb_features=5000, show=False) # Feature-selection - most variable features

Trying to set attribute `.var` of view, copying.


Let's go ahead and load the tensorflow libraries

In [5]:
import matplotlib.pyplot as plt
import tensorflow as tf

from sklearn.metrics import accuracy_score, precision_score, recall_score
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model

2022-11-15 17:17:30.798600: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


We'll extract our GEX and ATAC data and load it into tensorflow tensors

In [8]:
hvg_indices = gex.var["highly_variable"]
gex_arr = gex.X.todense()[:, hvg_indices]
gex_tensor = tf.convert_to_tensor(gex_arr)

In [15]:
gex_tensor.shape

TensorShape([69249, 2500])

In [21]:
atac_arr = atac.X.todense()
atac_tensor = tf.convert_to_tensor(atac_arr)

In [22]:
atac_tensor.shape

TensorShape([69249, 5001])

One extra dimension in the ATAC data, but we don't need to worry about that. We'll just say that the dimensionality of our ATAC data is 5001. We'll input both datatypes as a concatenated representation

In [45]:
gex_atac_concat = tf.concat([gex_tensor, atac_tensor], axis = 1)
gex_atac_concat.shape

TensorShape([69249, 7501])

Lets go ahead and create our AutoEncoder model now

In [110]:
latent_dim = 20 # Specify the size of our latent dimension 
gex_dim = 2500
atac_dim = 5001

# Create the model class for our AutoEncoder - this follows mostly for the tutorial from 
# https://www.tensorflow.org/tutorials/generative/autoencoder, except we're putting a multi-modal
# flavor on it and ensuring it uses and reconstructs both GEX and ATAC outputs 
class MultiModalAutoencoder(Model):
    def __init__(self, latent_dim, gex_dim, atac_dim):
        super(MultiModalAutoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.gex_dim = gex_dim
        self.atac_dim = atac_dim
        # We have two encoders and decoder - for each modality
        # We divide latent dim by two because we are going to 
        # concatenate the two modalities in latent space and
        # then use that concatenated representation to reconstruct
        # each modality 
        self.gex_encoder = tf.keras.Sequential([
            layers.Dense(250, activation="relu"),
            layers.Dense(latent_dim/2, activation="relu")
        ])
        self.atac_encoder = tf.keras.Sequential([
            layers.Dense(250, activation="relu"),
            layers.Dense(latent_dim/2, activation="relu")
        ])
        self.latent_concat = tf.keras.layers.Concatenate(
            axis=-1
        )
        self.outputs_concat = tf.keras.layers.Concatenate(
            axis=-1
        )
        self.gex_decoder = tf.keras.Sequential([
            layers.Dense(250, activation = "relu"),
            layers.Dense(gex_dim, activation = "relu")
        ])
        self.atac_decoder = tf.keras.Sequential([
            layers.Dense(250, activation = "relu"),
            layers.Dense(atac_dim, activation = "relu")
        ])
        
    def call(self, gex_atac_X):
        # Extract the data
        gex_X = gex_atac_X[:, 0:2501]
        atac_X = gex_atac_X[:, 2501:]
        # Encode both the GEX and ATAC data 
        gex_Z = self.gex_encoder(gex_X)
        atac_Z = self.atac_encoder(atac_X)
        # Concatenate the two encoded modalities 
        gex_atac_c = self.latent_concat([gex_Z, atac_Z]) # This is our latent we'll use later
        # Use the concatenated representation to recover both GEx and ATAC
        gex_X_decoded = self.gex_decoder(gex_Z)
        atac_X_decoded = self.atac_decoder(atac_Z)
        gex_atac_X_decoded = self.outputs_concat([gex_X_decoded, atac_X_decoded])
        return gex_atac_X_decoded 
    
# We're going to define a custom loss as we need a separate loss for both modalities 
# For GEX, since the data is continuous, we can use a mean-squared error loss 
# For ATAC, since the data is binary, we'll use a binarycrossentropy loss 
# We'll combine these to have even weight for now - but the scaling can be played around with
# (And maybe even treated as a hyperparameter)
def multimodal_loss(gex_atac_true, gex_atac_pred):
    # GEX loss 
    mse = tf.keras.losses.MeanSquaredError()
    gex_true = gex_atac_true[:, 0:2501]
    gex_pred = gex_atac_pred[:, 0:2501]
    gex_loss = mse(gex_true, gex_pred)
    
    # ATAC loss 
    bce = tf.keras.losses.BinaryCrossentropy()
    atac_true = gex_atac_true[:, 2501:]
    atac_pred = gex_atac_pred[:, 2501:]
    atac_loss = bce(atac_true, atac_pred)
    
    # Combine both and return
    loss = gex_loss*0.5 + atac_loss*0.5
    return loss 

Let's compile the autoencoder and train for 10 epochs

In [111]:
autoencoder = MultiModalAutoencoder(latent_dim, gex_dim, atac_dim)
autoencoder.compile(optimizer='adam', loss=multimodal_loss)

In [112]:
autoencoder.fit(gex_atac_concat, gex_atac_concat,
                epochs=10,
                shuffle=True,
                batch_size=128
               )

Epoch 1/10

KeyboardInterrupt: 

TensorShape([66748, 7501])

In [63]:
print(gex_atac_concat)

tf.Tensor(
[[0.        0.        0.        ... 0.        1.        0.       ]
 [0.        0.        0.        ... 0.        0.        0.       ]
 [0.        0.        0.        ... 1.        0.        0.       ]
 ...
 [0.        0.        0.        ... 0.        1.        1.       ]
 [0.        0.        0.        ... 0.        1.        1.       ]
 [0.        0.        2.0074675 ... 0.        0.        1.       ]], shape=(69249, 7501), dtype=float32)
