In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Basic utilities
import os
import gc
import glob
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

# Data visualization
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
# import plotly.figure_factory as ff
# import plotly.express as px

# Scientific computing
from scipy import stats
from itertools import groupby

# Machine Learning
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import TruncatedSVD
from sklearn.svm import LinearSVR
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import LinearSVR
from sklearn.multioutput import MultiOutputRegressor
from sklearn.impute import SimpleImputer
from sklearn.decomposition import PCA


# Set the folder path for data
folder_path = "./input"


In [2]:
import warnings
warnings.simplefilter('ignore')

import pandas as pd

pd.set_option('display.max_columns', 30)

import numpy as np

SEED = 6174
np.random.seed(SEED)

In [3]:
de_train = pd.read_parquet('/notebooks/input/de_train.parquet')
genes = de_train.columns[5:]
id_map = pd.read_csv ('/notebooks/input/id_map.csv')

sm_lincs_id = de_train.set_index('sm_name')["sm_lincs_id"].to_dict()
sm_name_to_smiles = de_train.set_index('sm_name')['SMILES'].to_dict()

id_map['sm_lincs_id'] = id_map['sm_name'].map(sm_lincs_id)
id_map['SMILES'] = id_map['sm_name'].map(sm_name_to_smiles)

de_train

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,...,ZSWIM5,ZSWIM6,ZSWIM7,ZSWIM8,ZSWIM9,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.104720,-0.077524,-1.625596,-0.144545,0.143555,0.073229,-0.016823,0.101717,-0.005153,1.043629,...,0.299807,0.319123,0.179530,0.220086,-0.206053,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.884380,0.371834,-0.081677,-0.498266,0.203559,0.604656,0.498592,-0.317184,0.375550,...,0.091576,0.717595,1.262570,0.357003,-0.168803,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.704780,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,-0.480681,0.467144,-0.293205,-0.005098,0.214918,...,-0.590645,-0.542832,0.225485,0.131672,-0.393695,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.213550,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,0.718590,-0.162145,0.157206,-3.654218,-0.212402,...,0.760570,-0.217246,-0.203936,2.060546,0.899520,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.224700,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,2.022829,0.600011,1.231275,0.236739,0.338703,...,1.005788,0.106344,-0.145054,0.965736,0.248029,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,LSM-5771,CC(C)c1c(C(=O)Nc2ccccc2)c(-c2ccccc2)c(-c2ccc(F...,False,-0.014372,-0.122464,-0.456366,-0.147894,-0.545382,-0.544709,0.282458,-0.431359,-0.364961,0.043123,...,0.092460,-0.960509,0.000051,-0.626368,-0.261534,-0.549987,-2.200925,0.359806,1.073983,0.356939,-0.029603,-0.528817,0.105138,0.491015,-0.979951
610,NK cells,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,-0.455549,0.188181,0.595734,-0.100299,0.786192,0.090954,0.169523,0.428297,0.106553,0.435088,...,0.883842,0.611697,-0.538152,0.047483,-0.602049,-1.236905,0.003854,-0.197569,-0.175307,0.101391,1.028394,0.034144,-0.231642,1.023994,-0.064760
611,T cells CD4+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.338168,-0.109079,0.270182,-0.436586,-0.069476,-0.061539,0.002818,-0.027167,-0.383696,0.226289,...,0.169480,-0.084077,0.697416,0.225507,0.063579,0.077579,-1.101637,0.457201,0.535184,-0.198404,-0.005004,0.552810,-0.209077,0.389751,-0.337082
612,T cells CD8+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23...,False,0.101138,-0.409724,-0.606292,-0.071300,-0.001789,-0.706087,-0.620919,-1.485381,0.059303,-0.032584,...,-1.149889,-0.977296,0.369929,0.625152,-0.885209,0.005951,-0.893093,-1.003029,-0.080367,-0.076604,0.024849,0.012862,-0.029684,0.005506,-1.733112


In [4]:
def mrrmse_pd(y_pred: pd.DataFrame, y_true: pd.DataFrame):
	return ((y_pred - y_true)**2).mean(axis=1).apply(np.sqrt).mean()

def mrrmse_np(y_pred, y_true):
	return np.sqrt(np.square(y_true - y_pred).mean()).mean()

In [5]:
def split_sign(text):
    text = text.replace(')(', ' ')
    text = text.replace('(' , ' ')
    text = text.replace(')' , ' ')
    return text.split(" ")

de_train['_SMILES'] = [split_sign(text) for text in de_train['SMILES'].values]

sign = []
for row in de_train['_SMILES'].values:
    for ele in row:
        sign.append(ele)
        
sign_list = list(set(sign))

data = np.zeros((len(de_train), len(sign_list)), dtype=int)
de_features = pd.DataFrame(data=data, columns=sign_list)

for sign in sign_list:
    for i in range(len(de_train)):
        row = de_train['_SMILES'].values[i]

        for ele in row:
            if ele == sign:
                de_features[sign][i] += 1

                
id_map['_SMILES'] = [split_sign(text) for text in id_map['SMILES'].values]

sign = []
for row in id_map['_SMILES'].values:
    for ele in row:
        sign.append(ele)
        
sign_list = list(set(sign))

data = np.zeros((len(id_map), len(sign_list)), dtype=int)
test_features = pd.DataFrame(data=data, columns=sign_list)

for sign in sign_list:
    for i in range(len(id_map)):
        row = id_map['_SMILES'].values[i]

        for ele in row:
            if ele == sign:
                test_features[sign][i] += 1
                
uncommon = [f for f in de_features if f not in test_features]
de_features = de_features.drop(columns=uncommon)

de_features = de_features.sort_index(axis = 1)
test_features = test_features.sort_index(axis = 1)

print("Columns Check", list(de_features.columns) == list(test_features.columns))

Columns Check True


In [6]:
cell_type = pd.get_dummies(de_train['cell_type'], dtype=float)
de_features = pd.concat([cell_type, de_features], axis=1)
de_features

Unnamed: 0,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells,-c1ccc,-c1cccc,-c2[nH]c,-c2[nH]nc3ccc,-c2cc,-c2cc3c,-c2cc3nccc,-c2cc3nccn3c,-c2ccc,...,ncnc3c2,nn12,nn1Cc1ccnc,nnc5C,no1,no2,noc1C,noc4C,o1,oc6,on4,s2,s3,sc2cc,sc3cc
0,0.0,0.0,1.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,0.0,0.0,0.0,1.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,0.0,0.0,0.0,0.0,0.0,1.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,0.0,0.0,1.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,0.0,0.0,0.0,0.0,0.0,1.0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
610,0.0,0.0,1.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
611,0.0,0.0,0.0,1.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
612,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [7]:
cell_type = cell_type.iloc[:len(id_map)]
cell_type.iloc[:, :] = 0.0
cell_type.iloc[:, :2] = pd.get_dummies(id_map['cell_type'], dtype=float)
test_features = pd.concat([cell_type, test_features], axis=1)
test_features

Unnamed: 0,B cells,Myeloid cells,NK cells,T cells CD4+,T cells CD8+,T regulatory cells,-c1ccc,-c1cccc,-c2[nH]c,-c2[nH]nc3ccc,-c2cc,-c2cc3c,-c2cc3nccc,-c2cc3nccn3c,-c2ccc,...,ncnc3c2,nn12,nn1Cc1ccnc,nnc5C,no1,no2,noc1C,noc4C,o1,oc6,on4,s2,s3,sc2cc,sc3cc
0,1.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,1.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,1.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,1.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,1.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,0.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
251,0.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
252,0.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
253,0.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [31]:
%%time

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def convert_to_image_format(data, output_size=(224, 224)):
    square_side = int(math.ceil(math.sqrt(data.shape[1])))
    padding_size = square_side ** 2 - data.shape[1]
    
    data_padded = np.pad(data, ((0, 0), (0, padding_size)), 'constant', constant_values=0)
    
    # We are reshaping to [N, H, W, C] because the final torch tensor needs to be [N, C, H, W]
    data_reshaped = data_padded.reshape(-1, square_side, square_side, 1)
    
    # Expand the last dimension to three channels by repeating the data
    data_rgb = np.repeat(data_reshaped, 3, -1)
    
    return data_rgb

# Splitting the dataset
X_train, X_val, y_train, y_val = train_test_split(de_features.values, de_train[genes].values, test_size=0.2, random_state=42)

# Scaling the data
X_scaler = StandardScaler()
y_scaler = StandardScaler()
X_train_scaled = X_scaler.fit_transform(X_train)
X_val_scaled = X_scaler.transform(X_val)
y_train_scaled = y_scaler.fit_transform(y_train)
y_val_scaled = y_scaler.transform(y_val)

# Convert scaled data to image format
X_train_img = convert_to_image_format(X_train_scaled)
X_val_img = convert_to_image_format(X_val_scaled)

# Convert to tensors and move to GPU
X_train_tensor = torch.tensor(X_train_img, dtype=torch.float32).permute(0, 3, 1, 2).cuda()
y_train_tensor = torch.tensor(y_train_scaled, dtype=torch.float32).cuda()
X_val_tensor = torch.tensor(X_val_img, dtype=torch.float32).permute(0, 3, 1, 2).cuda()
y_val_tensor = torch.tensor(y_val_scaled, dtype=torch.float32).cuda()

# Create DataLoaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Define the model, loss and optimizer
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, y_train.shape[1])
model = model.cuda()
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)  # Using AdamW with weight decay
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# Training loop
best_val_loss = float('inf')
early_stopping_patience = 10
early_stopping_counter = 0
gradient_clip_value = 1  # Clip gradients to avoid exploding gradients

for epoch in range(num_epochs):
    model.train()
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_value)  # Gradient clipping
        optimizer.step()

    # Validation loop
    model.eval()
    valid_loss_squared_sum = 0  # For MRRMSE
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            valid_loss_squared_sum += loss.item() * inputs.size(0)  # Sum of squared errors
        valid_mrrmse = np.sqrt(valid_loss_squared_sum / len(val_loader.dataset))  # RMSE
        print(f'Epoch {epoch}, MRRMSE: {valid_mrrmse}')
        
        # Update the learning rate scheduler based on validation loss
        scheduler.step(valid_mrrmse)
        
        if valid_mrrmse < best_val_loss:
            best_val_loss = valid_mrrmse
            early_stopping_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')  # Save best model checkpoint
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping_patience:
                print("Early stopping triggered")
                break

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

Epoch 0, MRRMSE: 18.794604215934484
Epoch 1, MRRMSE: 4.583258753562378
Epoch 2, MRRMSE: 1.441227093077733
Epoch 3, MRRMSE: 7.309315668640815
Epoch 4, MRRMSE: 1.0850921015826618
Epoch 5, MRRMSE: 1.2169897525992843
Epoch 6, MRRMSE: 0.9305677417385684
Epoch 7, MRRMSE: 0.808390068865227
Epoch 8, MRRMSE: 0.9577035474578622
Epoch 9, MRRMSE: 0.9910323621098723
Epoch 10, MRRMSE: 1.0299940503648881
Epoch 11, MRRMSE: 1.099165720088671
Epoch 12, MRRMSE: 1.1572294110550367
Epoch 13, MRRMSE: 1.9558584923273867
Epoch 00014: reducing learning rate of group 0 to 1.0000e-04.
Epoch 14, MRRMSE: 0.899345089746267
Epoch 15, MRRMSE: 0.920019208666931
Epoch 16, MRRMSE: 1.5214196454758373
Epoch 17, MRRMSE: 1.2153624899326663
Early stopping triggered
CPU times: user 12.7 s, sys: 1.79 s, total: 14.5 s
Wall time: 14.3 s


<All keys matched successfully>

In [None]:
%%time

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def convert_to_image_format(data, output_size=(224, 224)):
    square_side = int(math.ceil(math.sqrt(data.shape[1])))
    padding_size = square_side ** 2 - data.shape[1]
    
    data_padded = np.pad(data, ((0, 0), (0, padding_size)), 'constant', constant_values=0)
    
    # We are reshaping to [N, H, W, C] because the final torch tensor needs to be [N, C, H, W]
    data_reshaped = data_padded.reshape(-1, square_side, square_side, 1)
    
    # Expand the last dimension to three channels by repeating the data
    data_rgb = np.repeat(data_reshaped, 3, -1)
    
    return data_rgb

X_scaler = StandardScaler()
y_scaler = StandardScaler()

# Assuming de_features and genes are defined elsewhere in your script
X_scaled = X_scaler.fit_transform(de_features.values)
y_scaled = y_scaler.fit_transform(de_train[genes].values)

# Convert scaled data to image format
X_img = convert_to_image_format(X_scaled)

# Convert to tensors
X_tensor = torch.tensor(X_img, dtype=torch.float32).permute(0, 3, 1, 2)
y_tensor = torch.tensor(y_scaled, dtype=torch.float32)

# KFold setup
num_epochs = 20  # Define the number of epochs


X_tensor, cv_X_tensor, y_tensor, cv_y_tensor = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=6174)

kf = KFold(n_splits=5, shuffle=True, random_state=6174)
# # KFold Cross-Validation
for fold, (train_index, val_index) in enumerate(kf.split(X_tensor)):
    print(f'Fold {fold+1}')

    # Split data into training and validation sets
    X_train_tensor, y_train_tensor = X_tensor[train_index], y_tensor[train_index]
    X_val_tensor, y_val_tensor = X_tensor[val_index], y_tensor[val_index]
    
    # Move data to GPU
    X_train_tensor, y_train_tensor = X_train_tensor.cuda(), y_train_tensor.cuda()
    X_val_tensor, y_val_tensor = X_val_tensor.cuda(), y_val_tensor.cuda()

#     # DataLoader setup
#     train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
#     val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
#     train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
#     val_loader = DataLoader(val_dataset, batch_size=32)

#     # Model, loss, optimizer setup
#     model = models.efficientnet_v2_l(pretrained=True)
#     num_ftrs = model.fc.in_features
#     model.fc = nn.Linear(num_ftrs, y_train_tensor.shape[1])
#     model = model.cuda()
#     criterion = nn.MSELoss()
#     optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

#     # Training loop for each fold
#     best_val_loss = float('inf')
#     early_stopping_patience = 10
#     early_stopping_counter = 0
#     gradient_clip_value = 1
    
#     for epoch in range(num_epochs):
#         model.train()
#         for inputs, targets in train_loader:
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, targets)
#             loss.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_value)  # Gradient clipping
#             optimizer.step()

#         # Validation loop
#         model.eval()
#         valid_loss_squared_sum = 0  # For MRRMSE

#         with torch.no_grad():
#             for inputs, targets in val_loader:
#                 outputs = model(inputs)
#                 loss = criterion(outputs, targets)
#                 valid_loss_squared_sum += loss.item() * inputs.size(0)  # Sum of squared errors
#             valid_mrrmse = np.sqrt(valid_loss_squared_sum / len(val_loader.dataset))  # RMSE
#             print(f'Epoch {epoch}, MRRMSE: {valid_mrrmse}')

#             # Update the learning rate scheduler based on validation loss
#             scheduler.step(valid_mrrmse)

#             if valid_mrrmse < best_val_loss:
#                 best_val_loss = valid_mrrmse
#                 early_stopping_counter = 0
#                 torch.save(model.state_dict(), 'best_model.pth')  # Save best model checkpoint
#             else:
#                 early_stopping_counter += 1
#                 if early_stopping_counter >= early_stopping_patience:
#                     print("Early stopping triggered")
#                     break

#     # Load the best model
#     model.load_state_dict(torch.load(f'best_model_fold_{fold+1}.pth'))

In [33]:
def predict(test_features, model, X_scaler, y_scaler):
    # Preprocess the test features
    X_test_scaled = X_scaler.transform(test_features)
    X_test_img = convert_to_image_format(X_test_scaled)
    X_test_tensor = torch.tensor(X_test_img, dtype=torch.float32).permute(0, 3, 1, 2).cuda()
    
    # Prediction
    model.eval()
    with torch.no_grad():
        test_outputs = model(X_test_tensor)
        
    # Inverse transform the predictions
    test_predictions = test_outputs.cpu().numpy()
    test_predictions = y_scaler.inverse_transform(test_predictions)
    
    return test_predictions

# Load the best model (assuming it's already saved as 'best_model.pth')
model.load_state_dict(torch.load('best_model.pth'))

# Assuming test_features is your test dataset ready for prediction
test_predictions = predict(test_features, model, X_scaler, y_scaler)
id_map.loc[:, genes] = test_predictions
id_map = id_map.loc[:, genes]
id_map.to_csv('submission.csv')
id_map

Unnamed: 0,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,AAMDC,AAMP,AAR2,AARS,AARS2,...,ZSWIM5,ZSWIM6,ZSWIM7,ZSWIM8,ZSWIM9,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,0.595992,0.368350,0.069918,0.727424,1.306816,1.020369,-0.032189,0.541882,0.139095,0.078677,0.314860,0.370876,0.111040,0.530881,0.163642,...,0.795830,0.171355,0.273660,0.351811,0.142568,-0.026560,0.139775,0.034018,0.307761,0.609765,0.623012,0.170996,0.071483,-0.045470,-0.093183
1,1.635368,0.611189,1.960779,0.800416,11.345181,0.024531,0.188967,-1.058464,-0.677934,0.343438,2.606630,0.316648,0.807392,-1.663576,-1.104179,...,5.414925,2.199144,0.527810,1.591007,0.310127,2.109355,1.346576,1.205944,-1.407982,4.916261,1.752581,0.927014,-0.513901,-3.110033,-0.967385
2,0.320563,0.262383,-0.034656,0.525288,0.821121,0.283104,-0.108166,0.307785,0.060054,0.176851,0.210108,0.382287,0.004322,0.599775,-0.088725,...,0.158864,0.029954,0.326471,0.278313,0.009479,-0.071973,0.007866,-0.061243,0.100344,0.227242,0.359577,0.168527,0.053745,0.056484,-0.120918
3,0.305882,0.227779,-0.011662,0.602747,0.871729,0.292500,-0.117399,0.385910,0.053983,0.160335,0.268133,0.428944,-0.007553,0.439565,-0.091371,...,0.160886,0.000003,0.286655,0.335148,0.076813,-0.070176,0.045169,-0.127068,0.095681,0.232718,0.270746,0.101425,0.038231,0.069187,-0.152434
4,0.298770,0.224731,-0.063120,0.609347,0.916080,0.180063,-0.134084,0.330006,0.029818,0.204493,0.242071,0.449188,-0.039236,0.417514,-0.132584,...,0.110121,-0.036949,0.306736,0.307262,0.058896,-0.067635,0.023876,-0.157682,0.020833,0.193832,0.224724,0.091970,0.049406,0.090982,-0.157920
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,0.140937,0.091469,-0.062688,0.560157,0.940850,0.003877,-0.137158,0.281616,0.026833,0.217041,0.207183,0.473544,-0.073179,0.232911,-0.177768,...,0.086551,-0.110465,0.231061,0.262388,0.081440,-0.078205,-0.060714,-0.267922,0.026968,0.144325,0.096390,0.014342,-0.050969,0.074711,-0.241556
251,0.359244,0.242999,0.034653,0.679405,0.961106,0.481341,-0.090719,0.445232,0.076480,0.112944,0.221465,0.393391,0.030289,0.453603,0.014161,...,0.310344,0.058815,0.274973,0.339171,0.087351,-0.055950,0.076963,-0.089919,0.161142,0.342803,0.403051,0.113434,0.044976,0.017389,-0.134055
252,0.255392,0.251837,-0.046018,0.504110,0.686087,0.222824,-0.130978,0.284215,0.047566,0.198109,0.210456,0.394679,-0.016577,0.537228,-0.156593,...,0.080824,0.012590,0.332827,0.310022,0.004189,-0.089451,-0.008581,-0.059736,0.088191,0.179826,0.283482,0.159825,0.051516,0.068451,-0.135157
253,0.343447,0.234327,0.017031,0.575445,0.853222,0.413632,-0.108705,0.355444,0.061096,0.155817,0.260164,0.409428,0.017911,0.589433,-0.025048,...,0.078589,0.066642,0.333000,0.313724,0.026045,-0.057731,0.044345,-0.050040,0.085697,0.233773,0.375856,0.141162,0.088335,0.052728,-0.107084
