In [1]:
!pip uninstall scikit-learn -y
!pip install scikit-learn

!pip uninstall auto-sklearn -y
!pip install auto-sklearn


Collecting auto-sklearn
  Downloading auto-sklearn-0.15.0.tar.gz (6.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m64.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting liac-arff
  Downloading liac-arff-2.5.0.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting ConfigSpace<0.5,>=0.4.21
  Downloading ConfigSpace-0.4.21-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting pyrfr<0.9,>=0.8.1
  Downloading pyrfr-0.8.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m42.1 MB/s[0m eta [3

In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans

SEED = 6174
np.random.seed(SEED)

# Set the folder path for data
folder_path = "./input"

In [2]:
de_train = pd.read_parquet(f'{folder_path}/de_train.parquet')
genes = de_train.columns[5:]
id_map = pd.read_csv (f'{folder_path}/id_map.csv')
id_map = id_map.reindex(id_map.columns.tolist() + genes.tolist(), axis=1)

sm_lincs_id = de_train.set_index('sm_name')["sm_lincs_id"].to_dict()
sm_name_to_smiles = de_train.set_index('sm_name')['SMILES'].to_dict()

id_map['sm_lincs_id'] = id_map['sm_name'].map(sm_lincs_id)
id_map['SMILES'] = id_map['sm_name'].map(sm_name_to_smiles)

de_train

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.104720,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.884380,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.704780,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.213550,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.224700,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,LSM-5771,CC(C)c1c(C(=O)Nc2ccccc2)c(-c2ccccc2)c(-c2ccc(F...,False,-0.014372,-0.122464,-0.456366,-0.147894,-0.545382,...,-0.549987,-2.200925,0.359806,1.073983,0.356939,-0.029603,-0.528817,0.105138,0.491015,-0.979951
610,NK cells,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,-0.455549,0.188181,0.595734,-0.100299,0.786192,...,-1.236905,0.003854,-0.197569,-0.175307,0.101391,1.028394,0.034144,-0.231642,1.023994,-0.064760
611,T cells CD4+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.338168,-0.109079,0.270182,-0.436586,-0.069476,...,0.077579,-1.101637,0.457201,0.535184,-0.198404,-0.005004,0.552810,-0.209077,0.389751,-0.337082
612,T cells CD8+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.101138,-0.409724,-0.606292,-0.071300,-0.001789,...,0.005951,-0.893093,-1.003029,-0.080367,-0.076604,0.024849,0.012862,-0.029684,0.005506,-1.733112


In [3]:
## Drop cell_type = 'T cells CD4+', 'T cells CD8+'
#de_train.drop(de_train[de_train['cell_type'] == 'T cells CD4+'].index, inplace = True)
#de_train.drop(de_train[de_train['cell_type'] == 'T cells CD8+'].index, inplace = True)

#de_train

In [4]:
all_sm_names = de_train[de_train["cell_type"]=="B cells"]["sm_name"].to_list()
all_de_train = de_train[de_train["sm_name"].isin(all_sm_names)]
de_train.drop(all_de_train.index, inplace=True)

all_de_train

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
8,B cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.394173,-0.153824,0.178232,0.566241,0.391377,...,-1.052302,-1.176587,-1.220291,-0.278944,-0.095066,1.101790,0.061803,1.406335,-0.264996,-0.119743
9,Myeloid cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.025146,0.316388,1.366885,1.301593,2.317619,...,-0.902546,-1.445523,0.794772,0.428973,0.605834,0.271988,0.492231,0.354721,1.471559,-0.259483
10,NK cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.861487,-0.112313,-0.355217,0.719999,0.655865,...,0.035687,0.138060,-0.776619,-0.109832,-0.189906,0.753086,0.190892,-0.141699,-0.756510,-0.076934
11,T cells CD4+,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.206471,0.014638,-0.247518,0.430198,0.103020,...,-0.018902,-1.013426,-1.196254,-0.983257,1.097309,-0.090271,-1.293485,0.118196,-0.120892,-0.411331
12,T cells CD8+,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.046959,-0.346839,0.023478,0.485611,0.005066,...,0.017437,0.537964,-0.219895,0.376181,0.706930,-0.554368,0.035559,-0.189976,-0.145465,0.109084
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
557,Myeloid cells,CHIR-99021,LSM-1181,Cc1cnc(-c2cnc(NCCNc3ccc(C#N)cn3)nc2-c2ccc(Cl)c...,False,0.338227,-0.431007,0.066335,-0.907600,0.973881,...,0.187125,0.800574,0.214946,-0.104931,-0.125619,-0.160210,-0.886414,-2.955785,-0.866944,-0.053017
558,NK cells,CHIR-99021,LSM-1181,Cc1cnc(-c2cnc(NCCNc3ccc(C#N)cn3)nc2-c2ccc(Cl)c...,False,-1.026443,-0.024840,0.204772,-0.861985,-2.972540,...,0.323512,1.131738,-0.064157,0.011429,0.030404,0.015837,-0.055027,-0.329874,0.327199,-0.830306
559,T cells CD4+,CHIR-99021,LSM-1181,Cc1cnc(-c2cnc(NCCNc3ccc(C#N)cn3)nc2-c2ccc(Cl)c...,False,-0.545092,0.108150,-0.355024,-1.659293,-0.613075,...,2.537207,-0.429731,-0.605248,0.304310,-0.014653,0.000120,-1.117706,-0.130162,0.001642,-0.665771
560,T cells CD8+,CHIR-99021,LSM-1181,Cc1cnc(-c2cnc(NCCNc3ccc(C#N)cn3)nc2-c2ccc(Cl)c...,False,0.342721,0.921447,-2.992502,-0.842775,-0.408755,...,0.240196,0.302058,-0.555482,-0.009563,0.029517,0.037701,-0.257844,-1.374965,-1.202369,-0.676311


In [5]:
grouped_all_de = all_de_train.groupby('cell_type')[genes].mean().reset_index()
grouped_all_de

Unnamed: 0,cell_type,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,B cells,1.38089,0.530585,1.340812,1.594307,4.927551,3.613191,0.02864,0.544636,0.723079,...,0.257778,0.674977,0.217386,1.439374,0.952903,0.581303,0.637408,0.517737,-0.207092,0.079199
1,Myeloid cells,1.570336,0.752564,-2.856826,0.887845,6.658911,4.034911,0.442943,0.403543,0.196285,...,-0.270423,-0.103318,-1.307952,-0.166312,1.883588,0.612681,-0.583563,-0.427938,-0.292768,-0.067723
2,NK cells,1.726861,1.391056,-0.384683,-0.909542,0.735264,4.687871,0.460742,1.023145,0.046517,...,0.511199,0.462575,0.388045,0.784062,0.945736,1.23205,0.10688,0.600556,-0.585519,-0.156448
3,T cells CD4+,0.160548,0.469133,-0.07414,-1.055539,0.006779,2.138308,-0.247477,0.669044,0.630691,...,1.231267,0.601618,0.150675,0.811554,0.616685,0.330778,0.181473,0.175703,-0.472213,-0.09922
4,T cells CD8+,-0.259045,-0.1425,-0.402029,-0.313932,-0.436765,-0.769503,-0.118564,0.039057,-0.710979,...,0.135679,0.029111,-0.293163,-0.188252,0.014929,-0.223888,0.110891,-0.155235,-0.318527,0.253286
5,T regulatory cells,1.740391,1.263407,2.306707,2.578485,2.978225,2.920715,0.274367,1.225981,0.01202,...,-0.018877,1.311463,0.817281,1.549286,2.184906,2.333129,1.334974,0.791717,-0.092324,-0.028916


In [6]:
diff_all_de = grouped_all_de.copy()
diff_all_de[genes] = grouped_all_de.loc[:, genes] - grouped_all_de.loc[:, genes].mean(axis=0)
diff_all_de.columns = ['cell_type'] + ['cell_diff_' + i for i in genes]
diff_all_de

Unnamed: 0,cell_type,cell_diff_A1BG,cell_diff_A1BG-AS1,cell_diff_A2M,cell_diff_A2M-AS1,cell_diff_A2MP1,cell_diff_A4GALT,cell_diff_AAAS,cell_diff_AACS,cell_diff_AAGAB,...,cell_diff_ZUP1,cell_diff_ZW10,cell_diff_ZWILCH,cell_diff_ZWINT,cell_diff_ZXDA,cell_diff_ZXDB,cell_diff_ZXDC,cell_diff_ZYG11B,cell_diff_ZYX,cell_diff_ZZEF1
0,B cells,0.32756,-0.180123,1.352505,1.130703,2.449224,0.842276,-0.111469,-0.106265,0.573477,...,-0.049993,0.178906,0.222007,0.734422,-0.146888,-0.229706,0.339398,0.267314,0.120982,0.082503
1,Myeloid cells,0.517006,0.041857,-2.845133,0.424241,4.180583,1.263995,0.302835,-0.247358,0.046682,...,-0.578193,-0.599389,-1.303331,-0.871264,0.783797,-0.198328,-0.881573,-0.678362,0.035306,-0.06442
2,NK cells,0.673531,0.680349,-0.37299,-1.373146,-1.743063,1.916955,0.320633,0.372244,-0.103086,...,0.203429,-0.033497,0.392666,0.07911,-0.154055,0.421041,-0.191131,0.350132,-0.257445,-0.153144
3,T cells CD4+,-0.892782,-0.241575,-0.062447,-1.519142,-2.471549,-0.632608,-0.387585,0.018143,0.481089,...,0.923497,0.105547,0.155296,0.106602,-0.483106,-0.48023,-0.116537,-0.07472,-0.14414,-0.095916
4,T cells CD8+,-1.312375,-0.853207,-0.390336,-0.777536,-2.915092,-3.540418,-0.258673,-0.611844,-0.860581,...,-0.172092,-0.46696,-0.288541,-0.893204,-1.084862,-1.034897,-0.18712,-0.405658,0.009547,0.25659
5,T regulatory cells,0.687061,0.552699,2.3184,2.114881,0.499897,0.1498,0.134258,0.57508,-0.137582,...,-0.326647,0.815392,0.821903,0.844334,1.085114,1.52212,1.036963,0.541294,0.23575,-0.025612


In [7]:
grouped_sm_name = de_train.groupby('sm_name')[genes].mean().reset_index()
grouped_sm_name

Unnamed: 0,sm_name,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-...,0.300267,-0.112432,0.413144,1.468632,0.733237,0.722462,0.125359,0.210903,-0.876761,...,-0.769578,-0.690020,0.303616,0.260685,0.555278,0.837875,0.444535,0.432414,-0.219858,0.551906
1,ABT-199 (GDC-0199),-0.081286,0.007314,0.081242,-0.125777,0.219469,0.258288,-0.160568,0.023898,0.317472,...,0.430786,0.094845,-0.088646,-0.140509,0.216322,-0.065943,0.113272,-0.181743,0.068095,-0.093228
2,ABT737,0.408012,0.322574,0.107448,-0.049174,0.422284,1.151523,0.751861,0.189453,-0.121147,...,0.186543,-0.180051,0.028183,0.413515,0.166978,0.327588,0.256550,-0.069630,-0.135967,-0.728025
3,AMD-070 (hydrochloride),-0.031131,0.533648,0.124738,0.241484,-0.017756,0.039647,-0.173965,0.806999,-0.019594,...,-0.100840,0.065319,0.193013,0.310814,0.018807,0.144418,0.000372,0.204476,-0.077820,0.166340
4,AT 7867,0.242736,-0.275840,0.158312,0.267365,-0.003346,0.183553,-0.228290,0.162294,-0.240859,...,-0.704684,-0.088803,-0.120544,-0.337481,0.051235,0.466585,-0.157225,0.622629,0.022401,0.079217
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
124,Vandetanib,-0.006076,-0.672747,-0.230338,-0.492947,0.109427,-0.528983,0.030436,0.155058,-0.130232,...,0.113566,-0.056856,-0.375688,-0.098679,-0.203246,-0.313355,0.105695,-0.004788,0.135301,0.254045
125,Vanoxerine,0.188002,-0.459637,0.107419,0.442630,0.288657,0.239626,-0.191168,-0.257659,-0.149061,...,0.023219,-0.114899,0.148367,0.166318,0.220331,-0.095794,-0.191277,0.358461,-0.246971,-0.221379
126,Vardenafil,-0.097319,0.526795,-0.339189,-0.156595,0.021584,-0.269225,-0.216612,-0.447963,0.047921,...,0.044047,0.050485,0.528853,0.162957,0.079624,-0.091698,-0.154987,0.191808,0.325303,-0.385319
127,Vorinostat,0.075208,0.014575,0.454048,-0.009477,0.342563,-0.226654,0.517033,0.845815,1.743839,...,0.320819,0.542119,-0.402185,-0.104306,-0.586294,0.382534,-0.107183,0.212404,-0.029705,0.270488


In [8]:
diff_sm_name = grouped_sm_name.copy()
diff_sm_name[genes] = grouped_sm_name.loc[:, genes] - grouped_sm_name.loc[:, genes].mean(axis=0)
diff_sm_name.columns = ['sm_name'] + ['sm_diff_' + i for i in genes]
diff_sm_name

Unnamed: 0,sm_name,sm_diff_A1BG,sm_diff_A1BG-AS1,sm_diff_A2M,sm_diff_A2M-AS1,sm_diff_A2MP1,sm_diff_A4GALT,sm_diff_AAAS,sm_diff_AACS,sm_diff_AAGAB,...,sm_diff_ZUP1,sm_diff_ZW10,sm_diff_ZWILCH,sm_diff_ZWINT,sm_diff_ZXDA,sm_diff_ZXDB,sm_diff_ZXDC,sm_diff_ZYG11B,sm_diff_ZYX,sm_diff_ZZEF1
0,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-...,0.066678,-0.276644,0.234661,1.378728,0.403626,0.118317,0.150890,-0.079316,-0.859039,...,-0.794320,-0.785941,0.277779,0.075590,0.173820,0.537256,0.214356,0.288601,-0.091651,0.636851
1,ABT-199 (GDC-0199),-0.314874,-0.156898,-0.097241,-0.215681,-0.110141,-0.345857,-0.135036,-0.266321,0.335193,...,0.406043,-0.001076,-0.114482,-0.325603,-0.165137,-0.366562,-0.116907,-0.325555,0.196302,-0.008283
2,ABT737,0.174423,0.158362,-0.071035,-0.139078,0.092673,0.547377,0.777392,-0.100766,-0.103426,...,0.161801,-0.275972,0.002347,0.228421,-0.214481,0.026969,0.026371,-0.213442,-0.007760,-0.643080
3,AMD-070 (hydrochloride),-0.264720,0.369437,-0.053745,0.151580,-0.347366,-0.564498,-0.148433,0.516780,-0.001873,...,-0.125582,-0.030601,0.167177,0.125719,-0.362652,-0.156201,-0.229808,0.060664,0.050388,0.251284
4,AT 7867,0.009148,-0.440052,-0.020171,0.177461,-0.332957,-0.420592,-0.202759,-0.127924,-0.223138,...,-0.729427,-0.184723,-0.146380,-0.522576,-0.330223,0.165965,-0.387405,0.478816,0.150609,0.164162
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
124,Vandetanib,-0.239665,-0.836959,-0.408820,-0.582851,-0.220184,-1.133128,0.055968,-0.135161,-0.112511,...,0.088824,-0.152777,-0.401524,-0.283773,-0.584705,-0.613974,-0.124484,-0.148600,0.263509,0.338990
125,Vanoxerine,-0.045587,-0.623849,-0.071064,0.352726,-0.040954,-0.364520,-0.165637,-0.547878,-0.131340,...,-0.001523,-0.210819,0.122531,-0.018776,-0.161128,-0.396414,-0.421457,0.214648,-0.118764,-0.136434
126,Vardenafil,-0.330907,0.362583,-0.517672,-0.246499,-0.308027,-0.873371,-0.191080,-0.738182,0.065642,...,0.019304,-0.045436,0.503017,-0.022137,-0.301835,-0.392317,-0.385166,0.047995,0.453511,-0.300374
127,Vorinostat,-0.158380,-0.149637,0.275565,-0.099381,0.012953,-0.830799,0.542564,0.555596,1.761560,...,0.296077,0.446198,-0.428021,-0.289400,-0.967753,0.081915,-0.337363,0.068592,0.098502,0.355433


In [9]:
# Connect de_train and diff_all_de by cell_type
features = ["cell_type", "sm_name"]
de_train.reset_index(inplace=True)

X = de_train[features].copy()
X = pd.merge(X, diff_all_de, on='cell_type', how='left')
X = pd.merge(X, diff_sm_name, on='sm_name', how='left')
X.drop(features, axis=1, inplace=True)
X

Unnamed: 0,cell_diff_A1BG,cell_diff_A1BG-AS1,cell_diff_A2M,cell_diff_A2M-AS1,cell_diff_A2MP1,cell_diff_A4GALT,cell_diff_AAAS,cell_diff_AACS,cell_diff_AAGAB,cell_diff_AAK1,...,sm_diff_ZUP1,sm_diff_ZW10,sm_diff_ZWILCH,sm_diff_ZWINT,sm_diff_ZXDA,sm_diff_ZXDB,sm_diff_ZXDC,sm_diff_ZYG11B,sm_diff_ZYX,sm_diff_ZZEF1
0,0.673531,0.680349,-0.372990,-1.373146,-1.743063,1.916955,0.320633,0.372244,-0.103086,0.038950,...,-0.122370,-0.006705,0.010019,-0.080004,-0.496267,-0.718154,-0.274901,0.132798,0.531371,-0.126436
1,-0.892782,-0.241575,-0.062447,-1.519142,-2.471549,-0.632608,-0.387585,0.018143,0.481089,-0.513731,...,-0.122370,-0.006705,0.010019,-0.080004,-0.496267,-0.718154,-0.274901,0.132798,0.531371,-0.126436
2,-1.312375,-0.853207,-0.390336,-0.777536,-2.915092,-3.540418,-0.258673,-0.611844,-0.860581,0.269087,...,-0.122370,-0.006705,0.010019,-0.080004,-0.496267,-0.718154,-0.274901,0.132798,0.531371,-0.126436
3,0.687061,0.552699,2.318400,2.114881,0.499897,0.149800,0.134258,0.575080,-0.137582,-0.777092,...,-0.122370,-0.006705,0.010019,-0.080004,-0.496267,-0.718154,-0.274901,0.132798,0.531371,-0.126436
4,0.673531,0.680349,-0.372990,-1.373146,-1.743063,1.916955,0.320633,0.372244,-0.103086,0.038950,...,-0.182066,0.290490,0.304224,0.019832,-0.150009,0.175004,0.043966,-0.085937,-1.003342,0.698486
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
509,0.687061,0.552699,2.318400,2.114881,0.499897,0.149800,0.134258,0.575080,-0.137582,-0.777092,...,-0.223532,-0.239614,0.329282,-0.497519,-0.117295,-0.110864,-0.299586,-0.066721,0.142047,-0.525575
510,0.673531,0.680349,-0.372990,-1.373146,-1.743063,1.916955,0.320633,0.372244,-0.103086,0.038950,...,-0.255000,-1.155594,-0.298194,-0.345108,-0.707706,-0.068622,-0.096360,-0.412233,0.458510,-0.636464
511,-0.892782,-0.241575,-0.062447,-1.519142,-2.471549,-0.632608,-0.387585,0.018143,0.481089,-0.513731,...,-0.255000,-1.155594,-0.298194,-0.345108,-0.707706,-0.068622,-0.096360,-0.412233,0.458510,-0.636464
512,-1.312375,-0.853207,-0.390336,-0.777536,-2.915092,-3.540418,-0.258673,-0.611844,-0.860581,0.269087,...,-0.255000,-1.155594,-0.298194,-0.345108,-0.707706,-0.068622,-0.096360,-0.412233,0.458510,-0.636464


In [10]:
y = de_train[genes].copy()
y

Unnamed: 0,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,0.104720,-0.077524,-1.625596,-0.144545,0.143555,0.073229,-0.016823,0.101717,-0.005153,1.043629,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,0.915953,-0.884380,0.371834,-0.081677,-0.498266,0.203559,0.604656,0.498592,-0.317184,0.375550,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.704780,1.096702,-0.869887
2,-0.387721,-0.305378,0.567777,0.303895,-0.022653,-0.480681,0.467144,-0.293205,-0.005098,0.214918,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.213550,0.415768,0.078439,-0.259365
3,0.232893,0.129029,0.336897,0.486946,0.767661,0.718590,-0.162145,0.157206,-3.654218,-0.212402,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.224700,-0.048233,0.216139,-0.085024
4,4.290652,-0.063864,-0.017443,-0.541154,0.570982,2.022829,0.600011,1.231275,0.236739,0.338703,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
509,-0.014372,-0.122464,-0.456366,-0.147894,-0.545382,-0.544709,0.282458,-0.431359,-0.364961,0.043123,...,-0.549987,-2.200925,0.359806,1.073983,0.356939,-0.029603,-0.528817,0.105138,0.491015,-0.979951
510,-0.455549,0.188181,0.595734,-0.100299,0.786192,0.090954,0.169523,0.428297,0.106553,0.435088,...,-1.236905,0.003854,-0.197569,-0.175307,0.101391,1.028394,0.034144,-0.231642,1.023994,-0.064760
511,0.338168,-0.109079,0.270182,-0.436586,-0.069476,-0.061539,0.002818,-0.027167,-0.383696,0.226289,...,0.077579,-1.101637,0.457201,0.535184,-0.198404,-0.005004,0.552810,-0.209077,0.389751,-0.337082
512,0.101138,-0.409724,-0.606292,-0.071300,-0.001789,-0.706087,-0.620919,-1.485381,0.059303,-0.032584,...,0.005951,-0.893093,-1.003029,-0.080367,-0.076604,0.024849,0.012862,-0.029684,0.005506,-1.733112


In [14]:
X_sub = id_map[features].copy()
X_sub = pd.merge(X_sub, diff_all_de, on='cell_type', how='left')
X_sub = pd.merge(X_sub, diff_sm_name, on='sm_name', how='left')
X_sub.drop(features, axis=1, inplace=True)
X_sub

Unnamed: 0,cell_diff_A1BG,cell_diff_A1BG-AS1,cell_diff_A2M,cell_diff_A2M-AS1,cell_diff_A2MP1,cell_diff_A4GALT,cell_diff_AAAS,cell_diff_AACS,cell_diff_AAGAB,cell_diff_AAK1,...,sm_diff_ZUP1,sm_diff_ZW10,sm_diff_ZWILCH,sm_diff_ZWINT,sm_diff_ZXDA,sm_diff_ZXDB,sm_diff_ZXDC,sm_diff_ZYG11B,sm_diff_ZYX,sm_diff_ZZEF1
0,0.327560,-0.180123,1.352505,1.130703,2.449224,0.842276,-0.111469,-0.106265,0.573477,0.406898,...,-0.794320,-0.785941,0.277779,0.075590,0.173820,0.537256,0.214356,0.288601,-0.091651,0.636851
1,0.327560,-0.180123,1.352505,1.130703,2.449224,0.842276,-0.111469,-0.106265,0.573477,0.406898,...,0.406043,-0.001076,-0.114482,-0.325603,-0.165137,-0.366562,-0.116907,-0.325555,0.196302,-0.008283
2,0.327560,-0.180123,1.352505,1.130703,2.449224,0.842276,-0.111469,-0.106265,0.573477,0.406898,...,0.161801,-0.275972,0.002347,0.228421,-0.214481,0.026969,0.026371,-0.213442,-0.007760,-0.643080
3,0.327560,-0.180123,1.352505,1.130703,2.449224,0.842276,-0.111469,-0.106265,0.573477,0.406898,...,-0.125582,-0.030601,0.167177,0.125719,-0.362652,-0.156201,-0.229808,0.060664,0.050388,0.251284
4,0.327560,-0.180123,1.352505,1.130703,2.449224,0.842276,-0.111469,-0.106265,0.573477,0.406898,...,-0.729427,-0.184723,-0.146380,-0.522576,-0.330223,0.165965,-0.387405,0.478816,0.150609,0.164162
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,0.517006,0.041857,-2.845133,0.424241,4.180583,1.263995,0.302835,-0.247358,0.046682,0.575888,...,0.088824,-0.152777,-0.401524,-0.283773,-0.584705,-0.613974,-0.124484,-0.148600,0.263509,0.338990
251,0.517006,0.041857,-2.845133,0.424241,4.180583,1.263995,0.302835,-0.247358,0.046682,0.575888,...,-0.001523,-0.210819,0.122531,-0.018776,-0.161128,-0.396414,-0.421457,0.214648,-0.118764,-0.136434
252,0.517006,0.041857,-2.845133,0.424241,4.180583,1.263995,0.302835,-0.247358,0.046682,0.575888,...,0.019304,-0.045436,0.503017,-0.022137,-0.301835,-0.392317,-0.385166,0.047995,0.453511,-0.300374
253,0.517006,0.041857,-2.845133,0.424241,4.180583,1.263995,0.302835,-0.247358,0.046682,0.575888,...,0.296077,0.446198,-0.428021,-0.289400,-0.967753,0.081915,-0.337363,0.068592,0.098502,0.355433


In [11]:
# Get index of cell_type is NK cells
index = de_train[de_train["cell_type"] == "NK cells"].index
X_test = X.loc[index]
y_test = y.loc[index]

# Drop cell_type = 'NK cells'
X_train = X.drop(index)
y_train = y.drop(index)

In [13]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV

# Function to calculate Mean Rowwise Root Mean Squared Error
def mrrmse(y_true, y_pred):
    return np.mean(np.sqrt(np.mean((y_true - y_pred) ** 2, axis=1)))

# Define the hyperparameter grid to search
param_grid = {
    'alpha': [0.1, 0.5, 1, 10, 1000, 10000],
    #'fit_intercept': [True, False],
    #'solver': ['auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga', 'lbfgs'],
    #'max_iter': [1000, 5000, 10000],
    #'tol': [1e-3, 1e-4, 1e-5],
}

# Initialize the model
ridge = Ridge()

# Initialize GridSearchCV
grid_search = GridSearchCV(ridge, param_grid, cv=5, scoring='neg_mean_squared_error', n_jobs=1)

# Fit GridSearchCV
grid_search.fit(X_train, y_train)

# Best model
best_ridge = grid_search.best_estimator_

# Predict on test data
y_pred = best_ridge.predict(X_test)

# Calculate MRRMSE
mrrmse_score = mrrmse(y_test, y_pred)
print("Best Ridge Model Parameters:", grid_search.best_params_)
print("MRRMSE Score on Test Data:", mrrmse_score)

Best Ridge Model Parameters: {'alpha': 1000}
MRRMSE Score on Test Data: 1.1464951021784702


In [15]:
# Retrain by using whole X and y by best parameters
best_ridge = Ridge(**grid_search.best_params_)
best_ridge.fit(X, y)

# Predict on X_sub
y_sub_pred = best_ridge.predict(X_sub)

id_map[genes] = y_sub_pred
id_map = id_map.loc[:, ["id"] + genes.to_list()]
id_map.to_csv('submission.csv', index=False)
id_map

Unnamed: 0,id,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,0,0.421180,0.046216,0.479265,1.488925,0.895533,0.932395,0.120605,0.312069,-0.832354,...,-0.642277,-0.642881,0.334695,0.311509,0.561211,0.880780,0.485176,0.524537,-0.287176,0.549943
1,1,0.066268,0.134671,0.089829,-0.048125,0.328376,0.471805,-0.116430,0.134574,0.258221,...,0.476566,0.066763,-0.068217,-0.046447,0.216392,0.077163,0.169464,-0.074178,0.007134,-0.142185
2,2,0.545515,0.387604,0.138310,0.020890,0.556632,1.367078,0.679136,0.287170,-0.123674,...,0.239241,-0.150229,0.062258,0.438863,0.196919,0.406640,0.311722,0.013653,-0.147792,-0.641066
3,3,0.062988,0.584623,0.126479,0.273277,0.126677,0.245417,-0.128797,0.821058,-0.052129,...,-0.004899,0.018904,0.196918,0.328980,0.051411,0.249436,0.063545,0.276805,-0.122279,0.134017
4,4,0.321529,-0.150276,0.108164,0.279592,0.150570,0.392028,-0.163329,0.214360,-0.247636,...,-0.539106,-0.110891,-0.084548,-0.251893,0.076618,0.522469,-0.043280,0.632485,-0.029326,0.099883
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,250,0.123364,-0.446838,-0.230563,-0.369108,0.228340,-0.320425,0.043766,0.182957,-0.191702,...,0.216069,-0.159928,-0.330994,-0.111412,-0.198559,-0.170948,0.150703,0.066679,0.099345,0.216570
251,251,0.336390,-0.272271,0.092601,0.468109,0.436630,0.458797,-0.138561,-0.120472,-0.225945,...,0.115384,-0.163099,0.124014,0.190264,0.193507,0.036854,-0.067296,0.405367,-0.233711,-0.203822
252,252,0.061964,0.576702,-0.267426,-0.027441,0.169214,-0.050537,-0.155163,-0.263655,-0.042741,...,0.151407,-0.036319,0.441483,0.169417,0.067879,0.039701,-0.064865,0.277146,0.238588,-0.302141
253,253,0.225668,0.164098,0.426292,0.046859,0.489696,-0.010520,0.557695,0.925749,1.645311,...,0.419794,0.491761,-0.381563,-0.055898,-0.561067,0.494096,-0.046224,0.311232,-0.055833,0.233846
