In [48]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import sc_utils
import scvi
import catboost
import pickle

In [6]:
import h5py
import hdf5plugin

In [2]:
pd.options.display.max_rows = 200

In [3]:
%config InlineBackend.figure_format = "retina"

## 1. Load data

In [4]:
ds = sc.read_h5ad("../data/h5ad/jax_rna.h5ad")

In [5]:
ds

AnnData object with n_obs × n_vars = 119651 × 2000
    obs: 'day', 'donor', 'cell_type', 'technology', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'sample', 'leiden', 'diff_leiden', 'dpt_pseudotime', 'jax_leiden'
    var: 'ensembl_id', 'mito', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'cell_type_colors', 'day_colors', 'diff_leiden_colors', 'diff_leiden_sizes', 'diff_neighbors', 'diffmap_evals', 'donor_colors', 'hvg', 'iroot', 'jax_leiden_colors', 'jax_neighbors', 'leiden', 'leiden_colors', 'leiden_sizes', 'neighbors', 'paga', 'pca', 'rank_genes_groups', 'sample_colors', 'umap'
    obsm: 'X_diffmap', 'X_jax', 'X_pca', 'X_umap'
    obsp: 'connectivities', 'diff_neighbors_connectivities', 'diff_neighbors_distances', 'distances', 'jax_neighbors_connectivities', 'jax_neighbors_distances'

## 2. Load targets

In [14]:
meta = pd.read_csv("../data/metadata.csv", index_col=0)

In [7]:
f = h5py.File("")

In [18]:
def read_h5_as_df(path, meta, key="train_cite_inputs"):
    f = h5py.File(path)
    expr = f[key]["block0_values"][:]
    genes = f[key]["axis0"].asstr()[:]
    cells = f[key]["axis1"].asstr()[:]
    return pd.DataFrame(expr, index=cells, columns=genes)

In [19]:
targets = read_h5_as_df("../data/train_cite_targets.h5", meta, key="train_cite_targets")

In [21]:
targets

Unnamed: 0,CD86,CD274,CD270,CD155,CD112,CD47,CD48,CD40,CD154,CD52,...,CD94,CD162,CD85j,CD23,CD328,HLA-E,CD82,CD101,CD88,CD224
45006fe3e4c8,1.167804,0.622530,0.106959,0.324989,3.331674,6.426002,1.480766,-0.728392,-0.468851,-0.073285,...,-0.448390,3.220174,-0.533004,0.674956,-0.006187,0.682148,1.398105,0.414292,1.780314,0.548070
d02759a80ba2,0.818970,0.506009,1.078682,6.848758,3.524885,5.279456,4.930438,2.069372,0.333652,-0.468088,...,0.323613,8.407108,0.131301,0.047607,-0.243628,0.547864,1.832587,0.982308,2.736507,2.184063
c016c6b0efa5,-0.356703,-0.422261,-0.824493,1.137495,0.518924,7.221962,-0.375034,1.738071,0.142919,-0.971460,...,1.348692,4.888579,-0.279483,-0.131097,-0.177604,-0.689188,9.013709,-1.182975,3.958148,2.868600
ba7f733a4f75,-1.201507,0.149115,2.022468,6.021595,7.258670,2.792436,21.708519,-0.137913,1.649969,-0.754680,...,1.504426,12.391979,0.511394,0.587863,-0.752638,1.714851,3.893782,1.799661,1.537249,4.407671
fbcf2443ffb2,-0.100404,0.697461,0.625836,-0.298404,1.369898,3.254521,-1.659380,0.643531,0.902710,1.291877,...,0.777023,6.496499,0.279898,-0.841950,-0.869419,0.675092,5.259685,-0.835379,9.631781,1.765445
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
650ee456f0f3,0.905420,0.386141,0.961590,5.090580,2.854346,6.093729,-0.586178,0.452389,0.040806,0.191407,...,1.261118,3.092832,0.003275,0.278930,-0.272002,0.249477,3.789460,0.138330,1.466193,4.278504
cc506e7707f5,2.101247,2.117462,0.112699,2.065512,2.176803,3.900090,-0.586001,-0.175479,1.363232,0.109905,...,0.714624,5.029233,0.909861,0.057322,2.633387,1.340077,11.456146,-1.431453,5.275882,2.510530
a91f1b55a520,1.221313,0.476566,1.437551,5.135631,2.926102,1.615081,-0.586910,1.760421,1.944711,-0.095096,...,-0.176027,5.027534,-0.703609,1.139491,-0.078092,1.592960,9.358179,0.981883,6.911032,3.415310
3a9882c98205,-0.151433,-0.850024,0.461556,3.546561,1.996473,5.702821,0.883038,1.309014,1.029737,-0.072851,...,-0.484493,12.883892,1.579381,-0.382835,-0.065286,-0.021458,7.372662,1.010247,1.864805,3.449289


## 2. Catboost model

In [30]:
ds.obs_names = ds.obs_names.str.replace("-\d+$", "", regex=True)

In [31]:
train_x = ds[targets.index, :].obsm["X_jax"]

In [36]:
model = catboost.CatBoostRegressor(
    iterations=1000,
    # task_type="GPU",
    loss_function="MultiRMSE",
    eval_metric="MultiRMSE"
)

In [37]:
model.fit(train_x, targets)

0:	learn: 25.9322109	total: 1.35s	remaining: 22m 26s
1:	learn: 25.7480553	total: 2.66s	remaining: 22m 9s
2:	learn: 25.5693188	total: 3.98s	remaining: 22m 3s
3:	learn: 25.3982765	total: 5.21s	remaining: 21m 36s
4:	learn: 25.2298558	total: 6.37s	remaining: 21m 8s
5:	learn: 25.0729520	total: 7.68s	remaining: 21m 12s
6:	learn: 24.9191016	total: 8.84s	remaining: 20m 53s
7:	learn: 24.7797825	total: 10s	remaining: 20m 44s
8:	learn: 24.6365317	total: 11.3s	remaining: 20m 49s
9:	learn: 24.5051204	total: 12.7s	remaining: 20m 57s
10:	learn: 24.3822681	total: 13.7s	remaining: 20m 30s
11:	learn: 24.2592870	total: 14.9s	remaining: 20m 26s
12:	learn: 24.1423655	total: 16.1s	remaining: 20m 23s
13:	learn: 24.0315311	total: 17.5s	remaining: 20m 30s
14:	learn: 23.9204742	total: 18.6s	remaining: 20m 22s
15:	learn: 23.8121009	total: 19.6s	remaining: 20m 7s
16:	learn: 23.7120350	total: 20.8s	remaining: 20m 3s
17:	learn: 23.6168842	total: 22.2s	remaining: 20m 11s
18:	learn: 23.5253306	total: 23.2s	remaining:

<catboost.core.CatBoostRegressor at 0x2b13bc29f730>

In [38]:
def correlation_score(y_true, y_pred):
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [39]:
y_pred = model.predict(train_x)

In [40]:
correlation_score(targets, y_pred)

0.8963400960239647

In [43]:
model.save_model("05jax_catboost.model")

In [44]:
test_x = ds[~ds.obs_names.isin(targets.index), :].obsm["X_jax"]

In [45]:
test_pred = model.predict(test_x)

In [46]:
test_pred.shape

(48663, 140)

In [49]:
with open("05_jax_preds/cite_test_preds.pickle", "wb") as f:
    pickle.dump(test_pred, f)

In [50]:
with open("05_jax_preds/partial_submission_multi.pickle", "rb") as f:
    submission = pickle.load(f)

submission.iloc[:len(test_pred.ravel())] = test_pred.ravel()
assert not submission.isna().any()
submission = submission.round(6) # reduce the size of the csv
submission.to_csv("05_jax_preds/submission.csv")
submission

row_id
0           0.273072
1           0.293944
2           0.658493
3           4.448602
4           4.789261
              ...   
65744175    7.271666
65744176    0.017597
65744177    0.025040
65744178    1.933743
65744179    5.111444
Name: target, Length: 65744180, dtype: float64

This submission scores 0.794

In [None]:
# from https://www.kaggle.com/code/geraseva/magic-correlation
cite_cols_important={'CD86': ['ENSG00000114013_CD86'],
             'CD274': ['ENSG00000120217_CD274'],
             'CD270': ['ENSG00000157873_TNFRSF14'],
             'CD155': ['ENSG00000073008_PVR'],
             'CD112': ['ENSG00000130202_NECTIN2'],
             'CD47': ['ENSG00000196776_CD47'],
             'CD48': ['ENSG00000117091_CD48'],
             'CD40': ['ENSG00000101017_CD40'],
             'CD154': ['ENSG00000102245_CD40LG'],
             'CD52': ['ENSG00000169442_CD52'],
             'CD3': ['ENSG00000167286_CD3D'],
             'CD8': [],
             'CD56': ['ENSG00000149294_NCAM1'],
             'CD19': ['ENSG00000177455_CD19'],
             'CD33': ['ENSG00000105383_CD33'],
             'CD11c': ['ENSG00000140678_ITGAX'],
             'HLA-A-B-C': ['ENSG00000204525_HLA-C',
              'ENSG00000206503_HLA-A',
              'ENSG00000234745_HLA-B'],
             'CD45RA': ['ENSG00000081237_PTPRC'],
             'CD123': ['ENSG00000185291_IL3RA'],
             'CD7': ['ENSG00000173762_CD7'],
             'CD105': ['ENSG00000106991_ENG'],
             'CD49f': ['ENSG00000091409_ITGA6'],
             'CD194': ['ENSG00000183813_CCR4'],
             'CD4': ['ENSG00000010610_CD4'],
             'CD44': ['ENSG00000026508_CD44'],
             'CD14': ['ENSG00000170458_CD14'],
             'CD16': [],
             'CD25': ['ENSG00000134460_IL2RA'],
             'CD45RO': ['ENSG00000081237_PTPRC'],
             'CD279': [],
             'TIGIT': [],
             'Mouse-IgG1': [],
             'Mouse-IgG2a': [],
             'Mouse-IgG2b': [],
             'Rat-IgG2b': [],
             'CD20': ['ENSG00000156738_MS4A1'],
             'CD335': ['ENSG00000189430_NCR1'],
             'CD31': ['ENSG00000261371_PECAM1'],
             'Podoplanin': [],
             'CD146': ['ENSG00000076706_MCAM'],
             'IgM': ['ENSG00000211899_IGHM'],
             'CD5': [],
             'CD195': ['ENSG00000160791_CCR5'],
             'CD32': ['ENSG00000143226_FCGR2A'],
             'CD196': [],
             'CD185': ['ENSG00000160683_CXCR5'],
             'CD103': ['ENSG00000083457_ITGAE'],
             'CD69': ['ENSG00000110848_CD69'],
             'CD62L': ['ENSG00000188404_SELL'],
             'CD161': ['ENSG00000111796_KLRB1'],
             'CD152': [],
             'CD223': ['ENSG00000089692_LAG3'],
             'KLRG1': ['ENSG00000139187_KLRG1'],
             'CD27': ['ENSG00000139193_CD27'],
             'CD107a': ['ENSG00000185896_LAMP1'],
             'CD95': ['ENSG00000026103_FAS'],
             'CD134': ['ENSG00000186827_TNFRSF4'],
             'HLA-DR': ['ENSG00000204287_HLA-DRA'],
             'CD1c': ['ENSG00000158481_CD1C'],
             'CD11b': ['ENSG00000169896_ITGAM'],
             'CD64': ['ENSG00000150337_FCGR1A'],
             'CD141': ['ENSG00000178726_THBD'],
             'CD1d': ['ENSG00000158473_CD1D'],
             'CD314': [],
             'CD35': ['ENSG00000203710_CR1'],
             'CD57': [],
             'CD272': [],
             'CD278': ['ENSG00000163600_ICOS'],
             'CD58': ['ENSG00000116815_CD58'],
             'CD39': ['ENSG00000138185_ENTPD1'],
             'CX3CR1': ['ENSG00000168329_CX3CR1'],
             'CD24': ['ENSG00000272398_CD24'],
             'CD21': ['ENSG00000117322_CR2'],
             'CD11a': ['ENSG00000005844_ITGAL'],
             'CD79b': ['ENSG00000007312_CD79B'],
             'CD244': ['ENSG00000122223_CD244'],
             'CD169': [],
             'integrinB7': ['ENSG00000139626_ITGB7'],
             'CD268': ['ENSG00000159958_TNFRSF13C'],
             'CD42b': ['ENSG00000185245_GP1BA'],
             'CD54': ['ENSG00000090339_ICAM1'],
             'CD62P': ['ENSG00000174175_SELP'],
             'CD119': ['ENSG00000027697_IFNGR1'],
             'TCR': [],
             'Rat-IgG1': [],
             'Rat-IgG2a': [],
             'CD192': ['ENSG00000121807_CCR2'],
             'CD122': ['ENSG00000100385_IL2RB'],
             'FceRIa': ['ENSG00000179639_FCER1A'],
             'CD41': ['ENSG00000005961_ITGA2B'],
             'CD137': ['ENSG00000049249_TNFRSF9'],
             'CD163': ['ENSG00000177575_CD163'],
             'CD83': ['ENSG00000112149_CD83'],
             'CD124': ['ENSG00000077238_IL4R'],
             'CD13': ['ENSG00000166825_ANPEP'],
             'CD2': ['ENSG00000116824_CD2'],
             'CD226': ['ENSG00000150637_CD226'],
             'CD29': ['ENSG00000150093_ITGB1'],
             'CD303': ['ENSG00000198178_CLEC4C'],
             'CD49b': ['ENSG00000164171_ITGA2'],
             'CD81': ['ENSG00000110651_CD81'],
             'IgD': ['ENSG00000211898_IGHD'],
             'CD18': ['ENSG00000160255_ITGB2'],
             'CD28': [],
             'CD38': ['ENSG00000004468_CD38'],
             'CD127': ['ENSG00000168685_IL7R'],
             'CD45': ['ENSG00000081237_PTPRC'],
             'CD22': ['ENSG00000012124_CD22'],
             'CD71': ['ENSG00000072274_TFRC'],
             'CD26': ['ENSG00000197635_DPP4'],
             'CD115': ['ENSG00000182578_CSF1R'],
             'CD63': ['ENSG00000135404_CD63'],
             'CD304': ['ENSG00000099250_NRP1'],
             'CD36': ['ENSG00000135218_CD36'],
             'CD172a': ['ENSG00000198053_SIRPA'],
             'CD72': ['ENSG00000137101_CD72'],
             'CD158': [],
             'CD93': ['ENSG00000125810_CD93'],
             'CD49a': ['ENSG00000213949_ITGA1'],
             'CD49d': ['ENSG00000115232_ITGA4'],
             'CD73': [],
             'CD9': ['ENSG00000010278_CD9'],
             'TCRVa7.2': [],
             'TCRVd2': [],
             'LOX-1': ['ENSG00000173391_OLR1'],
             'CD158b': [],
             'CD158e1': [],
             'CD142': ['ENSG00000117525_F3'],
             'CD319': ['ENSG00000026751_SLAMF7'],
             'CD352': ['ENSG00000162739_SLAMF6'],
             'CD94': ['ENSG00000134539_KLRD1'],
             'CD162': ['ENSG00000110876_SELPLG'],
             'CD85j': ['ENSG00000104972_LILRB1'],
             'CD23': ['ENSG00000104921_FCER2'],
             'CD328': ['ENSG00000168995_SIGLEC7'],
             'HLA-E': ['ENSG00000204592_HLA-E'],
             'CD82': ['ENSG00000085117_CD82'],
             'CD101': ['ENSG00000134256_CD101'],
             'CD88': ['ENSG00000197405_C5AR1'],
             'CD224': ['ENSG00000100031_GGT1']}