In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import sc_utils
import catboost
import pickle
import os
import requests

In [2]:
import h5py
import hdf5plugin

In [3]:
pd.options.display.max_rows = 200

In [4]:
%config InlineBackend.figure_format = "retina"

## 1. Load targets

In [29]:
f = h5py.File("../data/train_multi_targets.h5")

In [30]:
names = f["train_multi_targets"]["axis0"].asstr()[:]

In [31]:
names = pd.Series(names)

In [34]:
f["train_multi_targets"]["block0_values"]

<HDF5 dataset "block0_values": shape (105942, 23418), type "<f4">

## 2. Load train latent

In [8]:
GFF3_URL = "https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_42/gencode.v42.annotation.gff3.gz"
GFF3 = "gencode.v42.annotation.gff3.gz"
if not os.path.exists(GFF3):
    with requests.get(GFF3_URL, stream=True) as res:
        res.raise_for_status()
        filesize = int(res.headers["Content-Length"])
        with open(GFF3, "wb") as df:
            total_bytes_received = 0
            for chunk in res.iter_content(chunk_size=1024 * 1024):
                df.write(chunk)
                total_bytes_received += len(chunk)
                percent_of_total_upload = float("{:.1f}".format(total_bytes_received / filesize * 100))
                color = "\033[38;5;10m" if percent_of_total_upload == 100 else ""
                print(f"\033[1m{color}{percent_of_total_upload}% downloaded\033[0m\r", end="")

In [9]:
gencode = pd.read_table(
    GFF3, 
    comment="#",
    sep="\t", 
    names=['seqname', 'source', 'feature', 'start' , 'end', 'score', 'strand', 'frame', 'attribute']
)

In [10]:
gencode = gencode.loc[gencode.feature.eq("gene"), :]

In [12]:
gencode["gene_id"] = gencode.attribute.str.split(";").str[1].str.split("=").str[1]

In [16]:
gencode.gene_id = gencode.gene_id.str.split(".").str[0]

In [19]:
names.isin(gencode.gene_id).value_counts()

True     23308
False      110
dtype: int64

This is good enough, some 110 retired genes we can fill with 0

In [26]:
train_latent = {}
for file in os.listdir("10_models/"):
    if file.endswith("latent.npy"):
        c = file.split(".")[0]
        train_latent[c] = np.load(f"10_models/{file}")

## 3. Catboost models

In [47]:
def correlation_score(y_true, y_pred):
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [51]:
catboost_params = {
    'learning_rate': 0.1, 
    'depth': 7, 
    'l2_leaf_reg': 4, 
    'loss_function': 'MultiRMSE', 
    'eval_metric': 'MultiRMSE', 
    'task_type': 'CPU', 
    'iterations': 100,
    'od_type': 'Iter', 
    'boosting_type': 'Plain', 
    'bootstrap_type': 'Bayesian', 
    'allow_const_label': True, 
    'random_state': 1
}

In [None]:
for c in train_latent.keys():
    genes = gencode.gene_id[gencode.seqname.eq(c)]
    selected_names = names.isin(genes).values
    targets = f["train_multi_targets"]["block0_values"][:, selected_names]
    colsums = targets.sum(axis=0)
    targets = targets[:, colsums > 0]
    train_x = train_latent[c]
    model = catboost.CatBoostRegressor(**catboost_params)
    model.fit(train_x, targets)
    y_pred = model.predict(train_x)
    score = correlation_score(targets, y_pred)
    print(f"Model for {c}: {score}")
    model.save_model(f"12_catboost/{c}.model")
    np.save(f"12_catboost/{c}.names", selected_names[selected_names][colsums > 0])

0:	learn: 29.0618318	total: 12.9s	remaining: 21m 13s
1:	learn: 29.0576158	total: 25.6s	remaining: 20m 53s
2:	learn: 29.0533554	total: 38.1s	remaining: 20m 31s
3:	learn: 29.0495244	total: 50.7s	remaining: 20m 16s
4:	learn: 29.0457945	total: 1m 3s	remaining: 20m 8s
5:	learn: 29.0420766	total: 1m 16s	remaining: 19m 56s
6:	learn: 29.0382196	total: 1m 29s	remaining: 19m 43s
7:	learn: 29.0344567	total: 1m 41s	remaining: 19m 29s
8:	learn: 29.0314108	total: 1m 54s	remaining: 19m 12s
9:	learn: 29.0274039	total: 2m 6s	remaining: 18m 56s
10:	learn: 29.0236352	total: 2m 19s	remaining: 18m 46s
11:	learn: 29.0195829	total: 2m 32s	remaining: 18m 35s
12:	learn: 29.0159927	total: 2m 44s	remaining: 18m 24s
13:	learn: 29.0128076	total: 2m 57s	remaining: 18m 11s
14:	learn: 29.0088889	total: 3m 10s	remaining: 17m 58s
15:	learn: 29.0054239	total: 3m 23s	remaining: 17m 46s
16:	learn: 29.0022697	total: 3m 36s	remaining: 17m 35s
17:	learn: 28.9988297	total: 3m 49s	remaining: 17m 24s
18:	learn: 28.9955582	total

In [30]:
ds.obs_names = ds.obs_names.str.replace("-\d+$", "", regex=True)

In [31]:
train_x = ds[targets.index, :].obsm["X_jax"]

In [36]:
model = catboost.CatBoostRegressor(
    iterations=1000,
    # task_type="GPU",
    loss_function="MultiRMSE",
    eval_metric="MultiRMSE"
)

In [37]:
model.fit(train_x, targets)

0:	learn: 25.9322109	total: 1.35s	remaining: 22m 26s
1:	learn: 25.7480553	total: 2.66s	remaining: 22m 9s
2:	learn: 25.5693188	total: 3.98s	remaining: 22m 3s
3:	learn: 25.3982765	total: 5.21s	remaining: 21m 36s
4:	learn: 25.2298558	total: 6.37s	remaining: 21m 8s
5:	learn: 25.0729520	total: 7.68s	remaining: 21m 12s
6:	learn: 24.9191016	total: 8.84s	remaining: 20m 53s
7:	learn: 24.7797825	total: 10s	remaining: 20m 44s
8:	learn: 24.6365317	total: 11.3s	remaining: 20m 49s
9:	learn: 24.5051204	total: 12.7s	remaining: 20m 57s
10:	learn: 24.3822681	total: 13.7s	remaining: 20m 30s
11:	learn: 24.2592870	total: 14.9s	remaining: 20m 26s
12:	learn: 24.1423655	total: 16.1s	remaining: 20m 23s
13:	learn: 24.0315311	total: 17.5s	remaining: 20m 30s
14:	learn: 23.9204742	total: 18.6s	remaining: 20m 22s
15:	learn: 23.8121009	total: 19.6s	remaining: 20m 7s
16:	learn: 23.7120350	total: 20.8s	remaining: 20m 3s
17:	learn: 23.6168842	total: 22.2s	remaining: 20m 11s
18:	learn: 23.5253306	total: 23.2s	remaining:

<catboost.core.CatBoostRegressor at 0x2b13bc29f730>

In [39]:
y_pred = model.predict(train_x)

In [40]:
correlation_score(targets, y_pred)

0.8963400960239647

In [43]:
model.save_model("05jax_catboost.model")

In [44]:
test_x = ds[~ds.obs_names.isin(targets.index), :].obsm["X_jax"]

In [45]:
test_pred = model.predict(test_x)

In [46]:
test_pred.shape

(48663, 140)

In [49]:
with open("05_jax_preds/cite_test_preds.pickle", "wb") as f:
    pickle.dump(test_pred, f)

In [50]:
with open("05_jax_preds/partial_submission_multi.pickle", "rb") as f:
    submission = pickle.load(f)

submission.iloc[:len(test_pred.ravel())] = test_pred.ravel()
assert not submission.isna().any()
submission = submission.round(6) # reduce the size of the csv
submission.to_csv("05_jax_preds/submission.csv")
submission

row_id
0           0.273072
1           0.293944
2           0.658493
3           4.448602
4           4.789261
              ...   
65744175    7.271666
65744176    0.017597
65744177    0.025040
65744178    1.933743
65744179    5.111444
Name: target, Length: 65744180, dtype: float64

This submission scores 0.794

In [None]:
# from https://www.kaggle.com/code/geraseva/magic-correlation
cite_cols_important={'CD86': ['ENSG00000114013_CD86'],
             'CD274': ['ENSG00000120217_CD274'],
             'CD270': ['ENSG00000157873_TNFRSF14'],
             'CD155': ['ENSG00000073008_PVR'],
             'CD112': ['ENSG00000130202_NECTIN2'],
             'CD47': ['ENSG00000196776_CD47'],
             'CD48': ['ENSG00000117091_CD48'],
             'CD40': ['ENSG00000101017_CD40'],
             'CD154': ['ENSG00000102245_CD40LG'],
             'CD52': ['ENSG00000169442_CD52'],
             'CD3': ['ENSG00000167286_CD3D'],
             'CD8': [],
             'CD56': ['ENSG00000149294_NCAM1'],
             'CD19': ['ENSG00000177455_CD19'],
             'CD33': ['ENSG00000105383_CD33'],
             'CD11c': ['ENSG00000140678_ITGAX'],
             'HLA-A-B-C': ['ENSG00000204525_HLA-C',
              'ENSG00000206503_HLA-A',
              'ENSG00000234745_HLA-B'],
             'CD45RA': ['ENSG00000081237_PTPRC'],
             'CD123': ['ENSG00000185291_IL3RA'],
             'CD7': ['ENSG00000173762_CD7'],
             'CD105': ['ENSG00000106991_ENG'],
             'CD49f': ['ENSG00000091409_ITGA6'],
             'CD194': ['ENSG00000183813_CCR4'],
             'CD4': ['ENSG00000010610_CD4'],
             'CD44': ['ENSG00000026508_CD44'],
             'CD14': ['ENSG00000170458_CD14'],
             'CD16': [],
             'CD25': ['ENSG00000134460_IL2RA'],
             'CD45RO': ['ENSG00000081237_PTPRC'],
             'CD279': [],
             'TIGIT': [],
             'Mouse-IgG1': [],
             'Mouse-IgG2a': [],
             'Mouse-IgG2b': [],
             'Rat-IgG2b': [],
             'CD20': ['ENSG00000156738_MS4A1'],
             'CD335': ['ENSG00000189430_NCR1'],
             'CD31': ['ENSG00000261371_PECAM1'],
             'Podoplanin': [],
             'CD146': ['ENSG00000076706_MCAM'],
             'IgM': ['ENSG00000211899_IGHM'],
             'CD5': [],
             'CD195': ['ENSG00000160791_CCR5'],
             'CD32': ['ENSG00000143226_FCGR2A'],
             'CD196': [],
             'CD185': ['ENSG00000160683_CXCR5'],
             'CD103': ['ENSG00000083457_ITGAE'],
             'CD69': ['ENSG00000110848_CD69'],
             'CD62L': ['ENSG00000188404_SELL'],
             'CD161': ['ENSG00000111796_KLRB1'],
             'CD152': [],
             'CD223': ['ENSG00000089692_LAG3'],
             'KLRG1': ['ENSG00000139187_KLRG1'],
             'CD27': ['ENSG00000139193_CD27'],
             'CD107a': ['ENSG00000185896_LAMP1'],
             'CD95': ['ENSG00000026103_FAS'],
             'CD134': ['ENSG00000186827_TNFRSF4'],
             'HLA-DR': ['ENSG00000204287_HLA-DRA'],
             'CD1c': ['ENSG00000158481_CD1C'],
             'CD11b': ['ENSG00000169896_ITGAM'],
             'CD64': ['ENSG00000150337_FCGR1A'],
             'CD141': ['ENSG00000178726_THBD'],
             'CD1d': ['ENSG00000158473_CD1D'],
             'CD314': [],
             'CD35': ['ENSG00000203710_CR1'],
             'CD57': [],
             'CD272': [],
             'CD278': ['ENSG00000163600_ICOS'],
             'CD58': ['ENSG00000116815_CD58'],
             'CD39': ['ENSG00000138185_ENTPD1'],
             'CX3CR1': ['ENSG00000168329_CX3CR1'],
             'CD24': ['ENSG00000272398_CD24'],
             'CD21': ['ENSG00000117322_CR2'],
             'CD11a': ['ENSG00000005844_ITGAL'],
             'CD79b': ['ENSG00000007312_CD79B'],
             'CD244': ['ENSG00000122223_CD244'],
             'CD169': [],
             'integrinB7': ['ENSG00000139626_ITGB7'],
             'CD268': ['ENSG00000159958_TNFRSF13C'],
             'CD42b': ['ENSG00000185245_GP1BA'],
             'CD54': ['ENSG00000090339_ICAM1'],
             'CD62P': ['ENSG00000174175_SELP'],
             'CD119': ['ENSG00000027697_IFNGR1'],
             'TCR': [],
             'Rat-IgG1': [],
             'Rat-IgG2a': [],
             'CD192': ['ENSG00000121807_CCR2'],
             'CD122': ['ENSG00000100385_IL2RB'],
             'FceRIa': ['ENSG00000179639_FCER1A'],
             'CD41': ['ENSG00000005961_ITGA2B'],
             'CD137': ['ENSG00000049249_TNFRSF9'],
             'CD163': ['ENSG00000177575_CD163'],
             'CD83': ['ENSG00000112149_CD83'],
             'CD124': ['ENSG00000077238_IL4R'],
             'CD13': ['ENSG00000166825_ANPEP'],
             'CD2': ['ENSG00000116824_CD2'],
             'CD226': ['ENSG00000150637_CD226'],
             'CD29': ['ENSG00000150093_ITGB1'],
             'CD303': ['ENSG00000198178_CLEC4C'],
             'CD49b': ['ENSG00000164171_ITGA2'],
             'CD81': ['ENSG00000110651_CD81'],
             'IgD': ['ENSG00000211898_IGHD'],
             'CD18': ['ENSG00000160255_ITGB2'],
             'CD28': [],
             'CD38': ['ENSG00000004468_CD38'],
             'CD127': ['ENSG00000168685_IL7R'],
             'CD45': ['ENSG00000081237_PTPRC'],
             'CD22': ['ENSG00000012124_CD22'],
             'CD71': ['ENSG00000072274_TFRC'],
             'CD26': ['ENSG00000197635_DPP4'],
             'CD115': ['ENSG00000182578_CSF1R'],
             'CD63': ['ENSG00000135404_CD63'],
             'CD304': ['ENSG00000099250_NRP1'],
             'CD36': ['ENSG00000135218_CD36'],
             'CD172a': ['ENSG00000198053_SIRPA'],
             'CD72': ['ENSG00000137101_CD72'],
             'CD158': [],
             'CD93': ['ENSG00000125810_CD93'],
             'CD49a': ['ENSG00000213949_ITGA1'],
             'CD49d': ['ENSG00000115232_ITGA4'],
             'CD73': [],
             'CD9': ['ENSG00000010278_CD9'],
             'TCRVa7.2': [],
             'TCRVd2': [],
             'LOX-1': ['ENSG00000173391_OLR1'],
             'CD158b': [],
             'CD158e1': [],
             'CD142': ['ENSG00000117525_F3'],
             'CD319': ['ENSG00000026751_SLAMF7'],
             'CD352': ['ENSG00000162739_SLAMF6'],
             'CD94': ['ENSG00000134539_KLRD1'],
             'CD162': ['ENSG00000110876_SELPLG'],
             'CD85j': ['ENSG00000104972_LILRB1'],
             'CD23': ['ENSG00000104921_FCER2'],
             'CD328': ['ENSG00000168995_SIGLEC7'],
             'HLA-E': ['ENSG00000204592_HLA-E'],
             'CD82': ['ENSG00000085117_CD82'],
             'CD101': ['ENSG00000134256_CD101'],
             'CD88': ['ENSG00000197405_C5AR1'],
             'CD224': ['ENSG00000100031_GGT1']}