In [91]:
import pandas as pd
import pickle
import numpy as np
from scipy.sparse import csr_matrix
import shap

In [72]:

with open("Repository/Pretrained_model/alldic.pickle", "rb") as f:
  alldic = pickle.load(f)
with open("Repository/Pretrained_model/allnum.pickle", "rb") as f:
  allnum = pickle.load(f)
with open("Repository/Feature/depmap.pickle", "rb") as f:
  depmap = pickle.load(f)
with open("Repository/Feature/gene2vec.pickle","rb") as f:
  gene2vec = pickle.load(f)
with open("Repository/Pretrained_model/datdic.pickle", "rb") as f:
  datdic = pickle.load(f)
with open("Repository/Pretrained_model/datnum.pickle", "rb") as f:
  datnum = pickle.load(f)
import gene_to_sym
gene_to_sym = gene_to_sym.get()
for sym in set(gene_to_sym.values()):
   sym = sym.upper()
   if sym not in alldic:
     alldic[sym] = set()
   if sym not in datdic:
     datdic[sym] = set()

In [73]:
#sparse matrix
def make_spa(fdic,fnum,exptup, col,row, k, C):
  colap = col.append
  rowap = row.append
  for nt in exptup:
    fpre = fdic[nt[0].upper()]
    fnex = fdic[nt[1].upper()]
    count = 0
    for fp in fpre:
      ff = fnum[fp]
      if fp in fnex:
        colap(C+ff*3)
        count += 1
      else:
        colap(C+ff*3+1)
        count+=1
    for fn in fnex:
      ff = fnum[fn]
      if fn not in fpre:
        colap(C+ff*3+2)
        count+=1
    row += [k for _ in range(count)]
    k += 1
  C += len(fnum)*3
  return col,row,C

def makespa_wv(g2v,exptup, col,row, k, C):
  d = []
  colap = col.append
  rowap = row.append
  dap = d.append
  npc = np.corrcoef
  for nt in exptup:
    nt0,nt1 = nt[0].lower(),nt[1].lower()
    if (nt0 not in g2v) or (nt1 not in g2v):
       k += 1
       continue
    else:
       colap(C)
       dap(npc(g2v[nt0],g2v[nt1])[0][1])
       rowap(k)
       k += 1
  C += 1
  return col,row,C,d

def makespa_dep(depmap,exptup, col,row, k, C):
  d = []
  colap = col.append
  rowap = row.append
  dap = d.append
  npc = np.corrcoef
  for nt in exptup:
    nt0 = nt[0]
    nt1 = nt[1]
    if (nt0 not in depmap) or (nt1 not in depmap):
       k += 1
       continue
    else:
       colap(C)
       x = npc(depmap[nt0],depmap[nt1])[0][1]
       dap(x)
       rowap(k)
       k += 1
  C += 1
  return col,row,C,d


def expespar(exptup, ya, k):
    row = []
    col = []
    C = 0
    d = 0
    if how == "all":
       col,row,C = make_spa(alldic,allnum,exptup, col,row, k, C)
       col,row,C,d1 = makespa_wv(gene2vec,exptup, col,row, k, C)
       col,row,C,d2 = makespa_dep(depmap,exptup, col,row, k, C)
       data = [1 for _ in range(len(col)-len(d1)-len(d2))]+d1+d2
    elif how == "dat":
       col,row,C = make_spa(datdic,datnum,exptup, col,row, k, C)
       col,row,C,d2 = makespa_dep(depmap,exptup, col,row, k, C)
       data = [1 for _ in range(len(col)-len(d2))]+d2
    y = np.full(len(exptup),ya)
    k += len(exptup)
    return data, row, col, k, y, C

In [2]:
def scoration(query,model):
    data = []
    row = []
    col = []
    pr = []
    sc = []
    extup = []
    
    for sym in set([g.upper() for g in gene_to_sym.values()]):
        extup.append((query,sym))
        pr.append(sym)

    extups = []
    extups.append(extup[:10000]) #memory
    extups.append(extup[10000:20000])
    extups.append(extup[20000:30000])
    extups.append(extup[30000:40000])
    extups.append(extup[40000:])
        
    sco = []
    for n in range(len(extups)):
          data, row, col,k, y, C = expespar(extups[n],1,0)
          scot = model.predict(csr_matrix((data, (row,col)),(k, C))).tolist()
          sco += scot
        
    score = pd.DataFrame({"Gene":pr,"Score":sco})

    score = score.sort_values("Score",  ascending=False)
    score["Rank"] = list(range(1,len(score)+1))
    
    return score
    


In [1]:
def make_common(g1,g2):
    if how=="all":
      common = alldic[g1] & alldic[g2]
    if how=="dat":
      common = datdic[g1] & datdic[g2]
    return common

def commoncheck(query,suggestions,model):
    explainer = shap.TreeExplainer(model=model, feature_perturbation='tree_path_dependent', model_output='raw')
    extups = []
    for suggestion in suggestions:
       extups.append((query,suggestion))
    data, row, col,k, y, C = expespar(extups,1,0)
    x =csr_matrix((data, (row,col)),(k, C))
    shap_values = explainer.shap_values(X=x)
    ret = []
    if how == "all":num=allnum
    if how == "dat":num=datnum
    for n,suggestion in enumerate(suggestions):
      tmp = []
      sv = shap_values[n]
      common = make_common(query,suggestion)
      for c in common:
         ch = num[c]*3 
         tmp.append((round(sv[ch]*100,3),c))
      if query.upper() in depmap:
         tmp.append((round(sv[-1]*100,3),"DepMap"))
      if how=="all":
         if query.lower() in gene2vec:
            tmp.append((round(sv[len(allnum)*3]*100,3),"Word2Vec"))
      tmp.sort(reverse=True)
      ret.append(tmp)
    return ret

In [118]:
def make_table(query,model):
    s = scoration(query,model)[:100]
    cos = []
    ins = []
    cod = {}
    coefs = commoncheck(query,s["Gene"].tolist()[:100],model)
    for n,sym in enumerate(s["Gene"].tolist()[:100]):
       coef = coefs[n]
       cod[sym]=coef
       if len(coef) > 5:
         coef = coef[:5]
       cos.append("\t".join([str(c) for c in coef]))
    
    s["Feature importance"] = cos

    s = s.reindex(columns=["Rank","Gene","Score","Feature importance"])
    return s

In [112]:
import xgboost as xgb
model =  xgb.XGBRegressor()
how = "all"
if how == "dat":
  model.load_model("Repository/Pretrained_model/model_data.pickle")
if how == "all":
  model.load_model("Repository/Pretrained_model/model_all.pickle")



Loading a native XGBoost model with Scikit-Learn interface.


In [119]:
make_table("CEP63",model)

Unnamed: 0,Rank,Gene,Score,Feature importance
8112,1,CEP63,0.999705,"(545.467, 'Word2Vec')\t(210.639, 'DepMap')\t(1..."
26265,2,CEP152,0.998968,"(626.227, 'Word2Vec')\t(18.76, 'DepMap')\t(12...."
47920,3,CEP135,0.998609,"(564.45, 'Word2Vec')\t(42.236, 'DepMap')\t(14...."
35190,4,CDK5RAP2,0.998352,"(637.501, 'Word2Vec')\t(9.324, 'GO:0005813')\t..."
48225,5,CEP192,0.998250,"(609.52, 'Word2Vec')\t(14.341, 'GO:0010389')\t..."
...,...,...,...,...
33824,96,ETAA1,0.855161,"(159.202, 'Word2Vec')\t(29.028, 'DepMap')\t(7...."
1024,97,UBAP2L,0.854421,"(207.48, 'Word2Vec')\t(8.781, 'TF:GTF2B')\t(2...."
10031,98,NEK11,0.850381,"(141.343, 'Word2Vec')\t(48.018, 'DepMap')\t(4...."
6245,99,TTLL5,0.848506,"(105.984, 'Word2Vec')\t(69.435, 'DepMap')\t(13..."
