# CITE-seq prediction for OpenProblems data format

In this example, we'll apply the best CITE-seq prediction model to the data format of the [modality prediction task](https://openproblems.bio/benchmarks/predict_modality?version=v1.0.0) in [OpenProblems](https://openproblems.bio/). For a general and a more detailed example, take a look at `example.ipynb`

In [1]:
import logging

import mudata as mu
import numpy as np
import scanpy as sc

from senkin_tmp_cite_pred.preprocess import preprocess_data

  from .autonotebook import tqdm as notebook_tqdm


Create directories for data and models:

In [2]:
logging.basicConfig(level=logging.DEBUG)

In [3]:
%%bash

mkdir -p data/
mkdir -p models/

DEBUG:asyncio:Using selector: KqueueSelector


In [4]:
adata_rna_train = sc.read_h5ad("data/train_mod1.h5ad")
adata_rna_test = sc.read_h5ad("data/test_mod1.h5ad")
adata_prot_train = sc.read_h5ad("data/train_mod2.h5ad")
adata_prot_test = sc.read_h5ad("data/test_mod2.h5ad")

DEBUG:h5py._conv:Creating converter from 3 to 5


In [5]:
adata_rna_train.shape, adata_rna_test.shape, adata_prot_train.shape, adata_prot_test.shape

((431, 134), (163, 134), (431, 134), (163, 134))

In the 2022 competition, the batch effect correction for this approach was done per day. Let's follow that, and extract day in a separate `.obs` column:

In [6]:
adata_rna_test.obs

Unnamed: 0,size_factors,batch
TCCGAAAAGAGGACTC-1-s4d8,22.0,s4d8
ATTCATCGTTAGTCGT-1-s4d1,15.0,s4d1
CATCGTCTCTGAGAGG-1-s4d8,23.0,s4d8
CATGGATGTACGCTAT-1-s4d1,7.0,s4d1
TCCTTCTTCCACCCTA-1-s2d1,15.0,s2d1
...,...,...
CATTCCGAGTGCTACT-1-s4d8,30.0,s4d8
AATGGAATCCATTCAT-1-s4d8,39.0,s4d8
GGCTTTCGTGGACCTC-1-s4d8,51.0,s4d8
GTGGTTAAGCATAGGC-1-s3d7,12.0,s3d7


In [7]:
def extract_day_and_donor_from_batch(batch: str) -> tuple[str, str]:
    "Extract day and donor numbers N and M from IDs of format sNdM"
    d_index = batch.find("d")
    day = batch[1: d_index]
    donor = batch[d_index + 1 :]
    return day, donor

adata_rna_train.obs["day"], adata_rna_train.obs["donor"] = zip(*adata_rna_train.obs["batch"].astype(str).map(extract_day_and_donor_from_batch))
adata_rna_test.obs["day"], adata_rna_test.obs["donor"] = zip(*adata_rna_test.obs["batch"].astype(str).map(extract_day_and_donor_from_batch))
adata_prot_train.obs["day"], adata_prot_train.obs["donor"] = zip(*adata_prot_train.obs["batch"].astype(str).map(extract_day_and_donor_from_batch))
adata_prot_test.obs["day"], adata_prot_test.obs["donor"] = zip(*adata_prot_test.obs["batch"].astype(str).map(extract_day_and_donor_from_batch))


In [8]:
adata_rna_train.obs["split"] = "train"
adata_rna_test.obs["split"] = "test"
adata_prot_train.obs["split"] = "train"
adata_prot_test.obs["split"] = "test"

Put all the modalities in one object. This will make sure that the indexes are aligned

In [9]:
adata_rna_train.var

Unnamed: 0,hvg,hvg_score
ENSG00000104320,True,-0.674490
ENSG00000187713,True,-0.251418
ENSG00000260349,True,0.391433
ENSG00000150756,True,-0.720957
ENSG00000105676,True,-0.674490
...,...,...
ENSG00000241343,True,-4.003093
ENSG00000196267,True,0.000000
ENSG00000112078,True,0.686500
ENSG00000261136,True,-0.698529


In [10]:
mdata = mu.MuData({"rna": sc.concat([adata_rna_train, adata_rna_test], axis=0), "prot": sc.concat([adata_prot_train, adata_prot_test], axis=0)})
mdata

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


We can add a set of known features, which are important for prediction of the given surface proteins. In the original solution, it was a set of genes that encode proteins in the CITE-seq data. Here, we'll use genes that encoded proteins used **in the 2022 competition**. Make sure to select these features appropriately, this is very important! Additionally, double check that IDs have an appropriate format: same type of encoding, no suffixes, etc.

In [11]:
competition_2022_genes = [
    "ENSG00000134256",
    "ENSG00000083457",
    "ENSG00000185896",
    "ENSG00000130202",
    "ENSG00000182578",
    "ENSG00000027697",
    "ENSG00000005844",
    "ENSG00000169896",
    "ENSG00000140678",
    "ENSG00000100385",
    "ENSG00000185291",
    "ENSG00000292332",
    "ENSG00000077238",
    "ENSG00000168685",
    "ENSG00000166825",
    "ENSG00000186827",
    "ENSG00000049249",
    "ENSG00000131495",
    "ENSG00000170458",
    "ENSG00000178726",
    "ENSG00000117525",
    "ENSG00000076706",
    "ENSG00000163599",
    "ENSG00000102245",
    "ENSG00000073008",
    "ENSG00000274108",
    "ENSG00000283708",
    "ENSG00000284241",
    "ENSG00000275623",
    "ENSG00000277317",
    "ENSG00000284236",
    "ENSG00000274410",
    "ENSG00000275008",
    "ENSG00000283702",
    "ENSG00000277484",
    "ENSG00000274830",
    "ENSG00000284044",
    "ENSG00000276590",
    "ENSG00000284510",
    "ENSG00000243772",
    "ENSG00000274952",
    "ENSG00000277924",
    "ENSG00000284333",
    "ENSG00000284132",
    "ENSG00000283996",
    "ENSG00000276218",
    "ENSG00000277554",
    "ENSG00000278327",
    "ENSG00000284504",
    "ENSG00000283790",
    "ENSG00000274402",
    "ENSG00000273947",
    "ENSG00000273887",
    "ENSG00000276459",
    "ENSG00000278369",
    "ENSG00000275658",
    "ENSG00000276011",
    "ENSG00000275914",
    "ENSG00000278731",
    "ENSG00000278692",
    "ENSG00000275546",
    "ENSG00000274412",
    "ENSG00000277725",
    "ENSG00000277251",
    "ENSG00000276731",
    "ENSG00000273661",
    "ENSG00000273578",
    "ENSG00000275960",
    "ENSG00000274518",
    "ENSG00000275452",
    "ENSG00000275735",
    "ENSG00000278300",
    "ENSG00000274438",
    "ENSG00000275253",
    "ENSG00000276425",
    "ENSG00000276139",
    "ENSG00000276258",
    "ENSG00000277885",
    "ENSG00000278152",
    "ENSG00000275737",
    "ENSG00000274948",
    "ENSG00000273775",
    "ENSG00000283731",
    "ENSG00000284093",
    "ENSG00000274036",
    "ENSG00000274920",
    "ENSG00000284426",
    "ENSG00000276329",
    "ENSG00000275717",
    "ENSG00000276423",
    "ENSG00000284177",
    "ENSG00000284589",
    "ENSG00000284342",
    "ENSG00000275659",
    "ENSG00000276379",
    "ENSG00000275786",
    "ENSG00000275486",
    "ENSG00000283827",
    "ENSG00000278079",
    "ENSG00000277175",
    "ENSG00000278856",
    "ENSG00000276501",
    "ENSG00000273518",
    "ENSG00000275545",
    "ENSG00000278368",
    "ENSG00000167633",
    "ENSG00000277272",
    "ENSG00000274146",
    "ENSG00000278427",
    "ENSG00000283954",
    "ENSG00000283729",
    "ENSG00000275288",
    "ENSG00000203747",
    "ENSG00000162747",
    "ENSG00000111796",
    "ENSG00000110876",
    "ENSG00000177575",
    "ENSG00000088827",
    "ENSG00000198053",
    "ENSG00000160255",
    "ENSG00000160683",
    "ENSG00000177455",
    "ENSG00000121807",
    "ENSG00000183813",
    "ENSG00000160791",
    "ENSG00000112486",
    "ENSG00000158481",
    "ENSG00000158473",
    "ENSG00000116824",
    "ENSG00000171431",
    "ENSG00000263057",
    "ENSG00000156738",
    "ENSG00000117322",
    "ENSG00000012124",
    "ENSG00000089692",
    "ENSG00000100031",
    "ENSG00000150637",
    "ENSG00000104921",
    "ENSG00000272398",
    "ENSG00000122223",
    "ENSG00000172183",
    "ENSG00000134460",
    "ENSG00000197635",
    "ENSG00000159958",
    "ENSG00000139193",
    "ENSG00000157873",
    "ENSG00000273936",
    "ENSG00000186265",
    "ENSG00000120217",
    "ENSG00000163600",
    "ENSG00000276977",
    "ENSG00000188389",
    "ENSG00000178562",
    "ENSG00000150093",
    "ENSG00000198178",
    "ENSG00000099250",
    "ENSG00000261371",
    "ENSG00000213809",
    "ENSG00000026751",
    "ENSG00000072694",
    "ENSG00000143226",
    "ENSG00000168995",
    "ENSG00000105383",
    "ENSG00000275521",
    "ENSG00000273916",
    "ENSG00000277442",
    "ENSG00000273506",
    "ENSG00000276450",
    "ENSG00000277824",
    "ENSG00000278362",
    "ENSG00000278025",
    "ENSG00000275637",
    "ENSG00000275822",
    "ENSG00000274053",
    "ENSG00000277629",
    "ENSG00000189430",
    "ENSG00000275156",
    "ENSG00000273535",
    "ENSG00000277334",
    "ENSG00000203710",
    "ENSG00000162739",
    "ENSG00000135218",
    "ENSG00000004468",
    "ENSG00000138185",
    "ENSG00000010610",
    "ENSG00000101017",
    "ENSG00000005961",
    "ENSG00000185245",
    "ENSG00000026508",
    "ENSG00000262418",
    "ENSG00000081237",
    "ENSG00000196776",
    "ENSG00000117091",
    "ENSG00000213949",
    "ENSG00000164171",
    "ENSG00000115232",
    "ENSG00000091409",
    "ENSG00000110448",
    "ENSG00000169442",
    "ENSG00000090339",
    "ENSG00000149294",
    "ENSG00000109956",
    "ENSG00000116815",
    "ENSG00000188404",
    "ENSG00000174175",
    "ENSG00000135404",
    "ENSG00000150337",
    "ENSG00000110848",
    "ENSG00000173762",
    "ENSG00000072274",
    "ENSG00000137101",
    "ENSG00000135318",
    "ENSG00000007312",
    "ENSG00000153563",
    "ENSG00000110651",
    "ENSG00000085117",
    "ENSG00000112149",
    "ENSG00000276452",
    "ENSG00000274669",
    "ENSG00000277134",
    "ENSG00000104972",
    "ENSG00000277807",
    "ENSG00000114013",
    "ENSG00000197405",
    "ENSG00000010278",
    "ENSG00000125810",
    "ENSG00000134539",
    "ENSG00000026103",
    "ENSG00000168329",
    "ENSG00000179639",
    "ENSG00000291905",
    "ENSG00000206493",
    "ENSG00000225201",
    "ENSG00000233904",
    "ENSG00000230254",
    "ENSG00000236632",
    "ENSG00000229252",
    "ENSG00000204592",
    "ENSG00000102245",
    "ENSG00000139187",
    "ENSG00000205274",
    "ENSG00000181847",
]


In [12]:
known_features = np.intersect1d(competition_2022_genes, mdata.mod["rna"].var_names)
print(f"Number of proteins with identical names in RNA: {len(known_features)}")

Number of proteins with identical names in RNA: 1


Now we can perform required preprocessing with just one function. It includes:
- Removal of constant features
- 200 components TSVD of CLR-transformed data
- 100 components PCA of customly normalized data (see documentation of `senkin_normalize` for details)
- Selection of raw expression for known features and for genes correlated with target proteins
- DSB normalization of protein data. For this `empty_count_range` setting is needed. Here, we used the same values as in 2022 competition. But this is data dependent, so make sure to select a meaningful range!

Note that the nomber of cells will become smaller after this step, because some droplets will be used as a background for DSB normalization. Additionally, number of genes will decrease as constant genes will be removed from RNA data.

In [13]:
import logging
logging.basicConfig(level=logging.DEBUG)

Put raw counts in .X for both modalities:

In [14]:
mdata.mod["rna"].X = mdata.mod["rna"].layers["counts"]
mdata.mod["prot"].X = mdata.mod["prot"].layers["counts"]

In [15]:
mdata = preprocess_data(
    mdata,
    empty_counts_range=(0., 1.),  # Note that this is data-dependent! Data must contain empty droplets.
    batch_key="day",
    group_key="donor",
    known_features=known_features
)

INFO:senkin_tmp_cite_pred.preprocess:DSB-normalizing protein data. The number of cells will be reduced.
DEBUG:senkin_tmp_cite_pred.preprocess:Number of cells before DSB: 594
  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)
DEBUG:senkin_tmp_cite_pred.preprocess:Number of cells after DSB: 507
INFO:senkin_tmp_cite_pred.preprocess:Starting RNA preprocessing
INFO:senkin_tmp_cite_pred.preprocess:Normalization and log1p-transformation
INFO:senkin_tmp_cite_pred.preprocess:Removing constant variables
INFO:senkin_tmp_cite_pred.preprocess:Computing CLR-TSVD transformation with 200 components
  warn("adata.X is sparse but not in CSC format. Converting to CSC.")
INFO:senkin_tmp_cite_pred.preprocess:Applying Senkin normalization
  return super()._validate_value(val, key)
INFO:senkin_tmp_cite_pred.preprocess:Computing PCA with 100 components
INFO:senkin_tmp_cite_pred.preprocess:Verifying observation names match between RNA and pro

Note new `obsm` fields in RNA modality, and a `dsb` layer in the protein modality:

In [16]:
mdata

# LGBM predictions

Next, we'll use preprocessed data to make initial prediction of target proteints with Light Gradient Bossting (LGBM) models. The dimensionality of the predictions will be further reduced with TSVD to 100 components. While these predictions are not used per se, they will become features for neural networks. As we showed in our analysis, this makes the resulting prediction much more accurate.

In [17]:
from senkin_tmp_cite_pred.lgbm_models import train_lightgbm_models
from sklearn.model_selection import KFold

In [18]:
adata_rna = mdata.mod["rna"]
adata_rna

AnnData object with n_obs × n_vars = 507 × 134
    obs: 'size_factors', 'batch', 'day', 'donor', 'split'
    uns: 'log1p', 'selected_features'
    obsm: 'X_log_normalized', 'X_clr_tsvd', 'X_sqrt_norm', 'X_pca_sqrt_norm', 'X_raw_selected'
    layers: 'counts', 'normalized'

In [19]:
adata_prot = mdata.mod["prot"]
adata_prot

AnnData object with n_obs × n_vars = 507 × 134
    obs: 'size_factors', 'batch', 'day', 'donor', 'split'
    layers: 'counts', 'normalized', 'dsb'

In [20]:
adata_rna.obs["day"].value_counts()

day
2    147
3    145
4    112
1    103
Name: count, dtype: int64

In [21]:
adata_rna.obs["split"].value_counts()

split
train    353
test     154
Name: count, dtype: int64

Additionally, we'll define cross-validation splits for the training dataset. It'll help us to see how well our models predict the data on an unseen fold. In this example we'll use 2 folds because we have a very small dataset for testing. In the original approach, 5 folds CV was used. **Set `n_splits` to 5 to match the original approach**

In [22]:
folds = KFold(n_splits=2, shuffle=True, random_state=666)

We'll save which cell comes from each fold to make sure results are reproducible after the `folds` variable is lost:

In [23]:
for i, (train_idxs, test_idxs) in enumerate(folds.split(adata_rna)):
    adata_rna.obs[f"fold_{i}"] = "test"
    adata_rna.obs.loc[adata_rna.obs_names[train_idxs], f"fold_{i}"] = "train"

adata_rna.obs["fold_0"].value_counts()

fold_0
test     254
train    253
Name: count, dtype: int64

In [24]:
train_cell_ids = adata_rna.obs_names[adata_rna.obs["split"] == "train"]
test_cell_ids = adata_rna.obs_names[adata_rna.obs["split"] == "test"]

Train LGBM models. 4 models will be trained, using different subsets of training data, and different preprocessing of the target proteins. In this example, we'll only build models for 3 proteins here, and train them for only 3 epochs. Remove `[:, 3]` after `adata_prot` to use all the proteins, and remove `num_boost_round` `early_stopping_rounds`, and `n_tsvd_components` to use the default parameters

In [None]:
adata_rna = train_lightgbm_models(
    adata_rna,
    adata_prot,
    train_cell_ids=train_cell_ids,
    test_cell_ids=test_cell_ids,
    folds=folds,
    num_boost_round=3,
    early_stopping_rounds=2,
    n_tsvd_components=10
)

INFO:senkin_tmp_cite_pred.lgbm_models:Initializing arrays in obsm with zeros
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model 1 for predicting DSB-normalized protein expression from log-normalized RNA expression
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models for 3 targets
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
INFO:senkin_tmp_cite_pred.lgbm_models:CV score: 0.20403000504862948
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predi

Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.247725
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.294214
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 1.62377
Training until validation scores don't improve for 2 rounds
Early stopping, best iteration is:
[1]	valid_0's l2: 0.645025
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.220594
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.331571
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.248225
Training until validation scores don't improve for 2 rounds
Did not m

DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
INFO:senkin_tmp_cite_pred.lgbm_models:CV score: 0.20412019992339497
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predictions to 2 components
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model 3 for predicting DSB-normalized protein expression from raw RNA expression
INFO:senkin_tmp_cite_pred.lgbm_models:Training LightGBM models for 3 targets
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
DEBUG:senkin_tmp_cite_pred.lgbm_models:Training LightGBM model for target 2
DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 0
DEBUG

Training until validation scores don't improve for 2 rounds
Early stopping, best iteration is:
[1]	valid_0's l2: 0.644973
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.218873
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 0.331109
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[2]	valid_0's l2: 0.247778
Training until validation scores don't improve for 2 rounds
Early stopping, best iteration is:
[1]	valid_0's l2: 0.293921
Training until validation scores don't improve for 2 rounds
Early stopping, best iteration is:
[1]	valid_0's l2: 1.62402
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[2]	valid_0's l2: 0.644732
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best i

DEBUG:senkin_tmp_cite_pred.lgbm_models:Fold: 1
INFO:senkin_tmp_cite_pred.lgbm_models:CV score: 0.8122238423558394
INFO:senkin_tmp_cite_pred.lgbm_models:TSVD-reducing predictions to 2 components


Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 439.644
Training until validation scores don't improve for 2 rounds
Did not meet early stopping. Best iteration is:
[3]	valid_0's l2: 804.945


Note that we know have 4 additional `X_lgbm` arrays in the `.obsm`:

In [26]:
adata_rna

AnnData object with n_obs × n_vars = 507 × 134
    obs: 'size_factors', 'batch', 'day', 'donor', 'split', 'fold_0', 'fold_1'
    uns: 'log1p', 'selected_features'
    obsm: 'X_log_normalized', 'X_clr_tsvd', 'X_sqrt_norm', 'X_pca_sqrt_norm', 'X_raw_selected', 'X_lgbm_1', 'X_lgbm_2', 'X_lgbm_3', 'X_lgbm_4'
    layers: 'counts', 'normalized'

# Train neural networks

In [27]:
from senkin_tmp_cite_pred.nn_models import train_nn_models

Finally, we'll train two neural networks with originally proposed architectures to predict target proteins. First, let's prepare datasets:

In [28]:
train_cite_X = np.concatenate([
    adata_rna[train_cell_ids].obsm["X_clr_tsvd"],
    adata_rna[train_cell_ids].obsm["X_pca_sqrt_norm"],
    adata_rna[train_cell_ids].obsm["X_raw_selected"].toarray(),
    adata_rna[train_cell_ids].obsm["X_lgbm_1"],
    adata_rna[train_cell_ids].obsm["X_lgbm_2"],
    adata_rna[train_cell_ids].obsm["X_lgbm_3"],
    adata_rna[train_cell_ids].obsm["X_lgbm_4"],
], axis=1)

test_cite_X = np.concatenate([
    adata_rna[test_cell_ids].obsm["X_clr_tsvd"],
    adata_rna[test_cell_ids].obsm["X_pca_sqrt_norm"],
    adata_rna[test_cell_ids].obsm["X_raw_selected"].toarray(),
    adata_rna[test_cell_ids].obsm["X_lgbm_1"],
    adata_rna[test_cell_ids].obsm["X_lgbm_2"],
    adata_rna[test_cell_ids].obsm["X_lgbm_3"],
    adata_rna[test_cell_ids].obsm["X_lgbm_4"],
], axis=1)

train_cite_y = adata_prot[train_cell_ids].layers["dsb"]

In [29]:
print("Total number of features:", train_cite_X.shape[1])

Total number of features: 360


This function will perform cross-validated training and prediction, and return us aggregated prediction of two models:

In [30]:
train_preds, test_preds = train_nn_models(
    train_cell_ids,
    train_cite_X,
    train_cite_y,
    test_cell_ids,
    test_cite_X,
    folds, 
    EPOCHS=3
)

0


Epoch 1/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4s/step - loss: 0.0000e+00

DEBUG:h5py._conv:Creating converter from 5 to 3


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - loss: 0.0000e+00 - val_loss: -0.3448 - learning_rate: 0.0010
Epoch 2/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 204ms/step - loss: -0.5629 - val_loss: -0.1883 - learning_rate: 0.0010
Epoch 3/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 195ms/step - loss: -0.4221 - val_loss: -0.3197 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 313ms/step
[np.float64(0.3448404183276398)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 375ms/step
1


Epoch 1/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - loss: 0.0000e+00 - val_loss: -0.3524 - learning_rate: 0.0010
Epoch 2/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 181ms/step - loss: -0.5597 - val_loss: -0.1808 - learning_rate: 0.0010
Epoch 3/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 174ms/step - loss: -0.4227 - val_loss: -0.3290 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 277ms/step
[np.float64(0.3448404183276398), np.float64(0.3524359526053535)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 301ms/step
Overall: 0.3486274269193613
0


Epoch 1/3




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - loss: 1.0430 - val_loss: 0.9531 - learning_rate: 5.0000e-04
Epoch 2/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 493ms/step - loss: 0.9554 - val_loss: 0.9331 - learning_rate: 5.0000e-04
Epoch 3/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 134ms/step - loss: 0.8933 - val_loss: 0.9506 - learning_rate: 5.0000e-04




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 365ms/step
[np.float64(0.29387561220797964)]




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 352ms/step
1


Epoch 1/3




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - loss: 1.0392 - val_loss: 0.9713 - learning_rate: 5.0000e-04
Epoch 2/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 356ms/step - loss: 0.9267 - val_loss: 0.9339 - learning_rate: 5.0000e-04
Epoch 3/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 350ms/step - loss: 0.8717 - val_loss: 0.9072 - learning_rate: 5.0000e-04
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 307ms/step
[np.float64(0.29387561220797964), np.float64(0.31625227695276403)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 306ms/step
Overall: 0.30503224958781566
Blend: 0.33954614007112444


Because for this dataset we know the ground truth, we can compute the correlation score for test data:

In [31]:
from senkin_tmp_cite_pred.metrics import correlation_score

In [33]:
correlation_score(adata_prot[test_cell_ids].layers["dsb"], test_preds)

np.float64(0.25758923033475933)