In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans

SEED = 6174
np.random.seed(SEED)

# Set the folder path for data
folder_path = "./input"

In [2]:
de_train = pd.read_parquet(f'{folder_path}/de_train.parquet')
genes = de_train.columns[5:]
id_map = pd.read_csv (f'{folder_path}/id_map.csv')
id_map = id_map.reindex(id_map.columns.tolist() + genes.tolist(), axis=1)

sm_lincs_id = de_train.set_index('sm_name')["sm_lincs_id"].to_dict()
sm_name_to_smiles = de_train.set_index('sm_name')['SMILES'].to_dict()

id_map['sm_lincs_id'] = id_map['sm_name'].map(sm_lincs_id)
id_map['SMILES'] = id_map['sm_name'].map(sm_name_to_smiles)

de_train

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.104720,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.884380,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.704780,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.213550,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.224700,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,LSM-5771,CC(C)c1c(C(=O)Nc2ccccc2)c(-c2ccccc2)c(-c2ccc(F...,False,-0.014372,-0.122464,-0.456366,-0.147894,-0.545382,...,-0.549987,-2.200925,0.359806,1.073983,0.356939,-0.029603,-0.528817,0.105138,0.491015,-0.979951
610,NK cells,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,-0.455549,0.188181,0.595734,-0.100299,0.786192,...,-1.236905,0.003854,-0.197569,-0.175307,0.101391,1.028394,0.034144,-0.231642,1.023994,-0.064760
611,T cells CD4+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.338168,-0.109079,0.270182,-0.436586,-0.069476,...,0.077579,-1.101637,0.457201,0.535184,-0.198404,-0.005004,0.552810,-0.209077,0.389751,-0.337082
612,T cells CD8+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.101138,-0.409724,-0.606292,-0.071300,-0.001789,...,0.005951,-0.893093,-1.003029,-0.080367,-0.076604,0.024849,0.012862,-0.029684,0.005506,-1.733112


In [55]:
from sklearn.decomposition import TruncatedSVD

n_components = 500

genes_svd = TruncatedSVD(n_components=n_components, random_state=6174)
genes_svd.fit(de_train[genes])

de_train_svd = pd.concat([de_train[['cell_type', 'sm_name']], pd.DataFrame(genes_svd.transform(de_train[genes]))], axis=1)
svd_col_names = [f'svd_{i}' for i in range(n_components)]
de_train_svd.columns = ['cell_type', 'sm_name'] + svd_col_names

de_train_svd

Unnamed: 0,cell_type,sm_name,svd_0,svd_1,svd_2,svd_3,svd_4,svd_5,svd_6,svd_7,...,svd_490,svd_491,svd_492,svd_493,svd_494,svd_495,svd_496,svd_497,svd_498,svd_499
0,NK cells,Clotrimazole,28.937383,3.256932,-2.177188,0.468545,-5.621313,4.551657,0.579008,-1.745717,...,-1.448222,1.809094,-1.191059,2.155270,-3.115348,0.895780,0.981621,-2.082788,1.613594,-0.276613
1,T cells CD4+,Clotrimazole,9.348546,8.662562,-2.060196,0.569955,-2.757469,5.832881,-3.394118,-3.942866,...,-1.049417,3.727830,0.923928,3.046239,4.430452,0.210493,3.420779,-0.828532,-2.400851,-1.988120
2,T cells CD8+,Clotrimazole,-35.811098,6.782070,-4.381716,4.732942,-7.618345,-0.230074,-3.123814,-2.686061,...,3.575040,-2.297633,-4.722379,-4.372242,4.173352,-3.190065,4.957041,1.355969,0.146838,0.041571
3,T regulatory cells,Clotrimazole,56.717018,-1.823420,1.623762,-7.199412,15.891014,4.050272,0.559890,-1.504857,...,0.846404,2.916620,8.348011,-5.965253,-10.348720,-6.374900,1.964278,3.352593,6.293860,-0.806234
4,NK cells,Mometasone Furoate,143.505709,5.531572,1.678172,2.653576,-40.262067,43.428336,5.563466,9.171054,...,-0.734609,1.590586,0.463078,0.253332,0.982482,-0.192322,0.437641,0.246400,-0.715719,-0.700939
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,-53.364344,6.742935,-8.874138,9.641464,-24.493771,-0.228756,-1.885098,-5.429054,...,2.934324,1.919666,-0.524809,-0.386633,-1.848113,1.086940,-1.079860,-0.973856,-0.566561,0.381697
610,NK cells,Riociguat,14.963827,0.252246,0.760688,-0.486874,-2.301388,3.568497,0.071688,0.281288,...,0.082902,-4.576433,2.534524,-2.201409,2.499091,0.116433,-3.292886,-0.716122,0.553430,1.738724
611,T cells CD4+,Riociguat,-5.031586,6.439422,-0.000287,-1.317794,-0.687400,2.886627,-2.994794,-0.251803,...,-2.301295,0.481109,1.675253,3.691823,-5.992915,-0.620876,-0.247681,-0.257542,-2.386236,-3.254447
612,T cells CD8+,Riociguat,-52.892264,7.502364,-2.890900,7.710026,-12.018751,2.703746,-3.876833,-0.758974,...,2.759349,-0.148490,0.611176,-7.353251,-0.879528,4.564364,3.024275,-0.253538,-3.626467,3.639673


In [56]:
genes_df = de_train_svd[svd_col_names].describe().T
genes_df

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
svd_0,614.0,65.812602,240.717884,-371.761046,-5.695867,12.126818,44.911032,2693.355499
svd_1,614.0,14.743826,109.771568,-91.205022,-1.416336,1.990916,7.075620,2030.014243
svd_2,614.0,7.842530,67.620994,-661.460715,-2.142572,0.345851,4.608089,788.468074
svd_3,614.0,3.284427,57.459416,-366.058868,-3.211833,0.203727,3.336091,820.735195
svd_4,614.0,1.333369,50.129446,-371.542081,-5.788433,0.565848,7.191667,662.879735
...,...,...,...,...,...,...,...,...
svd_495,614.0,-0.005299,2.291777,-9.520510,-0.996399,0.035385,1.018669,10.644157
svd_496,614.0,-0.006205,2.286919,-7.894911,-1.001229,-0.015834,1.065338,8.421964
svd_497,614.0,0.019400,2.279833,-8.414037,-1.027109,0.018879,1.022285,9.727079
svd_498,614.0,0.017803,2.275816,-7.886530,-1.121158,0.005805,1.075749,9.413213


In [57]:
all_sm_names = de_train_svd[de_train_svd["cell_type"]=="B cells"]["sm_name"].to_list()
all_de_train = de_train_svd[de_train_svd["sm_name"].isin(all_sm_names)]

all_de_train

Unnamed: 0,cell_type,sm_name,svd_0,svd_1,svd_2,svd_3,svd_4,svd_5,svd_6,svd_7,...,svd_490,svd_491,svd_492,svd_493,svd_494,svd_495,svd_496,svd_497,svd_498,svd_499
8,B cells,Idelalisib,23.773730,2.837433,-12.976027,11.651752,8.042760,37.405285,-20.805010,-6.955465,...,-0.476947,0.548872,-0.247447,0.312428,0.056117,-0.678452,-0.292642,-0.717953,0.346160,0.593856
9,Myeloid cells,Idelalisib,91.649349,-4.908918,-21.450276,43.214048,-4.684578,-3.460724,-10.541607,9.928496,...,0.133796,0.436068,0.928443,-0.337422,-0.045207,0.680503,0.235870,0.064594,0.185112,-0.337199
10,NK cells,Idelalisib,6.679450,2.727796,-4.994478,4.101686,-0.134387,17.906059,-1.503379,-3.606896,...,1.567086,0.871945,1.360412,-1.141874,-0.373261,-2.366621,0.172069,-0.910399,0.226739,1.517842
11,T cells CD4+,Idelalisib,4.849140,1.155022,-4.466306,6.422451,6.877696,30.257279,-9.298240,-7.313907,...,-2.181301,0.462986,-0.707369,-0.074810,2.795133,-0.495488,1.715914,-0.496254,-4.447414,-2.351793
12,T cells CD8+,Idelalisib,8.125392,1.192314,-1.433962,0.296609,1.720211,7.836396,-1.962089,-2.424426,...,2.704396,-3.853076,-1.574052,-1.430978,4.737484,-9.520510,-4.020194,-2.068768,-3.793253,-1.499627
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
557,Myeloid cells,CHIR-99021,39.138356,4.672281,13.453651,17.100034,-8.487260,-19.908061,-8.355001,17.015393,...,-0.090668,0.135254,0.057720,-0.255454,-0.466318,-0.322611,-0.378972,0.188430,0.067239,-0.179625
558,NK cells,CHIR-99021,36.748560,-2.452927,11.748416,-3.360486,-14.075230,-5.914131,2.889764,-0.069173,...,-0.128324,0.827010,-0.616260,-1.014289,0.700187,-0.415684,0.701156,1.053580,1.637547,-1.584513
559,T cells CD4+,CHIR-99021,-17.897179,10.862666,24.872869,-10.306525,-18.156121,-21.695314,-16.115189,-3.779427,...,0.331059,0.007971,0.762683,-0.330295,0.153163,0.144772,1.203233,-0.793599,-1.084560,0.355356
560,T cells CD8+,CHIR-99021,-21.425692,19.176200,6.598163,7.983050,-19.478631,33.416225,-30.844137,-12.178122,...,-0.381852,-0.822147,-0.856063,2.147645,-0.664049,-0.044471,-0.105949,1.762631,-0.278883,-1.024782


In [58]:
all_sm_names

['Idelalisib',
 'Crizotinib',
 'Linagliptin',
 'Palbociclib',
 'Dabrafenib',
 'Alvocidib',
 'LDN 193189',
 'R428',
 'Porcn Inhibitor III',
 'Belinostat',
 'Foretinib',
 'MLN 2238',
 'Penfluridol',
 'Dactolisib',
 'O-Demethylated Adapalene',
 'Oprozomib (ONX 0912)',
 'CHIR-99021']

In [59]:
de_train_svd[de_train_svd["cell_type"]=="B cells"][svd_col_names[0]].describe()

count      17.000000
mean      252.397271
std       490.773902
min       -21.883059
25%        19.956129
50%        38.410757
75%        76.694570
max      1421.664355
Name: svd_0, dtype: float64

In [60]:
de_train_svd

Unnamed: 0,cell_type,sm_name,svd_0,svd_1,svd_2,svd_3,svd_4,svd_5,svd_6,svd_7,...,svd_490,svd_491,svd_492,svd_493,svd_494,svd_495,svd_496,svd_497,svd_498,svd_499
0,NK cells,Clotrimazole,28.937383,3.256932,-2.177188,0.468545,-5.621313,4.551657,0.579008,-1.745717,...,-1.448222,1.809094,-1.191059,2.155270,-3.115348,0.895780,0.981621,-2.082788,1.613594,-0.276613
1,T cells CD4+,Clotrimazole,9.348546,8.662562,-2.060196,0.569955,-2.757469,5.832881,-3.394118,-3.942866,...,-1.049417,3.727830,0.923928,3.046239,4.430452,0.210493,3.420779,-0.828532,-2.400851,-1.988120
2,T cells CD8+,Clotrimazole,-35.811098,6.782070,-4.381716,4.732942,-7.618345,-0.230074,-3.123814,-2.686061,...,3.575040,-2.297633,-4.722379,-4.372242,4.173352,-3.190065,4.957041,1.355969,0.146838,0.041571
3,T regulatory cells,Clotrimazole,56.717018,-1.823420,1.623762,-7.199412,15.891014,4.050272,0.559890,-1.504857,...,0.846404,2.916620,8.348011,-5.965253,-10.348720,-6.374900,1.964278,3.352593,6.293860,-0.806234
4,NK cells,Mometasone Furoate,143.505709,5.531572,1.678172,2.653576,-40.262067,43.428336,5.563466,9.171054,...,-0.734609,1.590586,0.463078,0.253332,0.982482,-0.192322,0.437641,0.246400,-0.715719,-0.700939
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,-53.364344,6.742935,-8.874138,9.641464,-24.493771,-0.228756,-1.885098,-5.429054,...,2.934324,1.919666,-0.524809,-0.386633,-1.848113,1.086940,-1.079860,-0.973856,-0.566561,0.381697
610,NK cells,Riociguat,14.963827,0.252246,0.760688,-0.486874,-2.301388,3.568497,0.071688,0.281288,...,0.082902,-4.576433,2.534524,-2.201409,2.499091,0.116433,-3.292886,-0.716122,0.553430,1.738724
611,T cells CD4+,Riociguat,-5.031586,6.439422,-0.000287,-1.317794,-0.687400,2.886627,-2.994794,-0.251803,...,-2.301295,0.481109,1.675253,3.691823,-5.992915,-0.620876,-0.247681,-0.257542,-2.386236,-3.254447
612,T cells CD8+,Riociguat,-52.892264,7.502364,-2.890900,7.710026,-12.018751,2.703746,-3.876833,-0.758974,...,2.759349,-0.148490,0.611176,-7.353251,-0.879528,4.564364,3.024275,-0.253538,-3.626467,3.639673


# Make transformed_de_train

In [61]:
%%time

def get_svd_features(svd, column):
		return de_train_svd[column].describe().values[1:]

def get_cell_features(cell_type, column):
		return all_de_train[all_de_train["cell_type"]==cell_type][column].describe().values[1:]

def get_sm_name_features(sm_name, column):
		return de_train_svd[de_train_svd["sm_name"]==sm_name][column].describe().values[1:]

def merge_features(df, features):
    # This function will merge the extracted features back into the dataframe
    for feature in features:
        feature_data = df[feature].apply(pd.Series)
        feature_data.columns = [f'{feature}_{stat}' for stat in ['mean', 'std', 'min', '25%', '50%', '75%', 'max']]
        df = pd.concat([df, feature_data], axis=1)
    return df

def transform_de(df):
		df_long = df.set_index(['sm_name', 'cell_type'])[svd_col_names].stack().reset_index(name='value')
		df_long.columns = ['sm_name', 'cell_type', 'svd', 'value']

		df_long.loc[:, 'svd_features'] = df_long.apply(lambda x: get_svd_features(x['svd'], x['svd']), axis=1)
		df_long.loc[:, 'cell_features'] = df_long.apply(lambda x: get_cell_features(x['cell_type'], x['svd']), axis=1)
		df_long.loc[:, 'sm_name_features'] = df_long.apply(lambda x: get_sm_name_features(x['sm_name'], x['svd']), axis=1)

		features = ['svd_features', 'cell_features', 'sm_name_features']
		df_long = merge_features(df_long, features)
		df_long = df_long.drop(features, axis=1)
		
		return df_long

transformed_de = transform_de(de_train_svd)
transformed_de.to_csv('transformed_de_train.csv', index=False)
transformed_de

CPU times: user 17min 5s, sys: 2.02 s, total: 17min 7s
Wall time: 17min 7s


Unnamed: 0,sm_name,cell_type,svd,value,svd_features_mean,svd_features_std,svd_features_min,svd_features_25%,svd_features_50%,svd_features_75%,...,cell_features_50%,cell_features_75%,cell_features_max,sm_name_features_mean,sm_name_features_std,sm_name_features_min,sm_name_features_25%,sm_name_features_50%,sm_name_features_75%,sm_name_features_max
0,Clotrimazole,NK cells,svd_0,28.937383,65.812602,240.717884,-371.761046,-5.695867,12.126818,44.911032,...,27.125380,132.025202,1862.909197,14.797962,38.936288,-35.811098,-1.941365,19.142964,35.882291,56.717018
1,Clotrimazole,NK cells,svd_1,3.256932,14.743826,109.771568,-91.205022,-1.416336,1.990916,7.075620,...,-1.665262,2.727796,1013.782995,4.219536,4.609808,-1.823420,1.986844,5.019501,7.252193,8.662562
2,Clotrimazole,NK cells,svd_2,-2.177188,7.842530,67.620994,-661.460715,-2.142572,0.345851,4.608089,...,-0.175959,9.549005,34.824513,-1.748834,2.489103,-4.381716,-2.728320,-2.118692,-1.139207,1.623762
3,Clotrimazole,NK cells,svd_3,0.468545,3.284427,57.459416,-366.058868,-3.211833,0.203727,3.336091,...,1.282137,2.870305,32.523218,-0.356993,4.975503,-7.199412,-1.448444,0.519250,1.610701,4.732942
4,Clotrimazole,NK cells,svd_4,-5.621313,1.333369,50.129446,-371.542081,-5.788433,0.565848,7.191667,...,-6.707307,-1.432064,7.197147,-0.026528,10.797584,-7.618345,-6.120571,-4.189391,1.904652,15.891014
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
306995,Riociguat,T regulatory cells,svd_495,1.522856,-0.005299,2.291777,-9.520510,-0.996399,0.035385,1.018669,...,0.197945,1.378657,4.116555,1.395694,2.291994,-0.620876,-0.067894,0.819645,2.283233,4.564364
306996,Riociguat,T regulatory cells,svd_496,-0.890295,-0.006205,2.286919,-7.894911,-1.001229,-0.015834,1.065338,...,0.046053,0.989956,4.990729,-0.351647,2.604399,-3.292886,-1.490943,-0.568988,0.570308,3.024275
306997,Riociguat,T regulatory cells,svd_497,0.899468,0.019400,2.279833,-8.414037,-1.027109,0.018879,1.022285,...,0.108516,1.026124,6.531056,-0.081934,0.689355,-0.716122,-0.372187,-0.255540,0.034713,0.899468
306998,Riociguat,T regulatory cells,svd_498,0.543504,0.017803,2.275816,-7.886530,-1.121158,0.005805,1.075749,...,0.007552,0.638309,3.850495,-1.228942,2.113912,-3.626467,-2.696294,-0.921366,0.545985,0.553430


In [62]:
index = transformed_de[transformed_de["cell_type"] == "NK cells"].index
index = index.difference(all_de_train.index)

transformed_de_val = transformed_de.loc[index]
transformed_de_train = transformed_de.drop(index)

features = transformed_de.columns[4:]

transformed_de_val

Unnamed: 0,sm_name,cell_type,svd,value,svd_features_mean,svd_features_std,svd_features_min,svd_features_25%,svd_features_50%,svd_features_75%,...,cell_features_50%,cell_features_75%,cell_features_max,sm_name_features_mean,sm_name_features_std,sm_name_features_min,sm_name_features_25%,sm_name_features_50%,sm_name_features_75%,sm_name_features_max
0,Clotrimazole,NK cells,svd_0,28.937383,65.812602,240.717884,-371.761046,-5.695867,12.126818,44.911032,...,27.125380,132.025202,1862.909197,14.797962,38.936288,-35.811098,-1.941365,19.142964,35.882291,56.717018
1,Clotrimazole,NK cells,svd_1,3.256932,14.743826,109.771568,-91.205022,-1.416336,1.990916,7.075620,...,-1.665262,2.727796,1013.782995,4.219536,4.609808,-1.823420,1.986844,5.019501,7.252193,8.662562
2,Clotrimazole,NK cells,svd_2,-2.177188,7.842530,67.620994,-661.460715,-2.142572,0.345851,4.608089,...,-0.175959,9.549005,34.824513,-1.748834,2.489103,-4.381716,-2.728320,-2.118692,-1.139207,1.623762
3,Clotrimazole,NK cells,svd_3,0.468545,3.284427,57.459416,-366.058868,-3.211833,0.203727,3.336091,...,1.282137,2.870305,32.523218,-0.356993,4.975503,-7.199412,-1.448444,0.519250,1.610701,4.732942
4,Clotrimazole,NK cells,svd_4,-5.621313,1.333369,50.129446,-371.542081,-5.788433,0.565848,7.191667,...,-6.707307,-1.432064,7.197147,-0.026528,10.797584,-7.618345,-6.120571,-4.189391,1.904652,15.891014
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
305495,Riociguat,NK cells,svd_495,0.116433,-0.005299,2.291777,-9.520510,-0.996399,0.035385,1.018669,...,0.211868,0.628932,3.849936,1.395694,2.291994,-0.620876,-0.067894,0.819645,2.283233,4.564364
305496,Riociguat,NK cells,svd_496,-3.292886,-0.006205,2.286919,-7.894911,-1.001229,-0.015834,1.065338,...,0.166360,0.904012,2.587145,-0.351647,2.604399,-3.292886,-1.490943,-0.568988,0.570308,3.024275
305497,Riociguat,NK cells,svd_497,-0.716122,0.019400,2.279833,-8.414037,-1.027109,0.018879,1.022285,...,0.148542,1.053580,4.662931,-0.081934,0.689355,-0.716122,-0.372187,-0.255540,0.034713,0.899468
305498,Riociguat,NK cells,svd_498,0.553430,0.017803,2.275816,-7.886530,-1.121158,0.005805,1.075749,...,-0.185784,0.226739,2.038103,-1.228942,2.113912,-3.626467,-2.696294,-0.921366,0.545985,0.553430


In [75]:
from sklearn.linear_model import Ridge
from sklearn.svm import LinearSVR
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV

# Function to calculate Mean Rowwise Root Mean Squared Error
def mrrmse(y_true, y_pred):
		return np.mean(np.sqrt(np.mean((y_true - y_pred) ** 2, axis=1)))

X_train = transformed_de_train[features]
y_train = transformed_de_train["value"]

# Initialize the model
model = make_pipeline(
	StandardScaler(),
	LinearSVR(dual=True, random_state=6174, tol=1e-5)
)

model.fit(X_train, y_train)

X_val = transformed_de_val[features]
preds = model.predict(X_val)

pred_df = transformed_de_val[["sm_name", "cell_type", "svd"]].copy()
pred_df["value"] = preds
pred_df


Unnamed: 0,sm_name,cell_type,svd,value
0,Clotrimazole,NK cells,svd_0,24.963136
1,Clotrimazole,NK cells,svd_1,3.701831
2,Clotrimazole,NK cells,svd_2,-2.818743
3,Clotrimazole,NK cells,svd_3,0.230879
4,Clotrimazole,NK cells,svd_4,-8.785298
...,...,...,...,...
305495,Riociguat,NK cells,svd_495,0.766631
305496,Riociguat,NK cells,svd_496,-0.476001
305497,Riociguat,NK cells,svd_497,-0.185953
305498,Riociguat,NK cells,svd_498,-0.981559


In [76]:
#from sklearn.linear_model import Ridge
#from sklearn.model_selection import GridSearchCV

## Function to calculate Mean Rowwise Root Mean Squared Error
#def mrrmse(y_true, y_pred):
#		return np.mean(np.sqrt(np.mean((y_true - y_pred) ** 2, axis=1)))


## Initialize the model
#model = Ridge()

#pred_df = transformed_de_val.copy()
#pred_df['value'] = np.nan

#for i, svm_col_name in enumerate(svd_col_names):
#		print(f"Training for {svm_col_name}")
#		# Train the model
#		X_train = transformed_de_train[transformed_de_train["svd"] == svm_col_name][features]
#		y_train = transformed_de_train[transformed_de_train["svd"] == svm_col_name]["value"]

#		model.fit(X_train, y_train)

#		X_val = pred_df[pred_df["svd"] == svm_col_name][features]
#		y_pred = model.predict(X_val)
#		insert_index = pred_df[pred_df["svd"] == svm_col_name].index
#		pred_df.loc[insert_index, 'value'] = y_pred


In [77]:
def inverse_de(df):
		# Group by svd as columns
		wide_df = df.pivot_table(index=['sm_name', 'cell_type'], columns='svd', values='value')
		wide_df = wide_df.reset_index()
		return wide_df

pred_df = inverse_de(pred_df)
pred_df


svd,sm_name,cell_type,svd_0,svd_1,svd_10,svd_100,svd_101,svd_102,svd_103,svd_104,...,svd_90,svd_91,svd_92,svd_93,svd_94,svd_95,svd_96,svd_97,svd_98,svd_99
0,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-...,NK cells,26.296017,-1.294505,-1.626875,-0.280500,2.272868,-3.612204,3.542033,-1.620996,...,3.466518,-4.801328,-0.750711,0.068147,1.953735,6.075321,1.499024,-3.520647,1.043418,-2.713171
1,ABT-199 (GDC-0199),NK cells,26.532850,-0.935033,0.429501,-1.186910,-3.205910,-0.581371,-1.407261,-0.739658,...,-1.539178,-0.556773,-1.751180,-0.682532,0.077591,1.497083,1.896087,0.769865,0.160944,-2.522413
2,ABT737,NK cells,60.633969,-1.013220,1.371891,-0.357061,-2.089156,0.822476,2.599465,-1.425928,...,0.028411,-0.012567,-2.826226,-1.302162,-1.436870,-0.578852,0.477671,0.635791,0.740647,0.446885
3,AMD-070 (hydrochloride),NK cells,23.661489,1.562504,1.494583,-0.179945,-0.398419,0.429763,3.307933,-3.256740,...,-0.405386,-1.066821,-0.306479,1.814303,1.533843,-1.759306,0.059576,0.911476,-0.670906,-2.415172
4,AT 7867,NK cells,27.588271,5.154470,-0.585541,-0.200943,-3.517559,-0.276411,0.316595,1.309974,...,4.723079,-4.447446,-0.790851,-2.430166,2.308686,2.977869,0.117026,-0.847913,3.052525,0.393716
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
141,Vandetanib,NK cells,-12.996147,3.340381,1.216760,0.656470,-2.115774,1.059934,-1.452791,0.755146,...,-3.340293,-0.605145,-1.245906,1.948819,-1.263965,0.679002,0.857601,-2.092797,-2.091114,-2.369687
142,Vanoxerine,NK cells,34.162913,3.342623,1.477584,0.007616,-0.193816,0.884427,-1.154434,0.500131,...,-0.109815,-1.929348,1.105671,-0.629280,1.025784,-2.166681,1.066811,-0.879531,-2.869753,4.445199
143,Vardenafil,NK cells,-5.826570,0.158743,1.301059,0.806087,-0.379529,1.022590,-0.076339,-0.362835,...,-1.416260,-0.006774,-1.419150,3.227721,-0.734865,0.809824,-0.033330,-1.867652,-1.823091,-1.138871
144,Vorinostat,NK cells,62.820565,110.358545,9.278139,4.438324,-7.806698,-1.650144,-7.264087,4.697017,...,2.143935,-2.395150,1.047238,9.236462,-2.732329,-3.115463,-5.046280,-3.356418,-1.757928,-3.754304


In [78]:
pred_df[genes] = genes_svd.inverse_transform(pred_df[svd_col_names])
pred_df = pred_df.drop(svd_col_names, axis=1)
pred_df

svd,sm_name,cell_type,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-...,NK cells,-0.156300,-0.063314,-0.114497,0.733592,-0.037314,0.241373,-0.011029,-0.124549,...,-0.422529,-0.610030,0.299421,-0.047426,0.305912,0.358805,0.161990,0.080583,-0.489422,0.432580
1,ABT-199 (GDC-0199),NK cells,-0.158158,0.081403,0.023592,-0.199191,0.213635,0.345770,-0.215693,0.118476,...,0.292336,0.174951,-0.113639,-0.285458,0.180463,-0.162750,0.097626,-0.310797,0.065390,-0.244090
2,ABT737,NK cells,0.240822,0.166800,0.246242,0.246551,0.396952,0.876960,0.481275,0.224197,...,-0.033592,-0.034535,0.007843,-0.052553,0.032273,0.353800,0.054602,-0.166701,-0.061422,-0.521210
3,AMD-070 (hydrochloride),NK cells,0.065485,0.672435,0.002223,-0.066815,-0.004560,0.255012,-0.082702,0.628852,...,-0.058985,-0.077383,0.184032,0.079310,0.249312,0.242974,0.223175,0.409203,0.057551,0.141160
4,AT 7867,NK cells,0.146236,-0.058609,-0.055509,0.114372,0.043183,0.356055,-0.232649,0.016152,...,-0.429203,0.004191,-0.118731,-0.469399,0.114320,0.283176,-0.197038,0.623839,-0.058582,0.387251
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
141,Vandetanib,NK cells,-0.224535,-0.222706,-0.392497,-0.283821,0.148815,-0.195391,-0.106254,0.414728,...,0.260501,-0.159369,-0.450388,-0.313328,-0.215167,-0.064629,0.216355,-0.144923,-0.119817,0.159027
142,Vanoxerine,NK cells,0.101125,-0.312402,-0.102821,0.341886,0.167764,0.532089,-0.077891,0.105633,...,-0.117577,-0.034174,0.048187,0.081943,0.123787,-0.199217,-0.028391,0.122027,-0.046984,0.113754
143,Vardenafil,NK cells,-0.206017,0.416047,-0.314985,-0.080859,-0.087960,-0.143269,-0.259663,-0.406493,...,-0.153322,-0.009457,0.328856,-0.101851,-0.155770,0.073899,-0.185537,0.165818,0.222556,-0.052341
144,Vorinostat,NK cells,0.283633,0.017826,0.185083,0.058592,0.477850,0.280681,0.415923,0.594468,...,0.423114,0.698522,-0.288002,-0.021281,0.096415,0.233919,0.019729,0.266376,-0.155571,0.359097


In [79]:
mrrmse_score = mrrmse(de_train[de_train["cell_type"] == "NK cells"][genes], pred_df[genes])
print("MRRMSE Score on Test Data:", mrrmse_score)

MRRMSE Score on Test Data: 2.019987683550279
