In [1]:
import os
import warnings

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from scipy import stats
from sklearn.linear_model import Ridge
from tqdm import tqdm

warnings.simplefilter("ignore")
os.environ["PYTHONWARNINGS"] = "ignore" 

In [2]:
test = pd.read_csv(
    "../DrugCell/data_rcellminer/test_rcell_wo_other.txt",
    header=None,
    sep="\t",
)
pred = np.loadtxt("../DrugCell/code/Result/drugcell.predict")

In [3]:
def get_list(i):
    """
    Get list of drug indices for a given drug

    Parameters
    ----------
    i : Drug
    
    """
    
    tmp = test[test[1] == i]
    return [list(tmp.index), list(test[test[1] == i][1])[0]]

In [4]:
t = Parallel(n_jobs=-1)(delayed(get_list)(i) for i in tqdm(set(test[1])))
t = pd.DataFrame(t)
t.columns = ["drug_index", "drug"]

100%|██████████| 309/309 [00:00<00:00, 346.34it/s]


In [5]:
def get_corr(X, y):
    """
    Get correlation score between final prediction and prediction from each hidden feature

    Parameters
    ----------

    X : Hidden feature
    y : Final prediction

    """

    regr = Ridge()
    regr.fit(X, y)
    y_pred = regr.predict(X)
    p_rho, _ = stats.spearmanr(y_pred, y)
    return p_rho

In [6]:
def collect_corr(term, t):
    """
    Collect correlation score for each drug

    term : GO term
    t : List of drug indices and drug name

    """

    hidden = pd.read_csv(
        "../DrugCell/code/Hidden/" + term, header=None, sep=" "
    )
    corr = []

    for i in list(t["drug_index"]):
        y = pred[i]
        X = hidden.loc[i]
        corr.append(get_corr(X, y))

    return corr

In [7]:
GO = (
    pd.read_csv("../DrugCell/data_rcellminer/go.txt", header=None, sep="\t")[0]
    .unique()
    .tolist()
)
p = Parallel(n_jobs=-1)(
    delayed(collect_corr)(i + ".hidden", t) for i in tqdm(GO)
)
importance = pd.DataFrame(p, columns=list(t["drug"]), index=GO)

100%|██████████| 2086/2086 [02:12<00:00, 15.78it/s]


In [8]:
pubchem_id = pd.read_csv(
    "../DrugCell/data_rcellminer/SMILES_from_PubchemID.txt",
    header=None,
    sep="\t",
)
pubchem_id = {pubchem_id[1][i]: pubchem_id[0][i] for i in pubchem_id.index}
importance.columns = [pubchem_id[i] for i in importance.columns]

In [9]:
importance

Unnamed: 0,348483,401057,3114175,11589884,5459315,11704671,4005,54608427,24205286,281834,...,141295,321870,325912,11682081,382058,1724503,157348,25136944,401298,775540
GO:0007005,0.626776,0.307705,0.604644,0.924402,0.945615,0.789024,0.902570,1.000000,0.979088,0.977234,...,0.745455,0.381746,0.634742,0.815035,0.892361,0.881098,0.880952,0.587701,0.984548,0.842088
GO:0006281,0.951049,0.813665,0.881484,0.916667,0.652725,0.816548,0.923280,0.900000,0.945615,0.591945,...,0.672727,0.949107,0.952381,0.975684,0.788792,0.966570,0.880952,0.856494,0.952171,0.776954
GO:0051052,0.944056,0.689260,0.732800,0.566667,0.577411,0.761500,0.643498,0.700000,0.870301,0.781087,...,0.872727,0.734526,0.952381,0.919678,0.585352,0.930095,0.880952,0.929387,0.925681,0.623424
GO:1903047,0.965035,0.664644,0.948745,0.900000,0.878669,0.889946,0.949094,0.900000,0.995825,0.811189,...,0.690909,0.869327,0.928571,0.990422,0.846293,0.924016,0.976190,0.883829,0.971303,0.911875
GO:0006631,0.930070,0.631949,0.817762,0.750000,0.652725,0.963343,0.530376,1.000000,0.979088,0.818182,...,0.800000,0.786796,0.928571,0.782288,0.316256,0.576220,0.898220,0.710708,0.955114,0.516419
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GO:0036151,0.711439,0.569406,0.497399,0.730297,0.576693,0.388735,0.701225,0.353553,0.449182,0.673994,...,0.352303,0.648499,0.436436,0.553995,-0.177747,0.016325,0.748336,0.310582,0.702804,0.326927
GO:0036152,0.711439,0.752664,0.544770,0.564384,0.576693,0.520688,0.677694,0.353553,0.339178,0.682315,...,0.248961,0.658579,0.436436,0.553995,0.007030,-0.101438,0.646867,0.156689,0.641154,0.312018
GO:0015695,0.626139,0.280217,0.204454,0.207931,-0.183340,0.654294,0.111187,0.447214,0.834136,0.354787,...,0.135305,0.275329,0.783547,0.532464,0.153044,0.556888,0.109109,0.634698,0.574369,0.525863
GO:0043252,0.470132,0.463116,0.166161,0.297044,0.874038,0.654330,0.546877,0.335410,0.333090,0.665674,...,0.327150,0.548573,0.503953,0.353015,0.254611,0.366489,0.659550,0.253879,0.731860,0.366667


In [10]:
importance.to_csv('../DrugCell/data_rcellminer/corr_score.csv')