In [81]:
%cd ~/SSMuLA

/disk2/fli/SSMuLA


In [5]:
%load_ext autoreload
%autoreload 2
%load_ext blackcellmagic

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The blackcellmagic extension is already loaded. To reload it, use:
  %reload_ext blackcellmagic


In [8]:
from SSMuLA.get_corr import LANDSCAPE_ATTRIBUTES, val_list

In [72]:
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from matplotlib.colors import LinearSegmentedColormap, to_rgb
import matplotlib.pyplot as plt
import seaborn as sns

In [80]:
len(LANDSCAPE_ATTRIBUTES)

33

In [15]:
# Function to apply different gradients
def apply_gradient(s, colormap='YlGnBu', mse_colormap='coolwarm'):
    # Copy to avoid modifying the original data
    is_mse = s.name == 'mse'
    if is_mse:
        return [f'background-color: {x}' for x in sns.color_palette(mse_colormap, len(s))]
    else:
        return [f'background-color: {x}' for x in sns.color_palette(colormap, len(s))]


In [78]:
# Custom colormap for the MSE row, using greens
colors = ["#FFFFFF", "#9bbb59"]  # dark to light green
cmap_mse = LinearSegmentedColormap.from_list("mse_cmap_r", colors[::-1], N=100)

# def text_color(val):
#     rgb = to_rgb(cmap_mse(val))
#     # Perceived luminance formula: 0.299*R + 0.587*G + 0.114*B
#     luminance = 0.299*rgb[0] + 0.587*rgb[1] + 0.114*rgb[2]
#     return 'white' if luminance < 0.5 else 'black'

# Styling the DataFrame
def style_dataframe(df):
    # Define a function to apply gradient selectively
    def apply_gradient(row):
        if row.name == 'mse':
            # Generate colors for the MSE row based on its values
            norm = plt.Normalize(row.min(), row.max())
            rgba_colors = [cmap_mse(norm(value)) for value in row]
            return [f'background-color: rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]})' for rgba in rgba_colors]
        else:
            return [''] * len(row)  # No style for other rows
    
    # Apply gradient across all rows
    styled_df = df.style.background_gradient(cmap='Blues')
    # Apply the custom gradient to the MSE row
    styled_df = styled_df.apply(apply_gradient, axis=1)
    return styled_df.format("{:.2f}").apply(lambda x: ['color: black' if x.name == 'mse' else '' for _ in x], axis=1)


In [88]:
df = pd.read_csv("results/corr/384/boosting|ridge-top96/merge_all.csv")

# Load your dataset
# data = pd.read_csv('path_to_your_data.csv')

# Select features and targets
features = df[LANDSCAPE_ATTRIBUTES]
targets = df[val_list]

importance_df_list = []

# Splitting the dataset for each target and fitting a model
for target in targets.columns:
    X_train, X_test, y_train, y_test = train_test_split(features, targets[target], test_size=0.2, random_state=42)
    
    # Model initialization and training
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    
    # Prediction and performance evaluation
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    # print(f'Model for {target}, Mean Squared Error: {mse}')
    
    # Feature importance
    feature_importances = pd.DataFrame(model.feature_importances_,
                                       index = X_train.columns,
                                       columns=[target])
    feature_importances.loc["mse"] = mse
    importance_df_list.append(feature_importances)
importance_df = pd.concat(importance_df_list, axis=1)

In [79]:
style_dataframe(importance_df)

Unnamed: 0,single_step_DE_mean_all,single_step_DE_median_all,single_step_DE_mean_top96,single_step_DE_median_top96,single_step_DE_mean_top384,single_step_DE_median_top384,single_step_DE_fraction_max,recomb_SSM_mean_all,recomb_SSM_median_all,recomb_SSM_mean_top96,recomb_SSM_median_top96,recomb_SSM_mean_top384,recomb_SSM_median_top384,recomb_SSM_fraction_max,top96_SSM_mean_all,top96_SSM_median_all,top96_SSM_mean_top96,top96_SSM_median_top96,top96_SSM_mean_top384,top96_SSM_median_top384,top96_SSM_fraction_max,top_maxes,top_means,ndcgs,rhos,if_truemaxs,maxes_Triad,means_Triad,ndcgs_Triad,rhos_Triad,if_truemaxs_Triad,maxes_ev,means_ev,ndcgs_ev,rhos_ev,if_truemaxs_ev,maxes_esm,means_esm,ndcgs_esm,rhos_esm,if_truemaxs_esm,maxes_esmif,means_esmif,ndcgs_esmif,rhos_esmif,if_truemaxs_esmif,mlde_single_step_DE_delta,mlde_recomb_SSM_delta,mlde_top96_SSM_delta,Triad_single_step_DE_delta,Triad_recomb_SSM_delta,Triad_top96_SSM_delta,ev_single_step_DE_delta,ev_recomb_SSM_delta,ev_top96_SSM_delta,esm_single_step_DE_delta,esm_recomb_SSM_delta,esm_top96_SSM_delta,esmif_single_step_DE_delta,esmif_recomb_SSM_delta,esmif_top96_SSM_delta,delta_ft_mlde,delta_ft_de
n_site,0.0,0.0,0.0,0.0,0.0,0.0,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.03,0.0,0.0,0.0,0.0,0.01,0.0,0.0,0.0,0.0,0.04,0.0,0.0,0.01,0.0,0.04,0.0,0.0,0.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
numb_measured,0.04,0.02,0.01,0.01,0.0,0.01,0.17,0.25,0.19,0.0,0.01,0.0,0.02,0.03,0.11,0.01,0.02,0.02,0.01,0.02,0.25,0.2,0.02,0.02,0.01,0.24,0.13,0.01,0.02,0.01,0.34,0.16,0.01,0.02,0.02,0.48,0.33,0.02,0.02,0.01,0.29,0.15,0.01,0.03,0.02,0.49,0.05,0.01,0.01,0.0,0.03,0.01,0.01,0.0,0.02,0.01,0.0,0.01,0.01,0.01,0.01,0.02,0.0
percent_measured,0.08,0.12,0.11,0.09,0.07,0.09,0.01,0.03,0.08,0.1,0.11,0.14,0.13,0.1,0.07,0.11,0.15,0.15,0.15,0.15,0.06,0.08,0.03,0.04,0.0,0.04,0.14,0.02,0.08,0.01,0.01,0.11,0.03,0.05,0.01,0.01,0.08,0.01,0.02,0.02,0.01,0.12,0.03,0.04,0.04,0.01,0.02,0.01,0.02,0.03,0.02,0.01,0.01,0.0,0.01,0.01,0.01,0.01,0.02,0.02,0.02,0.07,0.01
numb_active,0.04,0.04,0.08,0.08,0.04,0.08,0.02,0.02,0.02,0.1,0.04,0.09,0.02,0.12,0.02,0.04,0.06,0.04,0.07,0.04,0.02,0.03,0.02,0.02,0.01,0.03,0.05,0.02,0.07,0.02,0.11,0.04,0.02,0.05,0.01,0.05,0.03,0.03,0.12,0.0,0.08,0.04,0.03,0.05,0.01,0.07,0.02,0.0,0.02,0.01,0.0,0.0,0.02,0.01,0.0,0.02,0.01,0.01,0.02,0.0,0.0,0.04,0.0
percent_active,0.03,0.02,0.05,0.04,0.03,0.04,0.01,0.06,0.03,0.07,0.04,0.05,0.03,0.05,0.01,0.02,0.02,0.03,0.02,0.03,0.01,0.0,0.06,0.02,0.05,0.01,0.03,0.06,0.0,0.05,0.0,0.02,0.06,0.03,0.06,0.01,0.03,0.04,0.03,0.03,0.0,0.04,0.03,0.01,0.03,0.0,0.01,0.0,0.01,0.02,0.01,0.0,0.02,0.0,0.01,0.02,0.01,0.02,0.02,0.01,0.01,0.01,0.01
active_fit_min,0.02,0.02,0.01,0.0,0.0,0.0,0.07,0.02,0.02,0.01,0.01,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.07,0.0,0.03,0.03,0.02,0.07,0.01,0.04,0.01,0.01,0.01,0.04,0.03,0.01,0.02,0.01,0.02,0.03,0.01,0.05,0.02,0.03,0.03,0.04,0.01,0.0,0.0,0.01,0.01,0.01,0.01,0.02,0.01,0.01,0.02,0.01,0.01,0.01,0.01,0.01,0.01,0.03,0.01
parent_fit,0.03,0.06,0.01,0.0,0.0,0.0,0.02,0.01,0.01,0.0,0.02,0.01,0.06,0.01,0.02,0.02,0.02,0.01,0.03,0.01,0.01,0.01,0.01,0.05,0.01,0.02,0.01,0.02,0.08,0.0,0.02,0.0,0.01,0.04,0.0,0.03,0.01,0.01,0.03,0.0,0.02,0.01,0.02,0.08,0.0,0.02,0.04,0.09,0.01,0.02,0.04,0.0,0.01,0.04,0.0,0.0,0.04,0.0,0.0,0.01,0.0,0.02,0.03
parent_rank,0.08,0.12,0.0,0.01,0.05,0.01,0.02,0.04,0.04,0.02,0.15,0.07,0.15,0.05,0.12,0.15,0.16,0.16,0.15,0.16,0.11,0.08,0.01,0.09,0.0,0.05,0.01,0.01,0.18,0.01,0.01,0.01,0.02,0.07,0.0,0.01,0.01,0.02,0.11,0.01,0.01,0.0,0.01,0.09,0.02,0.02,0.0,0.0,0.01,0.01,0.01,0.03,0.0,0.01,0.03,0.0,0.01,0.03,0.0,0.01,0.03,0.11,0.01
parent_rank_percent,0.06,0.04,0.02,0.02,0.03,0.02,0.01,0.02,0.05,0.01,0.08,0.02,0.05,0.03,0.06,0.08,0.1,0.06,0.09,0.06,0.02,0.05,0.01,0.06,0.0,0.03,0.01,0.0,0.06,0.01,0.01,0.02,0.01,0.0,0.0,0.01,0.01,0.01,0.05,0.0,0.0,0.02,0.01,0.05,0.0,0.02,0.0,0.01,0.0,0.01,0.02,0.02,0.01,0.02,0.02,0.01,0.01,0.01,0.01,0.01,0.02,0.05,0.01
mean,0.03,0.03,0.08,0.07,0.0,0.07,0.0,0.04,0.02,0.1,0.03,0.09,0.03,0.03,0.03,0.04,0.03,0.01,0.03,0.01,0.0,0.03,0.05,0.03,0.02,0.0,0.07,0.07,0.0,0.06,0.01,0.02,0.03,0.03,0.05,0.01,0.05,0.05,0.01,0.06,0.01,0.04,0.05,0.0,0.05,0.02,0.0,0.01,0.01,0.02,0.02,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.02,0.01,0.02,0.02,0.02


In [89]:
X_train, X_test

(    n_site  numb_measured  percent_measured  numb_active  percent_active  \
 12     3.0         7784.0         97.300000       2494.0       32.040082   
 5      3.0         7996.0         99.950000         18.0        0.225113   
 8      3.0         7964.0         99.550000        161.0        2.021597   
 2      3.0         7882.0         98.525000       7248.0       91.956356   
 1      3.0         7882.0         98.525000       6533.0       82.885055   
 13     4.0       159129.0         99.455625       9783.0        6.147842   
 4      3.0         7971.0         99.637500         59.0        0.740183   
 7      3.0         7763.0         97.037500        719.0        9.261883   
 10     3.0         7891.0         98.637500        108.0        1.368648   
 3      4.0       149361.0         93.350625      34545.0       23.128528   
 6      3.0         7994.0         99.925000         35.0        0.437828   
 
     active_fit_min  parent_fit  parent_rank  parent_rank_percent      mea

In [90]:
from sklearn.linear_model import LinearRegression

# Train a linear regression model as the surrogate
surrogate_model = LinearRegression()
surrogate_model.fit(X_train, model.predict(X_train))  # Use predictions as the target


In [91]:
# Coefficients and intercept
coefficients = surrogate_model.coef_
intercept = surrogate_model.intercept_

# Create formula
formula = f"single_step_DE_mean_all = {intercept:.4f} + " + " + ".join(f"{coef:.4f}*{name}" for coef, name in zip(coefficients, LANDSCAPE_ATTRIBUTES))
print("Approximation Formula:\n", formula)


Approximation Formula:
 single_step_DE_mean_all = -2.5560 + -0.0000*n_site + -0.0000*numb_measured + 0.0377*percent_measured + 0.0003*numb_active + -0.0251*percent_active + -0.0129*active_fit_min + 0.0107*parent_fit + -0.0015*parent_rank + -0.0000*parent_rank_percent + 0.0026*mean + 0.0019*std + -0.0197*range + -0.0030*iqr + -0.0040*std_dev + -0.0004*variance + 0.0323*skewness + -0.0015*kurt + 0.0017*loc + -0.0020*scale + -0.0248*numb_kde_peak + 0.0044*Q1 + 0.0019*Q2 + 0.0014*Q3 + 0.0129*numb_loc_opt + 0.0000*frac_loc_opt_total + 0.0149*frac_loc_opt_hd2_escape_numb + -0.0149*frac_loc_opt_hd2_cannot_escape_numb + -0.1578*numb_loc_opt_norm_cannot_escape + -0.0000*frac_loc_opt_norm_cannot_escape + 0.0060*fraction_non-magnitude + -0.0011*fraction_reciprocal-sign + -0.0176*norm_non-magnitude + -0.0072*norm_reciprocal-sign


In [92]:
import numpy as np
from sklearn.metrics import r2_score

# Predictions from both models
y_pred_complex = model.predict(X_test)
y_pred_surrogate = surrogate_model.predict(X_test)

# Calculate R-squared score
r2_complex = r2_score(y_test, y_pred_complex)
r2_surrogate = r2_score(y_test, y_pred_surrogate)

print(f"R^2 score for complex model: {r2_complex:.4f}")
print(f"R^2 score for surrogate model: {r2_surrogate:.4f}")


R^2 score for complex model: 0.9224
R^2 score for surrogate model: -10.5631


In [82]:
zs_df = pd.read_csv("results/zs_sum/none/zs_stat_scale2max.csv")
zs_df

Unnamed: 0,lib,n_mut,scale_type,Triad_score,ev_score,esm_score,esmif_score,struc-comb_score,msanoif-comb_score,msa-comb_score,structnmsa-comb_score
0,DHFR,all,max,"{'rho': 0.017800602737355977, 'ndcg': 0.874706...","{'rho': 0.3396500541186067, 'ndcg': 0.94705812...","{'rho': 0.2997147334642927, 'ndcg': 0.93412887...","{'rho': 0.3420102237927515, 'ndcg': 0.94530250...","{'rho': 0.21873645360229155, 'ndcg': 0.9010558...","{'rho': 0.3329620349497645, 'ndcg': 0.94020086...","{'rho': 0.35054254084284997, 'ndcg': 0.9449015...","{'rho': 0.3152181002051771, 'ndcg': 0.92293221..."
1,GB1,all,max,"{'rho': 0.2882819770014877, 'ndcg': 0.76404913...","{'rho': 0.18950853197890194, 'ndcg': 0.7215575...","{'rho': 0.08033543850645883, 'ndcg': 0.7016796...","{'rho': 0.29349648314848165, 'ndcg': 0.7536404...","{'rho': 0.3165424232001679, 'ndcg': 0.76648219...","{'rho': 0.14175871216515606, 'ndcg': 0.7115697...","{'rho': 0.21368724000610942, 'ndcg': 0.7341220...","{'rho': 0.2591091741511955, 'ndcg': 0.74743562..."
2,ParD2,all,max,"{'rho': 0.22903279275812125, 'ndcg': 0.9096686...","{'rho': 0.43950580411448875, 'ndcg': 0.9509634...","{'rho': 0.5086085705240099, 'ndcg': 0.96848213...","{'rho': 0.5474025505415586, 'ndcg': 0.96654986...","{'rho': 0.4307416392960829, 'ndcg': 0.94590129...","{'rho': 0.5060856131857661, 'ndcg': 0.96738107...","{'rho': 0.5578082401554308, 'ndcg': 0.97195573...","{'rho': 0.5462223690973439, 'ndcg': 0.96811505..."
3,ParD3,all,max,"{'rho': 0.2850094413250776, 'ndcg': 0.92655737...","{'rho': 0.46081667207091903, 'ndcg': 0.9712806...","{'rho': 0.5831577250662513, 'ndcg': 0.98204147...","{'rho': 0.6010278304483516, 'ndcg': 0.96236309...","{'rho': 0.48086112698498235, 'ndcg': 0.9456578...","{'rho': 0.5431052882445438, 'ndcg': 0.97931590...","{'rho': 0.6111206470597818, 'ndcg': 0.97596838...","{'rho': 0.6499730615911018, 'ndcg': 0.97073551..."
4,TrpB3A,all,max,"{'rho': -0.09805055619421527, 'ndcg': 0.981718...","{'rho': 0.1309686909154868, 'ndcg': 0.98910504...","{'rho': 0.1992294634923689, 'ndcg': 0.98892047...","{'rho': 0.1848187734802036, 'ndcg': 0.98972581...","{'rho': 0.04240529262441126, 'ndcg': 0.9850867...","{'rho': 0.16904314947804597, 'ndcg': 0.9893348...","{'rho': 0.1802074618735106, 'ndcg': 0.98954086...","{'rho': 0.11820848404960753, 'ndcg': 0.9865950..."
5,TrpB3B,all,max,"{'rho': -0.006996473076675202, 'ndcg': 0.98392...","{'rho': 0.07449505394200119, 'ndcg': 0.9885023...","{'rho': 0.053334641477870905, 'ndcg': 0.988071...","{'rho': 0.03670935228427115, 'ndcg': 0.9888879...","{'rho': 0.01586836931927659, 'ndcg': 0.9852379...","{'rho': 0.07031237885418128, 'ndcg': 0.9886116...","{'rho': 0.06065981334611515, 'ndcg': 0.9885747...","{'rho': 0.04993095890561982, 'ndcg': 0.9877549..."
6,TrpB3C,all,max,"{'rho': 0.05907320438643337, 'ndcg': 0.9846945...","{'rho': 0.11150958079771961, 'ndcg': 0.9843786...","{'rho': 0.06578349740417924, 'ndcg': 0.9844109...","{'rho': 0.1583678560333045, 'ndcg': 0.98520240...","{'rho': 0.12826386284923846, 'ndcg': 0.9852108...","{'rho': 0.09389268150482442, 'ndcg': 0.9844660...","{'rho': 0.12797819590603884, 'ndcg': 0.9850419...","{'rho': 0.11810772451050155, 'ndcg': 0.9850181..."
7,TrpB3D,all,max,"{'rho': 0.20064421832834228, 'ndcg': 0.9742951...","{'rho': 0.2295252291376624, 'ndcg': 0.98052225...","{'rho': 0.22118235783315962, 'ndcg': 0.9797853...","{'rho': 0.18619806055804317, 'ndcg': 0.9774962...","{'rho': 0.2110467820321757, 'ndcg': 0.97756330...","{'rho': 0.2312909551462471, 'ndcg': 0.98054282...","{'rho': 0.22367560909694612, 'ndcg': 0.9797003...","{'rho': 0.22995407699589082, 'ndcg': 0.9792312..."
8,TrpB3E,all,max,"{'rho': -0.018924718611103345, 'ndcg': 0.97891...","{'rho': 0.04401275888150764, 'ndcg': 0.9913618...","{'rho': 0.053774549927072084, 'ndcg': 0.988008...","{'rho': 0.001764806370509995, 'ndcg': 0.986639...","{'rho': -0.01242743092728458, 'ndcg': 0.980301...","{'rho': 0.050418892269663, 'ndcg': 0.990540469...","{'rho': 0.03718295498648975, 'ndcg': 0.9903196...","{'rho': 0.023717092226122675, 'ndcg': 0.983215..."
9,TrpB3F,all,max,"{'rho': 0.0529635806829533, 'ndcg': 0.98353252...","{'rho': 0.10184995204990809, 'ndcg': 0.9884003...","{'rho': 0.135039170949484, 'ndcg': 0.988355109...","{'rho': 0.1033826037236943, 'ndcg': 0.98749508...","{'rho': 0.07752773970749613, 'ndcg': 0.9890600...","{'rho': 0.1184841585649222, 'ndcg': 0.98852585...","{'rho': 0.11544324100515356, 'ndcg': 0.9885128...","{'rho': 0.09939019678891052, 'ndcg': 0.9891668..."
