In [1]:
import os
import warnings

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from scipy import stats
from sklearn.linear_model import Ridge
from tqdm import tqdm

warnings.simplefilter("ignore")
os.environ["PYTHONWARNINGS"] = "ignore" 

In [2]:
test = pd.read_csv(
    "../DrugCell/data_rcellminer/test_rcell_wo_other.txt",
    header=None,
    sep="\t",
)
pred = np.loadtxt("../DrugCell/code/Result/drugcell.predict")

In [3]:
def get_list(i):
    """
    Get list of drug indices for a given drug

    Parameters
    ----------
    i : Drug
    
    """
    
    tmp = test[test[1] == i]
    return [list(tmp.index), list(test[test[1] == i][1])[0]]

In [4]:
t = Parallel(n_jobs=-1)(delayed(get_list)(i) for i in tqdm(set(test[1])))
t = pd.DataFrame(t)
t.columns = ["drug_index", "drug"]

100%|██████████| 309/309 [00:00<00:00, 468.44it/s]


In [5]:
def get_corr(X, y):
    """
    Get correlation score between final prediction and prediction from each hidden feature

    Parameters
    ----------

    X : Hidden feature
    y : Final prediction

    """

    regr = Ridge()
    regr.fit(X, y)
    y_pred = regr.predict(X)
    p_rho, _ = stats.spearmanr(y_pred, y)
    return p_rho

In [6]:
def collect_corr(term, t):
    """
    Collect correlation score for each drug

    term : GO term
    t : List of drug indices and drug name

    """

    hidden = pd.read_csv(
        "../DrugCell/code/Hidden/" + term, header=None, sep=" "
    )
    corr = []

    for i in list(t["drug_index"]):
        y = pred[i]
        X = hidden.loc[i]
        corr.append(get_corr(X, y))

    return corr

In [7]:
GO = (
    pd.read_csv("../DrugCell/data_rcellminer/go.txt", header=None, sep="\t")[0]
    .unique()
    .tolist()
)

p = Parallel(n_jobs=-1)(
    delayed(collect_corr)(i + ".hidden", t) for i in tqdm(GO)
)
importance = pd.DataFrame(p, columns=list(t["drug"]), index=GO)

100%|██████████| 2086/2086 [02:24<00:00, 14.39it/s]


In [8]:
pubchem_id = pd.read_csv("../DrugCell/data_rcellminer/SMILES_from_PubchemID.txt", header=None, sep="\t",)
pubchem_id = {pubchem_id[1][i]: pubchem_id[0][i] for i in pubchem_id.index}
importance.columns = [pubchem_id[i] for i in importance.columns]
importance = importance.T.sort_index().T

In [9]:
importance

Unnamed: 0,2051,2179,2538,2569,3213,3973,4005,4033,4212,4261,...,54613769,60147550,60147553,60148416,60148419,60148441,135400182,135400916,135410875,135453292
GO:0007005,0.727763,0.801432,0.927099,0.936364,0.750000,0.312664,0.902570,0.771429,0.808514,0.680706,...,-0.150073,0.631418,0.954545,0.917690,0.554008,0.953983,0.900000,0.895104,0.715861,0.715626
GO:0006281,0.748608,0.846155,0.910592,0.918182,0.964286,0.675907,0.923280,1.000000,0.942254,0.893999,...,0.470407,0.800373,0.981818,1.000000,0.773540,0.928878,0.500000,0.965352,0.764058,0.954169
GO:0051052,0.670628,0.892667,0.984870,0.909091,0.892857,0.464398,0.643498,1.000000,0.930095,0.500880,...,0.400195,0.786239,0.972727,0.986014,0.757106,0.995825,0.900000,0.904726,0.837928,0.935819
GO:1903047,0.733012,0.940968,0.957360,0.881818,0.964286,0.381634,0.949094,0.942857,0.930095,0.954392,...,0.706595,0.970055,0.963636,0.993007,0.865097,0.928878,0.600000,0.834773,0.900883,0.944994
GO:0006631,0.803194,0.754921,0.969697,0.936364,0.964286,0.258075,0.530376,1.000000,0.917937,0.840072,...,0.225110,0.780938,0.981818,0.971855,0.536430,0.912142,0.900000,0.830110,0.924781,0.926645
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GO:0036151,0.174531,0.628363,0.743673,0.485828,0.486506,0.708910,0.701225,0.880406,0.687385,0.553199,...,0.066935,0.084059,0.477928,0.647234,0.310562,0.757509,0.353553,0.483131,0.846773,0.658934
GO:0036152,0.174531,0.418585,0.746566,0.606780,0.889499,0.708910,0.677694,0.676123,0.687385,0.561549,...,0.066935,0.566722,0.468369,0.670628,0.287049,-0.113625,0.353553,0.623061,0.803747,0.658934
GO:0015695,0.291920,0.406964,0.784836,0.067420,0.512272,0.293707,0.111187,0.033806,0.030938,0.442138,...,0.376891,0.510855,0.031684,0.499072,0.667056,0.709414,0.353553,0.518786,0.259794,0.100922
GO:0043252,0.721739,0.486954,0.873151,0.614645,0.778312,0.268984,0.546877,0.507093,0.702474,0.333533,...,0.040022,-0.014691,0.124257,0.678426,0.428858,0.687524,0.359092,0.809980,0.545206,0.495349


In [10]:
importance.to_csv('../DrugCell/data_rcellminer/corr_score.csv')