In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import sc_utils
import catboost
import pickle
import os
import requests

In [2]:
import h5py
import hdf5plugin

In [3]:
pd.options.display.max_rows = 200

In [4]:
%config InlineBackend.figure_format = "retina"

## 1. Load targets

In [5]:
f = h5py.File("../data/train_multi_targets.h5")

In [6]:
all_genes = f["train_multi_targets"]["axis0"].asstr()[:]

In [7]:
all_genes = pd.Series(all_genes)

In [8]:
f["train_multi_targets"]["block0_values"]

<HDF5 dataset "block0_values": shape (105942, 23418), type "<f4">

## 2. Load train latent

In [9]:
GFF3_URL = "https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_42/gencode.v42.annotation.gff3.gz"
GFF3 = "gencode.v42.annotation.gff3.gz"
if not os.path.exists(GFF3):
    with requests.get(GFF3_URL, stream=True) as res:
        res.raise_for_status()
        filesize = int(res.headers["Content-Length"])
        with open(GFF3, "wb") as df:
            total_bytes_received = 0
            for chunk in res.iter_content(chunk_size=1024 * 1024):
                df.write(chunk)
                total_bytes_received += len(chunk)
                percent_of_total_upload = float("{:.1f}".format(total_bytes_received / filesize * 100))
                color = "\033[38;5;10m" if percent_of_total_upload == 100 else ""
                print(f"\033[1m{color}{percent_of_total_upload}% downloaded\033[0m\r", end="")

In [10]:
gencode = pd.read_table(
    GFF3, 
    comment="#",
    sep="\t", 
    names=['seqname', 'source', 'feature', 'start' , 'end', 'score', 'strand', 'frame', 'attribute']
)

In [11]:
gencode = gencode.loc[gencode.feature.eq("gene"), :]

In [12]:
gencode["gene_id"] = gencode.attribute.str.split(";").str[1].str.split("=").str[1]

In [13]:
gencode.gene_id = gencode.gene_id.str.split(".").str[0]

In [14]:
all_genes.isin(gencode.gene_id).value_counts()

True     23308
False      110
dtype: int64

This is good enough, some 110 retired genes we can fill with 0

In [15]:
train_latent = {}
for file in os.listdir("20_atac/"):
    if file.endswith("latent.npy"):
        c = file.split(".")[0]
        train_latent[c] = np.load(f"20_atac/{file}")

In [16]:
chr2genes = {}
for c in train_latent.keys():
    chr_genes = gencode.gene_id[gencode.seqname.eq(c)]
    selected_genes = all_genes.isin(chr_genes).values
    targets = f["train_multi_targets"]["block0_values"][:, selected_genes]
    colsums = targets.sum(axis=0)
    selected_gene_names = all_genes[all_genes.isin(chr_genes)]
    selected_gene_names = selected_gene_names[colsums > 0]
    chr2genes[c] = selected_gene_names

## 3. Catboost models

TODO: fix function for chrY test case

In [17]:
def correlation_score(y_true, y_pred):
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [19]:
catboost_params = {
    'learning_rate': 0.1, 
    'depth': 7, 
    'l2_leaf_reg': 4, 
    'loss_function': 'MultiRMSE', 
    'eval_metric': 'MultiRMSE', 
    'task_type': 'CPU', 
    'iterations': 100,
    'od_type': 'Iter', 
    'boosting_type': 'Plain', 
    'bootstrap_type': 'Bayesian', 
    'allow_const_label': True, 
    'random_state': 1
}

In [21]:
for c in train_latent.keys():
    genes = gencode.gene_id[gencode.seqname.eq(c)]
    selected_names = names.isin(genes).values
    targets = f["train_multi_targets"]["block0_values"][:, selected_names]
    colsums = targets.sum(axis=0)
    targets = targets[:, colsums > 0]
    train_x = train_latent[c]
    model = catboost.CatBoostRegressor(**catboost_params)
    model.fit(train_x, targets)
    y_pred = model.predict(train_x)
    model.save_model(f"22_catboost/{c}.model")
    np.save(f"22_catboost/{c}.names", selected_names[selected_names][colsums > 0])
    score = correlation_score(targets, y_pred)
    print(f"Model for {c}: {score}")

0:	learn: 28.9968345	total: 12.8s	remaining: 21m 6s
1:	learn: 28.9473624	total: 26.2s	remaining: 21m 21s
2:	learn: 28.9030295	total: 39.2s	remaining: 21m 8s
3:	learn: 28.8607267	total: 52.2s	remaining: 20m 52s
4:	learn: 28.8223521	total: 1m 5s	remaining: 20m 42s
5:	learn: 28.7940962	total: 1m 18s	remaining: 20m 27s
6:	learn: 28.7707502	total: 1m 30s	remaining: 20m 8s
7:	learn: 28.7484825	total: 1m 43s	remaining: 19m 51s
8:	learn: 28.7310476	total: 1m 56s	remaining: 19m 38s
9:	learn: 28.7131472	total: 2m 9s	remaining: 19m 23s
10:	learn: 28.6983682	total: 2m 21s	remaining: 19m 8s
11:	learn: 28.6843176	total: 2m 34s	remaining: 18m 53s
12:	learn: 28.6714466	total: 2m 47s	remaining: 18m 39s
13:	learn: 28.6607675	total: 3m	remaining: 18m 27s
14:	learn: 28.6512898	total: 3m 13s	remaining: 18m 13s
15:	learn: 28.6415389	total: 3m 25s	remaining: 17m 58s
16:	learn: 28.6328728	total: 3m 38s	remaining: 17m 45s
17:	learn: 28.6240203	total: 3m 51s	remaining: 17m 34s
18:	learn: 28.6187842	total: 4m 4s

  c /= stddev[:, None]
  c /= stddev[None, :]


Model for chrY: nan


In [18]:
for c in train_latent.keys():
    if not os.path.exists(f"22_catboost/{c}.model"):
        print(f"No model for chromosome {c}")

In [19]:
preds = {}

In [20]:
for c in train_latent.keys():
    model = catboost.CatBoostRegressor()
    model.load_model(f"22_catboost/{c}.model")
    latent = np.load(f"21_latent/{c}.latent.npy")
    preds[c] = model.predict(latent)
    print(f"Predicted for chromosome {c}")

Predicted for chromosome chr18
Predicted for chromosome chr5
Predicted for chromosome chr11
Predicted for chromosome chr15
Predicted for chromosome chrX
Predicted for chromosome chr22
Predicted for chromosome chr12
Predicted for chromosome chr3
Predicted for chromosome chr16
Predicted for chromosome chr7
Predicted for chromosome chr17
Predicted for chromosome chr13
Predicted for chromosome chr2
Predicted for chromosome chr8
Predicted for chromosome chr14
Predicted for chromosome chr19
Predicted for chromosome chr6
Predicted for chromosome chr4
Predicted for chromosome chr1
Predicted for chromosome chr20
Predicted for chromosome chr21
Predicted for chromosome chr9
Predicted for chromosome chr10
Predicted for chromosome chrY


In evaluation different cells have different genes selected, so we need to match them manually

In [21]:
submission = pd.read_csv("../data/evaluation_ids.csv")

In [22]:
submission = submission.loc[submission.gene_id.str.startswith("ENSG0")]

In [23]:
submission

Unnamed: 0,row_id,cell_id,gene_id
6812820,6812820,8d287040728a,ENSG00000204091
6812821,6812821,8d287040728a,ENSG00000198938
6812822,6812822,8d287040728a,ENSG00000168495
6812823,6812823,8d287040728a,ENSG00000165527
6812824,6812824,8d287040728a,ENSG00000167414
...,...,...,...
65744175,65744175,2c53aa67933d,ENSG00000134419
65744176,65744176,2c53aa67933d,ENSG00000186862
65744177,65744177,2c53aa67933d,ENSG00000170959
65744178,65744178,2c53aa67933d,ENSG00000107874


Load cell names from test

In [24]:
f = h5py.File("../data/test_multi_inputs.h5")

In [25]:
cells = f["test_multi_inputs"]["axis1"].asstr()[:]

In [26]:
cells = pd.Series(cells)

Figure out gene names for each chromosome

In [27]:
gene2chr = gencode[["seqname", "gene_id"]].set_index("gene_id").seqname
gene2chr = gene2chr.loc[~gene2chr.index.duplicated()]

TODO: rewrite to multiindex

In [None]:
%%time
subs = []
for i, (_, r) in enumerate(submission.iterrows()):
    if r.gene_id not in gene2chr.index:
        # absent gene
        subs.append(0)
        continue
    c = gene2chr[r.gene_id]
    cell = r.cell_id
    cell_idx = cells.eq(cell)
    if c not in chr2genes:
        # absent chromosome
        subs.append(0)
        continue
    chr_genes = chr2genes[c]
    gene_idx = chr_genes.eq(r.gene_id)
    if gene_idx.sum() == 0:
        # constant gene, we didn't train on it
        subs.append(0)
        continue
    val = preds[c][cell_idx.values, gene_idx.values]
    subs.append(val)
    if i % 10_000 == 0:
        print(".", end="")

........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

In [56]:
genes1 = submission.gene_id[submission.cell_id.eq("8d287040728a")]

In [57]:
submission

Unnamed: 0,row_id,cell_id,gene_id
6812820,6812820,8d287040728a,ENSG00000204091
6812821,6812821,8d287040728a,ENSG00000198938
6812822,6812822,8d287040728a,ENSG00000168495
6812823,6812823,8d287040728a,ENSG00000165527
6812824,6812824,8d287040728a,ENSG00000167414
...,...,...,...
65744175,65744175,2c53aa67933d,ENSG00000134419
65744176,65744176,2c53aa67933d,ENSG00000186862
65744177,65744177,2c53aa67933d,ENSG00000170959
65744178,65744178,2c53aa67933d,ENSG00000107874


In [49]:
with open("05_jax_preds/cite_test_preds.pickle", "wb") as f:
    pickle.dump(test_pred, f)

In [50]:
with open("05_jax_preds/partial_submission_multi.pickle", "rb") as f:
    submission = pickle.load(f)

submission.iloc[:len(test_pred.ravel())] = test_pred.ravel()
assert not submission.isna().any()
submission = submission.round(6) # reduce the size of the csv
submission.to_csv("05_jax_preds/submission.csv")
submission

row_id
0           0.273072
1           0.293944
2           0.658493
3           4.448602
4           4.789261
              ...   
65744175    7.271666
65744176    0.017597
65744177    0.025040
65744178    1.933743
65744179    5.111444
Name: target, Length: 65744180, dtype: float64

This submission scores 0.794

In [None]:
# from https://www.kaggle.com/code/geraseva/magic-correlation
cite_cols_important={'CD86': ['ENSG00000114013_CD86'],
             'CD274': ['ENSG00000120217_CD274'],
             'CD270': ['ENSG00000157873_TNFRSF14'],
             'CD155': ['ENSG00000073008_PVR'],
             'CD112': ['ENSG00000130202_NECTIN2'],
             'CD47': ['ENSG00000196776_CD47'],
             'CD48': ['ENSG00000117091_CD48'],
             'CD40': ['ENSG00000101017_CD40'],
             'CD154': ['ENSG00000102245_CD40LG'],
             'CD52': ['ENSG00000169442_CD52'],
             'CD3': ['ENSG00000167286_CD3D'],
             'CD8': [],
             'CD56': ['ENSG00000149294_NCAM1'],
             'CD19': ['ENSG00000177455_CD19'],
             'CD33': ['ENSG00000105383_CD33'],
             'CD11c': ['ENSG00000140678_ITGAX'],
             'HLA-A-B-C': ['ENSG00000204525_HLA-C',
              'ENSG00000206503_HLA-A',
              'ENSG00000234745_HLA-B'],
             'CD45RA': ['ENSG00000081237_PTPRC'],
             'CD123': ['ENSG00000185291_IL3RA'],
             'CD7': ['ENSG00000173762_CD7'],
             'CD105': ['ENSG00000106991_ENG'],
             'CD49f': ['ENSG00000091409_ITGA6'],
             'CD194': ['ENSG00000183813_CCR4'],
             'CD4': ['ENSG00000010610_CD4'],
             'CD44': ['ENSG00000026508_CD44'],
             'CD14': ['ENSG00000170458_CD14'],
             'CD16': [],
             'CD25': ['ENSG00000134460_IL2RA'],
             'CD45RO': ['ENSG00000081237_PTPRC'],
             'CD279': [],
             'TIGIT': [],
             'Mouse-IgG1': [],
             'Mouse-IgG2a': [],
             'Mouse-IgG2b': [],
             'Rat-IgG2b': [],
             'CD20': ['ENSG00000156738_MS4A1'],
             'CD335': ['ENSG00000189430_NCR1'],
             'CD31': ['ENSG00000261371_PECAM1'],
             'Podoplanin': [],
             'CD146': ['ENSG00000076706_MCAM'],
             'IgM': ['ENSG00000211899_IGHM'],
             'CD5': [],
             'CD195': ['ENSG00000160791_CCR5'],
             'CD32': ['ENSG00000143226_FCGR2A'],
             'CD196': [],
             'CD185': ['ENSG00000160683_CXCR5'],
             'CD103': ['ENSG00000083457_ITGAE'],
             'CD69': ['ENSG00000110848_CD69'],
             'CD62L': ['ENSG00000188404_SELL'],
             'CD161': ['ENSG00000111796_KLRB1'],
             'CD152': [],
             'CD223': ['ENSG00000089692_LAG3'],
             'KLRG1': ['ENSG00000139187_KLRG1'],
             'CD27': ['ENSG00000139193_CD27'],
             'CD107a': ['ENSG00000185896_LAMP1'],
             'CD95': ['ENSG00000026103_FAS'],
             'CD134': ['ENSG00000186827_TNFRSF4'],
             'HLA-DR': ['ENSG00000204287_HLA-DRA'],
             'CD1c': ['ENSG00000158481_CD1C'],
             'CD11b': ['ENSG00000169896_ITGAM'],
             'CD64': ['ENSG00000150337_FCGR1A'],
             'CD141': ['ENSG00000178726_THBD'],
             'CD1d': ['ENSG00000158473_CD1D'],
             'CD314': [],
             'CD35': ['ENSG00000203710_CR1'],
             'CD57': [],
             'CD272': [],
             'CD278': ['ENSG00000163600_ICOS'],
             'CD58': ['ENSG00000116815_CD58'],
             'CD39': ['ENSG00000138185_ENTPD1'],
             'CX3CR1': ['ENSG00000168329_CX3CR1'],
             'CD24': ['ENSG00000272398_CD24'],
             'CD21': ['ENSG00000117322_CR2'],
             'CD11a': ['ENSG00000005844_ITGAL'],
             'CD79b': ['ENSG00000007312_CD79B'],
             'CD244': ['ENSG00000122223_CD244'],
             'CD169': [],
             'integrinB7': ['ENSG00000139626_ITGB7'],
             'CD268': ['ENSG00000159958_TNFRSF13C'],
             'CD42b': ['ENSG00000185245_GP1BA'],
             'CD54': ['ENSG00000090339_ICAM1'],
             'CD62P': ['ENSG00000174175_SELP'],
             'CD119': ['ENSG00000027697_IFNGR1'],
             'TCR': [],
             'Rat-IgG1': [],
             'Rat-IgG2a': [],
             'CD192': ['ENSG00000121807_CCR2'],
             'CD122': ['ENSG00000100385_IL2RB'],
             'FceRIa': ['ENSG00000179639_FCER1A'],
             'CD41': ['ENSG00000005961_ITGA2B'],
             'CD137': ['ENSG00000049249_TNFRSF9'],
             'CD163': ['ENSG00000177575_CD163'],
             'CD83': ['ENSG00000112149_CD83'],
             'CD124': ['ENSG00000077238_IL4R'],
             'CD13': ['ENSG00000166825_ANPEP'],
             'CD2': ['ENSG00000116824_CD2'],
             'CD226': ['ENSG00000150637_CD226'],
             'CD29': ['ENSG00000150093_ITGB1'],
             'CD303': ['ENSG00000198178_CLEC4C'],
             'CD49b': ['ENSG00000164171_ITGA2'],
             'CD81': ['ENSG00000110651_CD81'],
             'IgD': ['ENSG00000211898_IGHD'],
             'CD18': ['ENSG00000160255_ITGB2'],
             'CD28': [],
             'CD38': ['ENSG00000004468_CD38'],
             'CD127': ['ENSG00000168685_IL7R'],
             'CD45': ['ENSG00000081237_PTPRC'],
             'CD22': ['ENSG00000012124_CD22'],
             'CD71': ['ENSG00000072274_TFRC'],
             'CD26': ['ENSG00000197635_DPP4'],
             'CD115': ['ENSG00000182578_CSF1R'],
             'CD63': ['ENSG00000135404_CD63'],
             'CD304': ['ENSG00000099250_NRP1'],
             'CD36': ['ENSG00000135218_CD36'],
             'CD172a': ['ENSG00000198053_SIRPA'],
             'CD72': ['ENSG00000137101_CD72'],
             'CD158': [],
             'CD93': ['ENSG00000125810_CD93'],
             'CD49a': ['ENSG00000213949_ITGA1'],
             'CD49d': ['ENSG00000115232_ITGA4'],
             'CD73': [],
             'CD9': ['ENSG00000010278_CD9'],
             'TCRVa7.2': [],
             'TCRVd2': [],
             'LOX-1': ['ENSG00000173391_OLR1'],
             'CD158b': [],
             'CD158e1': [],
             'CD142': ['ENSG00000117525_F3'],
             'CD319': ['ENSG00000026751_SLAMF7'],
             'CD352': ['ENSG00000162739_SLAMF6'],
             'CD94': ['ENSG00000134539_KLRD1'],
             'CD162': ['ENSG00000110876_SELPLG'],
             'CD85j': ['ENSG00000104972_LILRB1'],
             'CD23': ['ENSG00000104921_FCER2'],
             'CD328': ['ENSG00000168995_SIGLEC7'],
             'HLA-E': ['ENSG00000204592_HLA-E'],
             'CD82': ['ENSG00000085117_CD82'],
             'CD101': ['ENSG00000134256_CD101'],
             'CD88': ['ENSG00000197405_C5AR1'],
             'CD224': ['ENSG00000100031_GGT1']}