# Gene imputation

Here, we try to impute missing data using multiETM

In [1]:
import sys
sys.path.append('../src/scETM/')

import os
os.environ[ 'NUMBA_CACHE_DIR' ] = '/scratch/st-jiaruid-1/yinian/tmp/' # https://github.com/scverse/scanpy/issues/2113

In [2]:
import scanpy as sc
import numpy as np
import anndata as ad
import torch
import yaml
from pathlib import Path
import pickle
import matplotlib.pyplot as plt

np.random.seed(0)

Matplotlib created a temporary config/cache directory at /tmp/pbs.4804264.pbsha.ib.sockeye/matplotlib-hrcqb548 because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [3]:
from batch_sampler import CellSampler
from models.scETM import scETM
from trainers.UnsupervisedTrainer import UnsupervisedTrainer
from eval_utils import evaluate

In [4]:
from batch_sampler import CellSamplerCITE
from models.multiETM import MultiETM
from trainers.UnsupervisedTrainerCITE import UnsupervisedTrainerCITE
from eval_utils import evaluate

In [5]:
config = yaml.safe_load(Path('../experiments/covid_healthy.yaml').read_text())
config

{'files': {'rna': ['/arc/project/st-jiaruid-1/yinian/pbmc/CV0902_rna.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0915_rna.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0917_rna.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0929_rna.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0939_rna.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0940_rna.h5ad'],
  'protein': ['/arc/project/st-jiaruid-1/yinian/pbmc/CV0902_protein.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0915_protein.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0917_protein.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0929_protein.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0939_protein.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0940_protein.h5ad'],
  'combined': ['/arc/project/st-jiaruid-1/yinian/pbmc/CV0902_combined.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0915_combined.h5ad',
   '/arc/project/st-jiaruid-1/yinian/pbmc/CV0917_combined.h5ad',
   '/arc/project/st-

## Load the data

In [6]:
files = config['files']
model_params = config['model_params']
if model_params['cell_type_col'] == 'None':
    model_params['cell_type_col'] = None

In [7]:
rna_files = files['rna']
protein_files = files['protein']
rna_adata = ad.concat([ad.read_h5ad(r_file) for r_file in rna_files], label="batch_indices")
protein_adata = ad.concat([ad.read_h5ad(p_file) for p_file in protein_files], label="batch_indices")
rna_adata

AnnData object with n_obs × n_vars = 12292 × 24737
    obs: 'sample_id', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'full_clustering', 'initial_clustering', 'Resample', 'Collection_Day', 'Sex', 'Age_interval', 'Swab_result', 'Status', 'Smoker', 'Status_on_day_collection', 'Status_on_day_collection_summary', 'Days_from_onset', 'Site', 'time_after_LPS', 'Worst_Clinical_Status', 'Outcome', 'patient_id', 'batch_indices'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap'
    layers: 'raw'

### Split into training and test splits

The train split will not zero out the genes we are going to test, the test set will.

In [8]:
train_indices = np.random.choice(np.arange(len(rna_adata)), size=int(len(rna_adata) * 0.85), replace=False)
test_indices = np.array(list(set(np.arange(len(rna_adata))).difference(train_indices)))

In [9]:
# with open('/scratch/st-jiaruid-1/yinian/my_jupyter/scETM/scripts/train_indices_ch.pkl', 'wb') as f:
#     pickle.dump(train_indices, f)

In [10]:
train_rna_adata = rna_adata[train_indices]
train_protein_adata = protein_adata[train_indices]

In [11]:
test_rna_adata = rna_adata[test_indices].copy()
orig_test_rna_adata = rna_adata[test_indices]
test_protein_adata = protein_adata[test_indices]

Select 10% of genes to impute and set them to zero

In [12]:
# On initial run, select the genes to zero out and save them in a file
# gene_indices = np.random.choice(np.arange(rna_adata.n_vars), size=int(rna_adata.n_vars * 0.1), replace=False)
# gene_indices

In [13]:
# with open('/scratch/st-jiaruid-1/yinian/my_jupyter/scETM/scripts/gene_indices_covid2.pkl', 'wb') as f:
#     pickle.dump(gene_indices, f)

In [14]:
# On all subsequent runs, load the saved file and zero out genes saved from there. This includes runs for 
# other models but the same dataset.
with open('/scratch/st-jiaruid-1/yinian/my_jupyter/scETM/scripts/gene_indices_covid2.pkl', 'rb') as f:
    gene_indices = pickle.load(f)

In [15]:
test_rna_adata[:, gene_indices] = 0

  self._set_arrayXarray(i, j, x)


## Train the model

In [16]:
model = MultiETM(train_rna_adata.n_vars, train_protein_adata.n_vars, train_rna_adata.obs.batch_indices.nunique())
trainer = UnsupervisedTrainerCITE(model, train_rna_adata, train_protein_adata,
                                  ckpt_dir='/scratch/st-jiaruid-1/yinian/output/')

In [17]:
trainer.train(n_epochs=12000, eval_every=12000, eval_kwargs = dict(cell_type_col = 'full_clustering'), n_samplers=1, save_model_ckpt=False)

loss:      38.72	nll:      38.72	kl_delta:      166.6	max_norm:      131.7	Epoch     0/12000	Next ckpt:       0

scETM.evaluate assumes discrete cell types. Converting cell_type_col to categorical.
2023-03-25 15:07:11.603971: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-25 15:07:11.760830: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-03-25 15:07:14.479525: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.

loss:      12.68	nll:      12.68	kl_delta:        494	max_norm:     0.2857	Epoch   545/12000	Next ckpt:   12000

KeyboardInterrupt: 

## Run the forward model and see what gets imputed

In [37]:
from sklearn.metrics import mean_squared_error
from scipy.sparse import spmatrix

In [38]:
def correlation_score(y_true, y_pred):
    """Scores the predictions according to the competition rules.

    It is assumed that the predictions are not constant.

    Returns the average of each sample's Pearson correlation coefficient

    Source: https://www.kaggle.com/code/xiafire/lb-t15-msci-multiome-catboostregressor#Predicting
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes are different.")
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [39]:
# Get the embeddings (emb) which represent the topic mixture proportions
emb, nll = model.get_cell_embeddings_and_nll(test_rna_adata, test_protein_adata, inplace=False)

In [40]:
if rna_adata.obs.batch_indices.nunique() > 1:
    batch = torch.LongTensor(test_rna_adata.obs['batch_indices'].astype('category').cat.codes)
else:
    batch = None

In [41]:
# Use the model decoder to reconstruct the data, the reconstruction will try to impute the zeroed out data.
pred = model.decode(torch.Tensor(emb['theta']).to('cuda'), batch).detach().cpu().numpy()
pred = pred[:, :rna_adata.n_vars]
pred = np.exp(pred)

In [42]:
indexed_pred = pred[:, gene_indices]
indexed_pred.shape

(1794, 1978)

In [43]:
true_data = orig_test_rna_adata[:, :model_params['rna_n_vars']].copy()
if isinstance(true_data.X, spmatrix):
    true_data.X = true_data.X.toarray()
true_data = true_data.X / true_data.X.sum(1, keepdims=True)
indexed_true_data = true_data[:, gene_indices]
indexed_true_data.shape

(1794, 1978)

In [44]:
mean_squared_error(indexed_pred, indexed_true_data)

2.2981219e-07

In [45]:
correlation_score(indexed_true_data, indexed_pred)

  c /= stddev[:, None]
  c /= stddev[None, :]


nan