In [1]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report

# Path to data folder
directory = '../GATSol/dataset/'

# Load CSVs
df_train = pd.read_pickle(directory + 'eSol_train_blosum.pkl')
df_test = pd.read_pickle(directory + 'eSol_test_blosum.pkl')

continuous_column = 'solubility'  # Replace with your target column name
binary_column = 'binary_solubility'

# Identify scalar features
scalar_features = [
    col for col in df_train.columns
    if pd.api.types.is_numeric_dtype(df_train[col]) and col != continuous_column and col != binary_column
]

# Extract scalar features and target for regression and classification
X_train = df_train[scalar_features]
y_train_continuous = df_train[continuous_column]
y_train_binary = df_train[binary_column]
X_test = df_test[scalar_features]
y_test_continuous = df_test[continuous_column]
y_test_binary = df_test[binary_column]

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

In [2]:
X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_train.columns)
X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X_test.columns)
print(X_train_scaled_df.shape)
X_train_scaled_df.head()


(2019, 67)


Unnamed: 0,molecular_weight,aromaticity,gravy,isoelectric_point,length,Hydrophobicity_ARGP820101-G1,Hydrophobicity_ARGP820101-G2,Hydrophobicity_ARGP820101-G3,Hydrophobicity_CASG920101-G1,Hydrophobicity_CASG920101-G2,...,mean_contacts_per_residue,solvent_exposed_fraction,N_atom_type_proportion,C_atom_type_proportion,O_atom_type_proportion,S_atom_type_proportion,polar_exposed_residue_proportion,nonpolar_exposed_residue_proportion,positive_exposed_residue_proportion,negative_exposed_residue_proportion
0,-1.319825,5.116634,4.85512,0.712178,-1.338477,-4.88284,-1.552149,6.088027,-5.171033,2.092709,...,-1.081553,-2.66963,-3.970462,6.26426,-4.504737,0.033367,-2.294576,3.959105,-1.893561,-2.649554
1,2.397372,0.468661,0.3134,1.588408,2.346673,-0.844988,0.240913,0.587177,-0.121754,0.173917,...,0.277,0.396178,-0.10195,0.916193,-0.989096,-0.255787,-0.840569,0.449496,0.49528,-0.368379
2,-0.362025,0.998023,-0.295374,0.082699,-0.394581,-0.497344,-0.433464,0.872151,-0.094455,-0.454403,...,0.37691,0.350622,0.746207,0.25697,-1.166814,0.999747,-0.757406,0.514899,0.484577,-0.457816
3,0.656888,-0.098094,0.154044,-0.652186,0.747137,0.560535,1.190676,-1.625646,-0.487905,1.264258,...,1.112271,0.747016,0.838893,-0.870924,0.436635,-0.19544,0.108377,0.757118,-0.497346,-0.677186
4,0.942892,0.237995,-0.116025,-0.692881,1.00148,0.782001,-0.04787,-0.703667,1.6e-05,-0.060356,...,0.680268,0.88629,0.018547,-0.49395,0.640156,-0.186341,0.916508,0.136768,-0.792905,-0.529838


In [3]:
test_embeddings = df_test[["gene", "embedding", "blosum62_embedding"]]
train_embeddings = df_train[["gene", "embedding", "blosum62_embedding"]]
train_embeddings

Unnamed: 0,gene,embedding,blosum62_embedding
0,aaeX,"[[0.025740903, -0.06068451, 0.04562243, -0.098...","[5, 4, 4, 6, 7, 4, 4, 4, 4, 6, 6, 4, 4, 6, 7, ..."
1,aas,"[[-0.037975986, -0.046647176, -0.014576814, 0....","[5, 4, 6, 4, 6, 6, 5, 6, 4, 9, 5, 4, 4, 7, 5, ..."
2,aat,"[[0.037574537, -0.042643685, 0.05279441, -0.02...","[5, 5, 4, 4, 5, 4, 4, 5, 8, 4, 4, 4, 6, 7, 4, ..."
3,abgA,"[[-0.00765643, -0.13189432, 0.014674729, -0.01...","[5, 5, 4, 4, 6, 5, 6, 4, 6, 4, 4, 4, 7, 5, 4, ..."
4,abgB,"[[0.0036755833, -0.15647556, -0.01813248, 0.04...","[5, 5, 5, 4, 7, 5, 6, 4, 6, 6, 4, 4, 5, 4, 6, ..."
...,...,...,...
2014,zapA,"[[-0.017177405, -0.01794975, 0.029380163, -0.0...","[5, 4, 4, 5, 7, 4, 6, 4, 5, 4, 6, 6, 5, 4, 4, ..."
2015,zntA,"[[-0.021385752, -0.07058593, 0.024256472, 0.03...","[5, 4, 5, 7, 6, 6, 8, 6, 5, 5, 4, 7, 5, 6, 4, ..."
2016,znuA,"[[-0.0033216071, -0.061825093, 0.004684513, 0....","[5, 4, 8, 5, 5, 5, 4, 4, 6, 4, 4, 4, 4, 4, 4, ..."
2017,zur,"[[0.035278082, -0.063037135, -0.0353482, 0.028...","[5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 5, 4, 5, 5, ..."


In [6]:
import numpy as np

# Flatten ESM embeddings
train_embeddings['flattened_embedding'] = train_embeddings['embedding'].apply(lambda x: np.array(x).flatten())
test_embeddings['flattened_embedding'] = test_embeddings['embedding'].apply(lambda x: np.array(x).flatten())

# Verify dimensions
print("Train flattened ESM embedding dimensions:", train_embeddings['flattened_embedding'].apply(len).unique())
print("Test flattened ESM embedding dimensions:", test_embeddings['flattened_embedding'].apply(len).unique())


Train flattened ESM embedding dimensions: [1280]
Test flattened ESM embedding dimensions: [1280]


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_embeddings['flattened_embedding'] = train_embeddings['embedding'].apply(lambda x: np.array(x).flatten())
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_embeddings['flattened_embedding'] = test_embeddings['embedding'].apply(lambda x: np.array(x).flatten())


In [7]:
# Extract BLOSUM embeddings
X_train_blosum = np.stack(train_embeddings['blosum62_embedding'])
X_test_blosum = np.stack(test_embeddings['blosum62_embedding'])

# Verify shapes
print("Train BLOSUM embeddings shape:", X_train_blosum.shape)  # (n_samples, 50)
print("Test BLOSUM embeddings shape:", X_test_blosum.shape)


Train BLOSUM embeddings shape: (2019, 50)
Test BLOSUM embeddings shape: (660, 50)


In [None]:
# Verify scalar feature shapes
print("Train scalar features shape:", X_train_scaled.shape)  # (n_samples, n_scalar_features)
print("Test scalar features shape:", X_test_scaled.shape)


Train scalar features shape: (2019, 67)
Test scalar features shape: (660, 67)


In [9]:
# Extract flattened ESM embeddings
X_train_esm = np.stack(train_embeddings['flattened_embedding'])
X_test_esm = np.stack(test_embeddings['flattened_embedding'])

# Combine scalar features, ESM embeddings, and BLOSUM embeddings
X_train_combined = np.hstack([X_train_scaled, X_train_esm, X_train_blosum])
X_test_combined = np.hstack([X_test_scaled, X_test_esm, X_test_blosum])

# Verify combined feature shapes
print("Combined train features shape:", X_train_combined.shape)
print("Combined test features shape:", X_test_combined.shape)


Combined train features shape: (2019, 1397)
Combined test features shape: (660, 1397)


In [11]:
from xgboost import XGBRegressor
from sklearn.metrics import mean_squared_error, r2_score

# Initialize the XGBoost regressor
xgb_model = XGBRegressor(
    n_estimators=100,       # Number of trees
    learning_rate=0.1,      # Learning rate
    max_depth=5,            # Maximum depth
    random_state=42         # For reproducibility
)

# Train the model
xgb_model.fit(X_train_combined, y_train_continuous)

# Predict on the test set
y_pred = xgb_model.predict(X_test_combined)

# Evaluate the model
mse = mean_squared_error(y_test_continuous, y_pred)
r2 = r2_score(y_test_continuous, y_pred)
print(f"Mean Squared Error: {mse:.4f}")
print(f"R^2 Score: {r2:.4f}")


Mean Squared Error: 0.0524
R^2 Score: 0.4933


In [12]:
# Extract feature importance
feature_importance = xgb_model.feature_importances_

# Create feature names
scalar_feature_names = [f"Scalar_{i}" for i in range(X_train_scaled.shape[1])]
esm_feature_names = [f"ESM_{i}" for i in range(X_train_esm.shape[1])]
blosum_feature_names = [f"BLOSUM_{i}" for i in range(X_train_blosum.shape[1])]

feature_names = scalar_feature_names + esm_feature_names + blosum_feature_names

# Create a DataFrame for interpretability
importance_df = pd.DataFrame({
    "Feature": feature_names,
    "Importance": feature_importance
}).sort_values(by="Importance", ascending=False)

# Display the top 10 most important features
print("\nTop 10 Features by Importance:")
print(importance_df.head(10))



Top 10 Features by Importance:
       Feature  Importance
66   Scalar_66    0.034200
46   Scalar_46    0.020026
51   Scalar_51    0.015239
7     Scalar_7    0.011917
45   Scalar_45    0.010327
474    ESM_407    0.010240
4     Scalar_4    0.008914
61   Scalar_61    0.008854
429    ESM_362    0.008229
44   Scalar_44    0.006623
