In [2]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Basic utilities
import os
import gc
import glob
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

# Data visualization
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
# import plotly.figure_factory as ff
# import plotly.express as px

# Scientific computing
from scipy import stats
from itertools import groupby

# Machine Learning
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import TruncatedSVD
from sklearn.svm import LinearSVR
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import LinearSVR
from sklearn.multioutput import MultiOutputRegressor
from sklearn.impute import SimpleImputer
from sklearn.decomposition import PCA


# Set the folder path for data
folder_path = "./input"


In [3]:
import warnings
warnings.simplefilter('ignore')

import pandas as pd

pd.set_option('display.max_columns', 30)

import numpy as np

SEED = 6174
np.random.seed(SEED)

In [11]:
de_train = pd.read_parquet(f'{folder_path}/de_train.parquet')
genes = de_train.columns[5:]
id_map = pd.read_csv (f'{folder_path}/id_map.csv')

sm_lincs_id = de_train.set_index('sm_name')["sm_lincs_id"].to_dict()
sm_name_to_smiles = de_train.set_index('sm_name')['SMILES'].to_dict()

id_map['sm_lincs_id'] = id_map['sm_name'].map(sm_lincs_id)
id_map['SMILES'] = id_map['sm_name'].map(sm_name_to_smiles)

de_train

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,...,ZSWIM5,ZSWIM6,ZSWIM7,ZSWIM8,ZSWIM9,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.104720,-0.077524,-1.625596,-0.144545,0.143555,0.073229,-0.016823,0.101717,-0.005153,1.043629,...,0.299807,0.319123,0.179530,0.220086,-0.206053,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.884380,0.371834,-0.081677,-0.498266,0.203559,0.604656,0.498592,-0.317184,0.375550,...,0.091576,0.717595,1.262570,0.357003,-0.168803,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.704780,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,-0.480681,0.467144,-0.293205,-0.005098,0.214918,...,-0.590645,-0.542832,0.225485,0.131672,-0.393695,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.213550,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,0.718590,-0.162145,0.157206,-3.654218,-0.212402,...,0.760570,-0.217246,-0.203936,2.060546,0.899520,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.224700,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,2.022829,0.600011,1.231275,0.236739,0.338703,...,1.005788,0.106344,-0.145054,0.965736,0.248029,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,LSM-5771,CC(C)c1c(C(=O)Nc2ccccc2)c(-c2ccccc2)c(-c2ccc(F...,False,-0.014372,-0.122464,-0.456366,-0.147894,-0.545382,-0.544709,0.282458,-0.431359,-0.364961,0.043123,...,0.092460,-0.960509,0.000051,-0.626368,-0.261534,-0.549987,-2.200925,0.359806,1.073983,0.356939,-0.029603,-0.528817,0.105138,0.491015,-0.979951
610,NK cells,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,-0.455549,0.188181,0.595734,-0.100299,0.786192,0.090954,0.169523,0.428297,0.106553,0.435088,...,0.883842,0.611697,-0.538152,0.047483,-0.602049,-1.236905,0.003854,-0.197569,-0.175307,0.101391,1.028394,0.034144,-0.231642,1.023994,-0.064760
611,T cells CD4+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.338168,-0.109079,0.270182,-0.436586,-0.069476,-0.061539,0.002818,-0.027167,-0.383696,0.226289,...,0.169480,-0.084077,0.697416,0.225507,0.063579,0.077579,-1.101637,0.457201,0.535184,-0.198404,-0.005004,0.552810,-0.209077,0.389751,-0.337082
612,T cells CD8+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.101138,-0.409724,-0.606292,-0.071300,-0.001789,-0.706087,-0.620919,-1.485381,0.059303,-0.032584,...,-1.149889,-0.977296,0.369929,0.625152,-0.885209,0.005951,-0.893093,-1.003029,-0.080367,-0.076604,0.024849,0.012862,-0.029684,0.005506,-1.733112


In [12]:
def mrrmse_pd(y_pred: pd.DataFrame, y_true: pd.DataFrame):
	return ((y_pred - y_true)**2).mean(axis=1).apply(np.sqrt).mean()

def mrrmse_np(y_pred, y_true):
	return np.sqrt(np.square(y_true - y_pred).mean()).mean()

In [13]:
def split_sign(text):
    text = text.replace(')(', ' ')
    text = text.replace('(' , ' ')
    text = text.replace(')' , ' ')
    return text.split(" ")

de_train['_SMILES'] = [split_sign(text) for text in de_train['SMILES'].values]

sign = []
for row in de_train['_SMILES'].values:
    for ele in row:
        sign.append(ele)
        
sign_list = list(set(sign))

data = np.zeros((len(de_train), len(sign_list)), dtype=int)
de_features = pd.DataFrame(data=data, columns=sign_list)

for sign in sign_list:
    for i in range(len(de_train)):
        row = de_train['_SMILES'].values[i]

        for ele in row:
            if ele == sign:
                de_features[sign][i] += 1

                
id_map['_SMILES'] = [split_sign(text) for text in id_map['SMILES'].values]

sign = []
for row in id_map['_SMILES'].values:
    for ele in row:
        sign.append(ele)
        
sign_list = list(set(sign))

data = np.zeros((len(id_map), len(sign_list)), dtype=int)
test_features = pd.DataFrame(data=data, columns=sign_list)

for sign in sign_list:
    for i in range(len(id_map)):
        row = id_map['_SMILES'].values[i]

        for ele in row:
            if ele == sign:
                test_features[sign][i] += 1
                
uncommon = [f for f in de_features if f not in test_features]
de_features = de_features.drop(columns=uncommon)

de_features = de_features.sort_index(axis = 1)
test_features = test_features.sort_index(axis = 1)

print("Columns Check", list(de_features.columns) == list(test_features.columns))

Columns Check True


In [14]:
sm_name = pd.get_dummies(de_train['sm_name'], dtype=float)
de_features = pd.concat([sm_name, de_features], axis=1)
de_features

Unnamed: 0,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-6-yl)pyrimidin-2-amine,ABT-199 (GDC-0199),ABT737,AMD-070 (hydrochloride),AT 7867,AT13387,AVL-292,AZ628,AZD-8330,AZD3514,AZD4547,Alogliptin,Alvocidib,Amiodarone,Atorvastatin,...,ncnc3c2,nn12,nn1Cc1ccnc,nnc5C,no1,no2,noc1C,noc4C,o1,oc6,on4,s2,s3,sc2cc,sc3cc
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
610,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
611,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
612,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [15]:
cell_type = pd.get_dummies(de_train['cell_type'], dtype=float)
de_features = pd.concat([cell_type, de_features], axis=1)
de_features

Unnamed: 0,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-6-yl)pyrimidin-2-amine,ABT-199 (GDC-0199),ABT737,AMD-070 (hydrochloride),AT 7867,AT13387,AVL-292,AZ628,AZD-8330,...,ncnc3c2,nn12,nn1Cc1ccnc,nnc5C,no1,no2,noc1C,noc4C,o1,oc6,on4,s2,s3,sc2cc,sc3cc
0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
610,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
611,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
612,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [65]:
sm_name

Unnamed: 0,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-6-yl)pyrimidin-2-amine,ABT-199 (GDC-0199),ABT737,AMD-070 (hydrochloride),AT 7867,AT13387,AVL-292,AZ628,AZD-8330,AZD3514,AZD4547,Alogliptin,Alvocidib,Amiodarone,Atorvastatin,...,TR-14035,Tacalcitol,Tamatinib,Tipifarnib,Tivantinib,Tivozanib,Topotecan,Tosedostat,Trametinib,UNII-BXU45ZH6LI,Vandetanib,Vanoxerine,Vardenafil,Vorinostat,YK 4-279
0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
251,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
252,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
253,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [16]:
sm_name = sm_name.iloc[:len(id_map)]
sm_name.iloc[:, :] = 0.0
for i, item in enumerate(id_map['sm_name']):
    if item in sm_name.columns:
        sm_name.loc[i, item] = 1.0
test_features = pd.concat([sm_name, test_features], axis=1)

cell_type = cell_type.iloc[:len(id_map)]
cell_type.iloc[:, :] = 0.0
cell_type.iloc[:, :2] = pd.get_dummies(id_map['cell_type'], dtype=float)
test_features = pd.concat([cell_type, test_features], axis=1)
test_features

Unnamed: 0,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-6-yl)pyrimidin-2-amine,ABT-199 (GDC-0199),ABT737,AMD-070 (hydrochloride),AT 7867,AT13387,AVL-292,AZ628,AZD-8330,...,ncnc3c2,nn12,nn1Cc1ccnc,nnc5C,no1,no2,noc1C,noc4C,o1,oc6,on4,s2,s3,sc2cc,sc3cc
0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
251,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
252,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
253,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [17]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def convert_to_image_format(data, output_size=(224, 224)):
    square_side = int(math.ceil(math.sqrt(data.shape[1])))
    padding_size = square_side ** 2 - data.shape[1]
    
    data_padded = np.pad(data, ((0, 0), (0, padding_size)), 'constant', constant_values=0)
    
    # We are reshaping to [N, H, W, C] because the final torch tensor needs to be [N, C, H, W]
    data_reshaped = data_padded.reshape(-1, square_side, square_side, 1)
    
    # Expand the last dimension to three channels by repeating the data
    data_rgb = np.repeat(data_reshaped, 3, -1)
    
    return data_rgb

X_scaler = StandardScaler()
y_scaler = StandardScaler()

# Assuming de_features and genes are defined elsewhere in your script
X_scaled = X_scaler.fit_transform(de_features.values)
y_scaled = y_scaler.fit_transform(de_train[genes].values)

# Convert scaled data to image format
X_img = convert_to_image_format(X_scaled)

# Convert to tensors
X_tensor = torch.tensor(X_img, dtype=torch.float32).permute(0, 3, 1, 2)
y_tensor = torch.tensor(y_scaled, dtype=torch.float32)

# KFold setup
num_epochs = 100  # Define the number of epochs

X_tensor, cv_X_tensor, y_tensor, cv_y_tensor = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=6174)


In [18]:
from torchvision import models

models.resnext101_32x8d(pretrained=True)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1

In [19]:
%%time

# Define the model architecture that corresponds to the checkpoints
def get_model():
    model = models.regnet_y_800mf(pretrained=True)
    
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, len(genes))
    
    # num_ftrs = model.classifier[2].in_features
    # model.classifier[2] = nn.Linear(num_ftrs, len(genes))
    
    # num_ftrs = model.classifier.in_features
    # model.classifier = nn.Linear(num_ftrs, len(genes))
    
    
    return model.cuda()

kf = KFold(n_splits=5, shuffle=True, random_state=6174)
for fold, (train_index, val_index) in enumerate(kf.split(X_tensor)):
    print(f'Fold {fold+1}')

    # Split data into training and validation sets
    X_train_tensor, y_train_tensor = X_tensor[train_index], y_tensor[train_index]
    X_val_tensor, y_val_tensor = X_tensor[val_index], y_tensor[val_index]
    
    # Move data to GPU
    X_train_tensor, y_train_tensor = X_train_tensor.cuda(), y_train_tensor.cuda()
    X_val_tensor, y_val_tensor = X_val_tensor.cuda(), y_val_tensor.cuda()

    # DataLoader setup
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)

    # Model, loss, optimizer setup
    model = get_model()
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # Training loop for each fold
    best_val_loss = float('inf')
    early_stopping_patience = 10
    early_stopping_counter = 0
    gradient_clip_value = 1
    
    for epoch in range(num_epochs):
        model.train()
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_value)  # Gradient clipping
            optimizer.step()

        # Validation loop
        model.eval()
        valid_loss_squared_sum = 0  # For MRRMSE

        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                valid_loss_squared_sum += loss.item() * inputs.size(0)  # Sum of squared errors
            valid_mrrmse = np.sqrt(valid_loss_squared_sum / len(val_loader.dataset))
            print(f'Epoch {epoch}, MRRMSE: {valid_mrrmse}')

            # Update the learning rate scheduler based on validation loss
            scheduler.step(valid_mrrmse)

            if valid_mrrmse < best_val_loss:
                best_val_loss = valid_mrrmse
                early_stopping_counter = 0
                torch.save(model.state_dict(), f'best_model_fold_{fold+1}.pth')  # Save best model checkpoint
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping triggered")
                    print()
                    break
    print()


Fold 1
Epoch 0, MRRMSE: 1.0047153499962092
Epoch 1, MRRMSE: 1.0084927101677392
Epoch 2, MRRMSE: 1.0426744226083833
Epoch 3, MRRMSE: 1.23896465885814
Epoch 4, MRRMSE: 1.1638127635054814
Epoch 5, MRRMSE: 1.4890449759182902
Epoch 6, MRRMSE: 0.9906527371697679
Epoch 7, MRRMSE: 1.059663677998504
Epoch 8, MRRMSE: 0.9552710828090751
Epoch 9, MRRMSE: 0.9986350581653908
Epoch 10, MRRMSE: 1.5427229646140297
Epoch 11, MRRMSE: 1.2187074205755466
Epoch 12, MRRMSE: 1.4005966736794389
Epoch 13, MRRMSE: 1.069363867032207
Epoch 14, MRRMSE: 1.0403135234128698
Epoch 00015: reducing learning rate of group 0 to 1.0000e-04.
Epoch 15, MRRMSE: 1.0790136100145296
Epoch 16, MRRMSE: 1.015792021841304
Epoch 17, MRRMSE: 0.990410492513625
Epoch 18, MRRMSE: 1.102805383445942
Early stopping triggered


Fold 2
Epoch 0, MRRMSE: 0.8553150788339188
Epoch 1, MRRMSE: 0.8402973896161112
Epoch 2, MRRMSE: 0.8196683777440318
Epoch 3, MRRMSE: 0.8879514396524437
Epoch 4, MRRMSE: 0.9357821670989924
Epoch 5, MRRMSE: 1.019059323812

In [None]:
from sklearn.metrics import mean_squared_error


def predict(model, X_data, X_scaler=None, y_scaler=None, is_image_format=False):
    """
    Preprocess the features if needed, predict using the trained model, and inverse transform the predictions.
    """
    if X_scaler is not None and not is_image_format:
        X_scaled = X_scaler.transform(X_data)
        X_img = convert_to_image_format(X_scaled)
        X_tensor = torch.tensor(X_img, dtype=torch.float32).permute(0, 3, 1, 2).cuda()
    else:
        X_tensor = X_data

    model.eval()
    with torch.no_grad():
        test_outputs = model(X_tensor)
        
    test_predictions = test_outputs.cpu().numpy()
    if y_scaler is not None:
        test_predictions = y_scaler.inverse_transform(test_predictions)
    
    return test_predictions

ensemble_cv_predictions = []
ensemble_sub_predictions = []

# Load models and collect their predictions for the validation and test set
for fold in range(1, 6):
    # Load the pre-trained model
    model = get_model()
    checkpoint_path = f'best_model_fold_{fold}.pth'
    model.load_state_dict(torch.load(checkpoint_path))
    model = model.cuda()
    
    # Make predictions on the validation set
    # X_val_tensor is already in image format, no need for scaling
    val_predictions = predict(model, X_val_tensor, is_image_format=True)
    ensemble_cv_predictions.append(val_predictions)
    
    # Make predictions on the test set
    # Assuming test_features is in the original format and needs scaling and conversion
    test_predictions = predict(model, test_features, X_scaler, y_scaler)
    ensemble_sub_predictions.append(test_predictions)
    
# Calculate the average prediction across all folds
ensemble_cv_predictions = np.mean(ensemble_cv_predictions, axis=0)
ensemble_sub_predictions = np.mean(ensemble_sub_predictions, axis=0)

# Assign ensemble predictions to the submission DataFrame
id_map.loc[:, genes] = ensemble_sub_predictions
id_map = id_map.loc[:, ["id"] + genes.to_list()]
id_map.to_csv('submission.csv', index=False)

# We can also calculate the ensemble CV score if 'y_val_tensor' holds the true validation labels
ensemble_mrrmse = np.sqrt(mean_squared_error(y_val_tensor.cpu().numpy(), ensemble_cv_predictions))
print(f'Ensemble CV MRRMSE: {ensemble_mrrmse}')

Ensemble CV MRRMSE: 1.100290298461914
