# CITE-seq prediction with the top problem from OpenProblems 2022 competition

In this repository, we provide a reimplementation of the best CITE-seq model from OpenProblems modality prediction competition in 2022. Original code can be found on [GitHub](https://github.com/senkin13/kaggle/blob/master/Open-Problems-Multimodal-Single-Cell-Integration-2nd-Place-Solution/senkin13/preprocess_cite.ipynb) and on [Kaggle](https://www.kaggle.com/code/senkin13/2nd-place-gru-cite). Here, we adapt this code to work with mudata format and modern library versions. We also make it more efficient by using power of `muon` and `scanpy` libraries.

Let's first import some libraries:

In [None]:
import logging

import mudata as mu
import numpy as np
import scanpy as sc
import pooch  # To download mudata from figshare, feel free to drop if manually downloaded

from senkin_tmp_cite_pred.preprocess import preprocess_data

  from .autonotebook import tqdm as notebook_tqdm


Create directories for data and models:

In [2]:
%%bash

mkdir -p data/
mkdir -p models/

Let's download data from 2021 competition to make sure that the code runs for another dataset:

In [4]:
pooch.retrieve(
    url="https://figshare.com/ndownloader/files/41452287",  # Feel free to manually download the file to avoid installing pooch
    fname="data/cite_filtered.h5mu",
    path=".",
    known_hash=None,
    progressbar=True,
)

mdata = mu.read_h5mu("data/cite_filtered.h5mu")
mdata

Downloading data from 'https://figshare.com/ndownloader/files/41452287' to file '/Users/vladimir.shitov/Documents/programming/senin_tmp_CITE_pred/senkin-tmp-cite-pred/data/cite_filtered.h5mu'.
100%|█████████████████████████████████████| 1.55G/1.55G [00:00<00:00, 2.30TB/s]
SHA256 hash of downloaded file: 2b11fe2ac8cea96cf83992c2ab231f3474d2173ccec49fc6b7b2499ad213d7f4
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


If the data was downloaded previously, comment out the cell above and run:

In [5]:
mdata = mu.read_h5mu("data/cite_filtered.h5mu")
mdata

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


Let's select a random subset of cells to test the code:

In [6]:
random_cells = np.random.choice(mdata.obs_names, size=10000, replace=False)
random_cells

array(['GGGCTACTCGAGATGG-1-1-0-0-0-0-0-0-0-0',
       'TGTTACTGTCAATCTG-1-1-0-0-0', 'TCGGGCACAAATTGGA-1-1-0-0-0-0', ...,
       'TCCGGGAAGTAACAGT-1-1-0-0', 'TCTCTGGCAGGCAATG-1-1-0',
       'CTCAACCAGGGCAGTT-1-1-0-0-0-0-0-0-0-0-0-0'],
      shape=(10000,), dtype=object)

In [7]:
mdata = mdata[random_cells].copy()

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


In [8]:
adata_rna = mdata.mod["rna"]
adata_prot = mdata.mod["prot"]

In the 2022 competition, the batch effect correction for this approach was done per day. Let's follow that, and extract day in a separate `.obs` column:

In [9]:
adata_rna.obs["day"] = adata_rna.obs["donor"].str[-2:]

In [10]:
adata_rna.obs["day"].value_counts()

day
d1    3547
d7    1200
d6    1182
d5     951
d9     832
d3     658
d4     625
d2     613
d8     392
Name: count, dtype: int64

Optionally, we can add a set of known features, which are important for prediction of the given surface proteins. In the original solution, it was a set of genes that encode proteins in the CITE-seq data. In this example notebook, we'll simply extract genes that have identical names for proteins and RNA. Note that it will give an incomplete list:

In [11]:
known_features = adata_prot.var["gene_ids"][adata_prot.var["gene_ids"].isin(adata_rna.var_names)]
print(f"Number of proteins with identical names in RNA: {len(known_features)}")

Number of proteins with identical names in RNA: 37


Now we can perform required preprocessing with just one function. It includes:
- Removal of constant features
- 200 components TSVD of CLR-transformed data
- 100 components PCA of customly normalized data (see documentation of `senkin_normalize` for details)
- Selection of raw expression for known features and for genes correlated with target proteins
- DSB normalization of protein data. For this `empty_count_range` setting is needed. Here, we used the same values as in 2022 competition. But this is data dependent, so make sure to select a meaningful range!

Note that the nomber of cells will become smaller after this step, because some droplets will be used as a background for DSB normalization. Additionally, number of genes will decrease as constant genes will be removed from RNA data.

In [12]:
mdata = preprocess_data(
    mdata,
    empty_counts_range=(1.5, 2.8),
    batch_key="day",
    group_key="donor",
    known_features=known_features
)

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)
  warn("adata.X is sparse but not in CSC format. Converting to CSC.")
  return super()._validate_value(val, key)
  X = (X - X.mean(axis=0)) / X.std(axis=0)


Note new `obsm` fields in RNA modality, and a `dsb` layer in the protein modality:

In [13]:
mdata

Let's save preprocessed files to reuse them later:

In [14]:
mdata.mod["rna"].write_h5ad("data/cite_filtered_preprocessed_rna.h5ad")
mdata.mod["prot"].write_h5ad("data/cite_filtered_preprocessed_prot.h5ad")

# LGBM predictions

Next, we'll use preprocessed data to make initial prediction of target proteints with Light Gradient Bossting (LGBM) models. The dimensionality of the predictions will be further reduced with TSVD to 100 components. While these predictions are not used per se, they will become features for neural networks. As we showed in our analysis, this makes the resulting prediction much more accurate.

In [15]:
from senkin_tmp_cite_pred.lgbm_models import train_lightgbm_models
from sklearn.model_selection import KFold

In [16]:
adata_rna = sc.read_h5ad("data/cite_filtered_preprocessed_rna.h5ad")
adata_rna

AnnData object with n_obs × n_vars = 9774 × 25156
    obs: 'donor', 'batch', 'day'
    var: 'gene_ids', 'feature_types'
    uns: 'log1p', 'selected_features'
    obsm: 'X_clr_tsvd', 'X_log_normalized', 'X_pca_sqrt_norm', 'X_raw_selected', 'X_sqrt_norm'

In [17]:
adata_prot = sc.read("data/cite_filtered_preprocessed_prot.h5ad")
adata_prot

AnnData object with n_obs × n_vars = 9774 × 140
    obs: 'donor', 'batch'
    var: 'gene_ids', 'feature_types'
    layers: 'dsb'

In [18]:
adata_rna.obs["day"].value_counts()

day
d1    3443
d7    1187
d6    1168
d5     937
d9     803
d3     647
d2     600
d4     600
d8     389
Name: count, dtype: int64

If your data contains "train" and "test" subsets, feel free to use them. Otherwise, we need to define train-test split. Let's use days 6-8 as test here. For test dataset, we won't use ground truth target protein levels for training. Here, we have a ground truth to validate the quality of our results, but in a real life scenario it can be a dataset, for which you want to impute protein modality.

In [19]:
adata_rna.obs["split"] = adata_rna.obs["day"].apply(lambda x: "test" if x in ["d6", "d7", "d8"] else "train")

In [20]:
adata_rna.obs["split"].value_counts()

split
train    7030
test     2744
Name: count, dtype: int64

Additionally, we'll define cross-validation splits for the training dataset. It'll help us to see how well our models predict the data on an unseen fold. Similarly to the original approach, we'll use random split to 5 folds.

In [21]:
folds = KFold(n_splits= 5, shuffle=True, random_state=666)

We'll save which cell comes from each fold to make sure results are reproducible after the `folds` variable is lost:

In [22]:
for i, (train_idxs, test_idxs) in enumerate(folds.split(adata_rna)):
    adata_rna.obs[f"fold_{i}"] = "test"
    adata_rna.obs.loc[adata_rna.obs_names[train_idxs], f"fold_{i}"] = "train"

adata_rna.obs["fold_0"].value_counts()

fold_0
train    7819
test     1955
Name: count, dtype: int64

In [23]:
train_cell_ids = adata_rna.obs_names[adata_rna.obs["split"] == "train"]
test_cell_ids = adata_rna.obs_names[adata_rna.obs["split"] == "test"]

If you wish to see less logs, feel free to remove the cell below:

In [24]:
logging.basicConfig(level=logging.DEBUG)

Train LGBM models. 4 models will be trained, using different subsets of training data, and different preprocessing of the target proteins. In this example, we'll only build models for 3 proteins here, and train them for only 3 epochs. Remove `[:, 3]` after `adata_prot` to use all the proteins, and remove `num_boost_round` `early_stopping_rounds`, and `n_tsvd_components` to use the default parameters

In [25]:
adata_rna = train_lightgbm_models(
    adata_rna,
    adata_prot[:, :3],
    train_cell_ids=train_cell_ids,
    test_cell_ids=test_cell_ids,
    folds=folds,
    num_boost_round=3,
    early_stopping_rounds=2,
    n_tsvd_components=10
)

INFO:senkin_tmp_cite_pred.lgbm_models:Initializing arrays in obsm with zeros
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model 1 for predicting DSB-normalized protein expression from log-normalized RNA expression
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models for 3 targets
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.39419


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.60932


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.80761


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.35578


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.20692


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.412599
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.397814


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.380666
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.402292


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.445119


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.146048


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.15668


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.14213
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.152791


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.141665


INFO:senkin_tmp_cite_pred.lgbm_models:CV score: 0.21122944745599762
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predictions to 2 components
INFO:senkin_tmp_cite_pred.lgbm_models:Preparing datasets for LightGBM model 2
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model 2 for predicting DSB-normalized protein expression from customly normalized RNA expression data and selected features
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models for 3 targets
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.37773


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.59449


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.78432


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.3403


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.18969


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.406571


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.391986


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.374271


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.395232


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.437211


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.145371


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.156014


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.141127


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.151962


INFO:senkin_tmp_cite_pred.lgbm_models:CV score: 0.21674261963878433
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predictions to 2 components
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model 3 for predicting DSB-normalized protein expression from raw RNA expression


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.14075


INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models for 3 targets
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.38625


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.60318


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.79507


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.3513
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2.1967


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.411495


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.396523


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.379092


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.40127
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.443384


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.14576


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.156416


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.141652


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.152378
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.141319


INFO:senkin_tmp_cite_pred.lgbm_models:CV score: 0.21270658937614365
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predictions to 2 components
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model 4 for predicting raw protein expression from raw RNA expression
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models for 3 targets
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 177.104


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 382.599


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 507.458


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 169.709


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 180.154


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 424.725


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 6830.37


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 12278.5


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 686.791


DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2615.92


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 555.093


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 2


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 7289.47


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 3


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 7072.69


DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 4


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2442.77
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 2707.63


  c /= stddev[:, None]
  c /= stddev[None, :]
INFO:senkin_tmp_cite_pred.lgbm_models:CV score: nan
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predictions to 2 components


Note that we know have 4 additional `X_lgbm` arrays in the `.obsm`:

In [26]:
adata_rna

AnnData object with n_obs × n_vars = 9774 × 25156
    obs: 'donor', 'batch', 'day', 'split', 'fold_0', 'fold_1', 'fold_2', 'fold_3', 'fold_4'
    var: 'gene_ids', 'feature_types'
    uns: 'log1p', 'selected_features'
    obsm: 'X_clr_tsvd', 'X_log_normalized', 'X_pca_sqrt_norm', 'X_raw_selected', 'X_sqrt_norm', 'X_lgbm_1', 'X_lgbm_2', 'X_lgbm_3', 'X_lgbm_4'

# Train neural networks

In [27]:
from senkin_tmp_cite_pred.nn_models import train_nn_models

Finally, we'll train two neural networks with originally proposed architectures to predict target proteins. First, let's prepare datasets:

In [28]:
train_cite_X = np.concatenate([
    adata_rna[train_cell_ids].obsm["X_clr_tsvd"],
    adata_rna[train_cell_ids].obsm["X_pca_sqrt_norm"],
    adata_rna[train_cell_ids].obsm["X_raw_selected"].toarray(),
    adata_rna[train_cell_ids].obsm["X_lgbm_1"],
    adata_rna[train_cell_ids].obsm["X_lgbm_2"],
    adata_rna[train_cell_ids].obsm["X_lgbm_3"],
    adata_rna[train_cell_ids].obsm["X_lgbm_4"],
], axis=1)

test_cite_X = np.concatenate([
    adata_rna[test_cell_ids].obsm["X_clr_tsvd"],
    adata_rna[test_cell_ids].obsm["X_pca_sqrt_norm"],
    adata_rna[test_cell_ids].obsm["X_raw_selected"].toarray(),
    adata_rna[test_cell_ids].obsm["X_lgbm_1"],
    adata_rna[test_cell_ids].obsm["X_lgbm_2"],
    adata_rna[test_cell_ids].obsm["X_lgbm_3"],
    adata_rna[test_cell_ids].obsm["X_lgbm_4"],
], axis=1)

train_cite_y = adata_prot[train_cell_ids].layers["dsb"]

In [29]:
print("Total number of features:", train_cite_X.shape[1])

Total number of features: 1039


This function will perform cross-validated training and prediction, and return us aggregated prediction of two models. For the purposes of the example, we'll on;y use 3 target proteins and 3 epochs to train the model. Remove `[:, :3]` from `train_cite_y` and set `EPOCHS` to 100 (or just leave the default) to predict all the proteins with the original parameters:

In [30]:
train_preds, test_preds = train_nn_models(
    train_cell_ids,
    train_cite_X,
    train_cite_y[:, :3],
    test_cell_ids,
    test_cite_X,
    folds, 
    EPOCHS=3
)

0


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 577ms/step - loss: -0.2607 - val_loss: -0.2738 - learning_rate: 0.0010
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 517ms/step - loss: -0.3584 - val_loss: -0.3387 - learning_rate: 0.0010
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 476ms/step - loss: -0.4217 - val_loss: -0.3827 - learning_rate: 0.0010
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 292ms/step
[np.float64(0.38866246937395404)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 111ms/step
1


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 557ms/step - loss: -0.2575 - val_loss: -0.3389 - learning_rate: 0.0010
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 377ms/step - loss: -0.3594 - val_loss: -0.3487 - learning_rate: 0.0010
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 398ms/step - loss: -0.3902 - val_loss: -0.4152 - learning_rate: 0.0010
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 190ms/step
[np.float64(0.38866246937395404), np.float64(0.4165862001734295)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 156ms/step
2


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 510ms/step - loss: -0.2684 - val_loss: -0.3761 - learning_rate: 0.0010
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 404ms/step - loss: -0.3751 - val_loss: -0.4531 - learning_rate: 0.0010
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 421ms/step - loss: -0.4377 - val_loss: -0.4867 - learning_rate: 0.0010
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 208ms/step
[np.float64(0.38866246937395404), np.float64(0.4165862001734295), np.float64(0.47648880260325616)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 100ms/step
3


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 585ms/step - loss: -0.2415 - val_loss: -0.3314 - learning_rate: 0.0010
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 412ms/step - loss: -0.3433 - val_loss: -0.3680 - learning_rate: 0.0010
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 424ms/step - loss: -0.3889 - val_loss: -0.4301 - learning_rate: 0.0010
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 204ms/step
[np.float64(0.38866246937395404), np.float64(0.4165862001734295), np.float64(0.47648880260325616), np.float64(0.4320480909972952)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 94ms/step 
4


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 595ms/step - loss: -0.2406 - val_loss: -0.3006 - learning_rate: 0.0010
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 461ms/step - loss: -0.3509 - val_loss: -0.3602 - learning_rate: 0.0010
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 478ms/step - loss: -0.4081 - val_loss: -0.4333 - learning_rate: 0.0010
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 222ms/step
[np.float64(0.38866246937395404), np.float64(0.4165862001734295), np.float64(0.47648880260325616), np.float64(0.4320480909972952), np.float64(0.438133903374123)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 103ms/step
Overall: 0.4303838933044107
0


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 308ms/step - loss: 9682.9248 - val_loss: 376.7199 - learning_rate: 5.0000e-04
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 203ms/step - loss: 315.0083 - val_loss: 157.8277 - learning_rate: 5.0000e-04
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 186ms/step - loss: 157.8118 - val_loss: 28.8294 - learning_rate: 5.0000e-04
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 178ms/step
[np.float64(0.08517633307644036)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step
1


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 291ms/step - loss: 4966.2393 - val_loss: 1487.5784 - learning_rate: 5.0000e-04
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 205ms/step - loss: 570.1090 - val_loss: 354.9667 - learning_rate: 5.0000e-04
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 186ms/step - loss: 266.3580 - val_loss: 146.3739 - learning_rate: 5.0000e-04
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 177ms/step
[np.float64(0.08517633307644036), np.float64(0.15231072579708155)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step
2


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 278ms/step - loss: 3252.1018 - val_loss: 711.7197 - learning_rate: 5.0000e-04
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 211ms/step - loss: 544.0789 - val_loss: 220.7589 - learning_rate: 5.0000e-04
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 240ms/step - loss: 371.8804 - val_loss: 173.7226 - learning_rate: 5.0000e-04
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 193ms/step
[np.float64(0.08517633307644036), np.float64(0.15231072579708155), np.float64(0.2247124700300189)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step
3


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 310ms/step - loss: 5247.8599 - val_loss: 185.6742 - learning_rate: 5.0000e-04
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 191ms/step - loss: 357.7753 - val_loss: 141.1672 - learning_rate: 5.0000e-04
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 193ms/step - loss: 184.2556 - val_loss: 75.8332 - learning_rate: 5.0000e-04
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 172ms/step
[np.float64(0.08517633307644036), np.float64(0.15231072579708155), np.float64(0.2247124700300189), np.float64(0.23951225806519605)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
4


Epoch 1/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 257ms/step - loss: 9553.7773 - val_loss: 29.2939 - learning_rate: 5.0000e-04
Epoch 2/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 156ms/step - loss: 370.8620 - val_loss: 105.7917 - learning_rate: 5.0000e-04
Epoch 3/3
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 158ms/step - loss: 237.3216 - val_loss: 69.3325 - learning_rate: 5.0000e-04
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 169ms/step
[np.float64(0.08517633307644036), np.float64(0.15231072579708155), np.float64(0.2247124700300189), np.float64(0.23951225806519605), np.float64(0.1889106973678733)]
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step
Overall: 0.17812449686732168
Blend: 0.38754734322874024


Because for this dataset we know the ground truth, we can compute the correlation score for test data:

In [31]:
from senkin_tmp_cite_pred.metrics import correlation_score

In [32]:
correlation_score(adata_prot[test_cell_ids].layers["dsb"][:, :3], test_preds)

np.float64(0.717635762821232)