In [1]:
%cd ~/SSMuLA

/disk2/fli/SSMuLA


In [2]:
%load_ext blackcellmagic
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [3]:
import pandas as pd

In [4]:

import holoviews as hv
from holoviews import dim


hv.extension("bokeh")

from SSMuLA.landscape_global import LIB_NAMES, TrpB_names
from SSMuLA.vis import (
    save_bokeh_hv,
    JSON_THEME,
    LIB_COLORS,
    one_decimal_x,
    one_decimal_y,
    fixmargins,
)
from SSMuLA.util import get_file_name, checkNgen_folder

hv.renderer("bokeh").theme = JSON_THEME


In [6]:
df = pd.read_csv("results/pairwise_epistasis_vis/none/scale2max.csv")
df

Unnamed: 0,lib,summary_type,epistasis_type,value
0,DHFR,count,magnitude,5.380000e+05
1,DHFR,count,sign,3.690490e+05
2,DHFR,count,reciprocal sign,1.783300e+04
3,DHFR,fraction,magnitude,5.816958e-01
4,DHFR,fraction,sign,3.990228e-01
...,...,...,...,...
67,TrpB4,count,sign,5.249809e+06
68,TrpB4,count,reciprocal sign,1.322707e+06
69,TrpB4,fraction,magnitude,5.391454e-01
70,TrpB4,fraction,sign,3.681085e-01


In [7]:
df[df["summary_type"] == "fraction"]

Unnamed: 0,lib,summary_type,epistasis_type,value
3,DHFR,fraction,magnitude,0.581696
4,DHFR,fraction,sign,0.399023
5,DHFR,fraction,reciprocal sign,0.019281
9,GB1,fraction,magnitude,0.569806
10,GB1,fraction,sign,0.334981
11,GB1,fraction,reciprocal sign,0.095213
15,TrpB3A,fraction,magnitude,0.39678
16,TrpB3A,fraction,sign,0.380682
17,TrpB3A,fraction,reciprocal sign,0.222538
21,TrpB3B,fraction,magnitude,0.457143


In [6]:
from SSMuLA.pairwise_epistasis import EPISTASIS_TYPE

In [9]:
def hook(plot,element):
    plot.handles['plot'].x_range.factors = [(lib, epistasis) for lib in LIB_NAMES for epistasis in EPISTASIS_TYPE]

# Create the Holoviews Bars element
bars = hv.Bars(df[df["summary_type"] == "fraction"], 
               kdims=["lib", "epistasis_type"], 
               vdims="value"
               ).opts(
    width=1200,
    height=400,
    show_legend=True,
    legend_position="top",
    legend_offset=(0, 5),
    ylabel="Fraction",
    multi_level=False,
    title="Fraction of pairwise epistasis types",
    xlabel="Library",
    hooks=[fixmargins, one_decimal_y, hook],
    # x_range_factor = [
    #         (lib, epistasis)
    #         for lib in LIB_NAMES
    #         for epistasis in EPISTASIS_TYPE
    #     ]
        )
bars

In [7]:
import os
from glob import glob

In [11]:
# results/pairwise_epistasis_vis/none/scale2max.csv
# make bar plots base on that and save to the same directory
def hook(plot, element):
    plot.handles["plot"].x_range.factors = [
        (lib, epistasis) for lib in LIB_NAMES for epistasis in EPISTASIS_TYPE
    ]

# Create the Holoviews Bars element
save_bokeh_hv(
    hv.Bars(
        df[df["summary_type"] == "fraction"],
        kdims=["lib", "epistasis_type"],
        vdims="value",
    ).opts(
        width=1200,
        height=400,
        show_legend=True,
        legend_position="top",
        legend_offset=(0, 5),
        ylabel="Fraction",
        multi_level=False,
        title="Fraction of pairwise epistasis types",
        xlabel="Library",
        hooks=[fixmargins, one_decimal_y, hook],
    ),
    plot_name="scale2max",
    plot_path=os.path.join("results/pairwise_epistasis_vis", "none"),
)


In [5]:
import ast

In [14]:
# try zs summary
zs_sum_df = pd.read_csv("results/zs_sum/none/zs_stat_scale2max.csv")
# make the nested dict in zs_sum_df to be flat as columns and make this a melted df
zs_sum_df_melt = zs_sum_df.melt(
    id_vars=["lib", "n_mut"],
    value_vars=["Triad_score", "ev_score", "esm_score"],
    var_name="zs_type",
    value_name="corr",
)

# Concatenate the expanded columns back to the original DataFrame
df_expanded = pd.concat(
    [
        zs_sum_df_melt.drop("corr", axis=1),
        zs_sum_df_melt["corr"].apply(ast.literal_eval).apply(pd.Series),
    ],
    axis=1,
)
df_expanded

Unnamed: 0,lib,n_mut,zs_type,rho,ndcg,rocauc
0,DHFR,all,Triad_score,-0.051567,0.668045,0.451771
1,GB1,all,Triad_score,0.314949,0.932063,0.715623
2,TrpB3A,all,Triad_score,0.046876,0.381388,0.659206
3,TrpB3B,all,Triad_score,0.021349,0.254069,0.633794
4,TrpB3C,all,Triad_score,0.024897,0.447985,0.610434
...,...,...,...,...,...,...
103,TrpB3F,double,esm_score,0.043033,0.493978,0.551364
104,TrpB3G,double,esm_score,0.004594,0.482683,0.505708
105,TrpB3H,double,esm_score,0.152528,0.393976,0.862815
106,TrpB3I,double,esm_score,0.191259,0.859007,0.610453


In [16]:
df_score = df_expanded.melt(id_vars=["lib", "n_mut", "zs_type"], value_vars=["rho", "ndcg", "rocauc"], var_name="metric", value_name="value")
df_score

Unnamed: 0,lib,n_mut,zs_type,metric,value
0,DHFR,all,Triad_score,rho,-0.051567
1,GB1,all,Triad_score,rho,0.314949
2,TrpB3A,all,Triad_score,rho,0.046876
3,TrpB3B,all,Triad_score,rho,0.021349
4,TrpB3C,all,Triad_score,rho,0.024897
...,...,...,...,...,...
319,TrpB3F,double,esm_score,rocauc,0.551364
320,TrpB3G,double,esm_score,rocauc,0.505708
321,TrpB3H,double,esm_score,rocauc,0.862815
322,TrpB3I,double,esm_score,rocauc,0.610453


In [8]:
import holoviews as hv
hv.extension('bokeh')

In [9]:
from SSMuLA.zs_analysis import ZS_OPTS

In [30]:
def hook(plot,element):
    plot.handles['plot'].x_range.factors = [(lib, zs) for lib in LIB_NAMES for zs in ZS_OPTS]

In [32]:
for metric in ["rho", "ndcg", "rocauc"]:

    # Create the Holoviews Bars element
    save_bokeh_hv(
        hv.Bars(df_score[df_score["metric"] == metric], 
                kdims=["lib", "zs_type"], 
                vdims="value"
                ).opts(
        width=1200,
        height=400,
        show_legend=True,
        legend_position="top",
        legend_offset=(0, 5),
        ylabel=f"{metric} correlation",
        multi_level=False,
        title=f"ZS fitness {metric} correlation",
        xlabel="Library",
        hooks=[fixmargins, one_decimal_y, hook],
            ),
        plot_name=f"zs_stat_scale2max-{metric}",
        plot_path=os.path.join("results/zs_sum", "none"),
    )


In [33]:
def noesmhook(plot,element):
    plot.handles['plot'].x_range.factors = [(lib, zs) for lib in LIB_NAMES for zs in ["Triad_score", "ev_score"]]

for metric in ["rho", "ndcg", "rocauc"]:

    # Create the Holoviews Bars element
    save_bokeh_hv(
        hv.Bars(df_score[(df_score["metric"] == metric) & (df_score["zs_type"] != "esm_score")], 
                kdims=["lib", "zs_type"], 
                vdims="value"
                ).opts(
        width=1200,
        height=400,
        show_legend=True,
        legend_position="top",
        legend_offset=(0, 5),
        ylabel=f"{metric} correlation",
        multi_level=False,
        title=f"ZS fitness {metric} correlation",
        xlabel="Library",
        hooks=[fixmargins, one_decimal_y, noesmhook],
            ),
        plot_name=f"zs_stat_scale2max-{metric}-noesm",
        plot_path=os.path.join("results/zs_sum", "none"),
    )


In [39]:
# try de sim sum
df = pd.read_csv("results/simulations/DE-active/scale2max/all_landscape_de_summary.csv")
df

Unnamed: 0,lib,de_type,mean_all,median_all,mean_top96,median_top96,mean_top384,median_top384,fraction_max
0,DHFR,single_step_DE,0.889922,0.857847,1.0,1.0,1.0,1.0,0.283568
1,DHFR,recomb_SSM,0.851463,0.847249,0.999315,1.0,0.949198,0.959943,0.090164
2,DHFR,top96_SSM,0.959305,1.0,1.0,1.0,1.0,1.0,0.632319
3,GB1,single_step_DE,0.571523,0.597319,1.0,1.0,1.0,1.0,0.026045
4,GB1,recomb_SSM,0.362927,0.37017,0.978695,1.0,0.887291,0.862211,0.002055
5,GB1,top96_SSM,0.611348,0.620935,1.0,1.0,1.0,1.0,0.02504
6,TrpB3A,single_step_DE,0.401312,0.200246,0.993309,1.0,,,0.254237
7,TrpB3A,recomb_SSM,0.401256,0.190336,,,,,0.220339
8,TrpB3A,top96_SSM,0.428916,0.230815,,,,,0.288136
9,TrpB3B,single_step_DE,0.271319,0.12797,0.294215,0.12797,,,0.166667


In [40]:
# Fill NaNs in 'mean_top96' and 'median_top96' from 'mean_all' and 'median_all'
df['mean_top96'] = df['mean_top96'].fillna(df['mean_all'])
df['median_top96'] = df['median_top96'].fillna(df['median_all'])

# Fill NaNs in 'mean_top384' and 'median_top384' from 'mean_top96' and 'median_top96'
df['mean_top384'] = df['mean_top384'].fillna(df['mean_top96'])
df['median_top384'] = df['median_top384'].fillna(df['median_top96'])

df

Unnamed: 0,lib,de_type,mean_all,median_all,mean_top96,median_top96,mean_top384,median_top384,fraction_max
0,DHFR,single_step_DE,0.889922,0.857847,1.0,1.0,1.0,1.0,0.283568
1,DHFR,recomb_SSM,0.851463,0.847249,0.999315,1.0,0.949198,0.959943,0.090164
2,DHFR,top96_SSM,0.959305,1.0,1.0,1.0,1.0,1.0,0.632319
3,GB1,single_step_DE,0.571523,0.597319,1.0,1.0,1.0,1.0,0.026045
4,GB1,recomb_SSM,0.362927,0.37017,0.978695,1.0,0.887291,0.862211,0.002055
5,GB1,top96_SSM,0.611348,0.620935,1.0,1.0,1.0,1.0,0.02504
6,TrpB3A,single_step_DE,0.401312,0.200246,0.993309,1.0,0.993309,1.0,0.254237
7,TrpB3A,recomb_SSM,0.401256,0.190336,0.401256,0.190336,0.401256,0.190336,0.220339
8,TrpB3A,top96_SSM,0.428916,0.230815,0.428916,0.230815,0.428916,0.230815,0.288136
9,TrpB3B,single_step_DE,0.271319,0.12797,0.294215,0.12797,0.294215,0.12797,0.166667


In [37]:
for lib in sorted(glob("data/TrpB/scale2max/TrpB3*.csv")):
    if "codon" not in lib:
        trpb_df = pd.read_csv(lib)
        if trpb_df.loc[trpb_df["fitness"].idxmax()]["fitness"] != 1:
            print(trpb_df.loc[trpb_df["fitness"].idxmax()])

In [None]:
# Find the index of the row with the max value in column 'C'
max_index = df['C'].idxmax()

# Retrieve the row with the max value in column 'C'
max_row = df.loc[max_index]

In [41]:
de_metric_map = {
    "mean_all": "all simulations fitness mean",
    "median_all": "all simulations fitness median",
    "mean_top96": "top 96 simulations fitness mean",
    "median_top96":	"top 96 simulations fitness median",
    "mean_top384": "top 384 simulations fitness mean",
    "median_top384": "top 384 simulations fitness median",
    "fraction_max": "fraction reached max fitness",
}

In [44]:
def de_hook(plot,element):
    plot.handles['plot'].x_range.factors = [(lib, de) for lib in LIB_NAMES for de in ["single_step_DE", "recomb_SSM", "top96_SSM"]]

for metric, metric_dets in de_metric_map.items():

    title = f"DE from active variant {metric_dets}"

    save_bokeh_hv(
        hv.Bars(df, kdims=["lib", "de_type"], vdims=metric).opts(
        width=1200,
        height=400,
        show_legend=True,
        legend_position="top",
        legend_offset=(0, 5),
        ylabel=metric_dets.capitalize(),
        multi_level=False,
        title=title,
        xlabel="Library",
        hooks=[fixmargins, one_decimal_y, de_hook],
    ),
        plot_name=title,
        plot_path=os.path.join("results/simulations/DE-active", "scale2max", "summary"),
    )

In [9]:
mlde_df = pd.read_csv("results/mlde_old/vis/all_df.csv")
mlde_df

Unnamed: 0,encoding,model,n_sample,ft_lib,repeats,n_mut_cutoff,lib,zs,n_top,maxes_all,means_all,maxes,means,ndcgs,rhos,if_truemaxs,truemax_inds
0,esm2_t33_650M_UR50D-flatten_site,ridge,384,4000,0,all,DHFR,Triad_score,384,0.628731,0.169707,1.000000,0.476004,0.920996,0.308924,1.0,3270.0
1,esm2_t33_650M_UR50D-flatten_site,ridge,384,4000,0,all,DHFR,Triad_score,384,0.628112,0.174797,1.000000,0.423837,0.897278,0.325964,1.0,3270.0
2,esm2_t33_650M_UR50D-flatten_site,ridge,384,4000,0,all,DHFR,Triad_score,384,0.632948,0.163465,0.996537,0.496190,0.918804,0.341528,0.0,3270.0
3,esm2_t33_650M_UR50D-flatten_site,ridge,384,4000,1,all,DHFR,Triad_score,384,0.729294,0.173647,1.000000,0.511053,0.906804,0.320921,1.0,3270.0
4,esm2_t33_650M_UR50D-flatten_site,ridge,384,4000,1,all,DHFR,Triad_score,384,0.659609,0.168476,1.000000,0.446297,0.924130,0.346594,1.0,3270.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
245895,one-hot,ridge,384,77,47,single,TrpB4,none,96,0.744631,0.004855,0.739434,0.287961,0.978963,0.322322,0.0,
245896,one-hot,ridge,384,77,48,single,TrpB4,none,96,0.771967,0.016475,0.749107,0.300604,0.979479,0.325815,0.0,
245897,one-hot,ridge,384,77,48,single,TrpB4,none,96,0.697771,0.006323,0.749107,0.333802,0.980334,0.329226,0.0,
245898,one-hot,ridge,384,77,49,single,TrpB4,none,96,0.674717,0.013695,0.749107,0.343322,0.980304,0.325455,0.0,


In [15]:
pd.set_option('display.max_rows', None)

mlde_df[
    (mlde_df["zs"] == "none")
    & (mlde_df["n_top"] == 96)
    & (mlde_df["n_mut_cutoff"] == "all")
][["lib", "encoding", "model", "maxes_all", "means_all", "maxes", "means", "ndcgs", "rhos", "if_truemaxs", "truemax_inds"]].groupby(["lib", "encoding", "model"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,maxes_all,means_all,maxes,means,ndcgs,rhos,if_truemaxs,truemax_inds
lib,encoding,model,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
DHFR,esm2_t33_650M_UR50D-flatten_site,boosting,0.773631,0.158778,0.977409,0.610694,0.940657,0.832849,0.64,
DHFR,esm2_t33_650M_UR50D-flatten_site,ridge,0.660442,0.158332,0.969651,0.616295,0.941632,0.380089,0.6,3270.0
DHFR,esm2_t33_650M_UR50D-mean_all,boosting,0.884329,0.158615,0.985464,0.587484,0.934488,0.801577,0.58,
DHFR,esm2_t33_650M_UR50D-mean_all,ridge,0.323114,0.158794,0.970751,0.525285,0.909034,0.637188,0.76,
DHFR,esm2_t33_650M_UR50D-mean_site,boosting,0.876711,0.158492,0.969474,0.550992,0.917471,0.413603,0.24,3270.0
DHFR,esm2_t33_650M_UR50D-mean_site,ridge,0.716095,0.158744,0.96293,0.559186,0.92239,0.762832,0.69,
DHFR,one-hot,boosting,0.880601,0.155112,0.964598,0.610213,0.942136,0.822785,0.3,
DHFR,one-hot,ridge,0.538627,0.158666,0.938048,0.605672,0.932893,0.784621,0.8,
GB1,esm2_t33_650M_UR50D-flatten_site,boosting,0.365111,0.009849,0.660743,0.199664,0.802327,0.334959,0.05,
GB1,esm2_t33_650M_UR50D-flatten_site,ridge,0.289781,0.009028,0.70865,0.27612,0.839122,0.357882,0.08,57022.0


In [10]:
nan_rows = mlde_df[mlde_df.isna().any(axis=1)]
nan_rows

Unnamed: 0,encoding,model,n_sample,ft_lib,repeats,n_mut_cutoff,lib,zs,n_top,maxes,means,ndcgs,rhos_pearson,rhos_spearman,iftruemax
28800,esm2_t33_650M_UR50D-flatten_site,boosting,384,1141,0,double,DHFR,Triad_score,384,1.000000,0.552582,0.959451,0.813289,0.386621,
28801,esm2_t33_650M_UR50D-flatten_site,boosting,384,1141,0,double,DHFR,Triad_score,384,1.000000,0.563792,0.967116,0.829891,0.391226,
28802,esm2_t33_650M_UR50D-flatten_site,boosting,384,1141,0,double,DHFR,Triad_score,384,1.000000,0.564975,0.966861,0.858433,0.386016,
28803,esm2_t33_650M_UR50D-flatten_site,boosting,384,1141,0,double,DHFR,Triad_score,384,1.000000,0.545364,0.963988,0.813700,0.394607,
28804,esm2_t33_650M_UR50D-flatten_site,boosting,384,1141,0,double,DHFR,Triad_score,384,1.000000,0.538738,0.958187,0.858015,0.410142,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
167995,one-hot,ridge,384,77,47,single,TrpB4,ev_score,96,0.739434,0.287961,0.978963,0.322322,0.189986,
167996,one-hot,ridge,384,77,48,single,TrpB4,ev_score,96,0.749107,0.300604,0.979479,0.325815,0.188705,
167997,one-hot,ridge,384,77,48,single,TrpB4,ev_score,96,0.749107,0.333802,0.980334,0.329226,0.194473,
167998,one-hot,ridge,384,77,49,single,TrpB4,ev_score,96,0.749107,0.343322,0.980304,0.325455,0.190111,


In [5]:
mlde_df = pd.read_csv("results/mlde/vis/all_df.csv")
mlde_df

Unnamed: 0,encoding,model,n_sample,ft_lib,repeats,n_mut_cutoff,lib,zs,n_top,maxes_all,means_all,maxes,means,ndcgs,rhos,if_truemaxs,truemax_inds
0,one-hot,boosting,384,4000,0,all,DHFR,Triad_score,384,0.837839,0.160675,1.000000,0.450315,0.931561,0.512189,1.0,3270.0
1,one-hot,boosting,384,4000,0,all,DHFR,Triad_score,384,0.835335,0.163866,1.000000,0.443002,0.919135,0.520578,1.0,3270.0
2,one-hot,boosting,384,4000,0,all,DHFR,Triad_score,384,0.839829,0.157905,1.000000,0.425063,0.894345,0.528852,1.0,3270.0
3,one-hot,boosting,384,4000,0,all,DHFR,Triad_score,384,0.811491,0.171549,1.000000,0.479756,0.917335,0.499166,1.0,3270.0
4,one-hot,boosting,384,4000,0,all,DHFR,Triad_score,384,0.839540,0.174150,1.000000,0.413084,0.911455,0.507032,1.0,3270.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
119995,one-hot,ridge,384,77,97,single,TrpB4,none,96,0.771698,0.017581,0.749107,0.330329,0.979702,0.190245,0.0,2961.0
119996,one-hot,ridge,384,77,98,single,TrpB4,none,96,0.597178,0.060315,0.749107,0.361424,0.980491,0.190741,0.0,2961.0
119997,one-hot,ridge,384,77,98,single,TrpB4,none,96,0.634966,0.009730,0.752895,0.428029,0.981345,0.193477,0.0,2961.0
119998,one-hot,ridge,384,77,99,single,TrpB4,none,96,0.783760,0.040222,0.749107,0.284615,0.979312,0.188739,0.0,2961.0


In [41]:
mlde_df = pd.read_csv("results/mlde/vis/all_df.csv")
mlde_df

Unnamed: 0,encoding,model,n_sample,ft_lib,rep,maxes_all,means_all,maxes,means,ndcgs,rhos,if_truemaxs,truemax_inds,n_mut_cutoff,lib,zs,n_top
0,esm2_t33_650M_UR50D-flatten_site,boosting,384,4000,0,0.840951,0.156936,1.000000,0.499483,0.930975,0.498705,1.0,267.0,all,DHFR,Triad_score,384
1,esm2_t33_650M_UR50D-flatten_site,boosting,384,4000,1,0.837710,0.159552,0.996537,0.500002,0.932816,0.480707,0.0,,all,DHFR,Triad_score,384
2,esm2_t33_650M_UR50D-flatten_site,boosting,384,4000,2,0.841871,0.158632,0.996537,0.464790,0.901309,0.464723,0.0,,all,DHFR,Triad_score,384
3,esm2_t33_650M_UR50D-flatten_site,boosting,384,4000,3,0.812338,0.180189,1.000000,0.474807,0.926957,0.420841,1.0,2.0,all,DHFR,Triad_score,384
4,esm2_t33_650M_UR50D-flatten_site,boosting,384,4000,4,0.841787,0.162283,1.000000,0.485674,0.923643,0.430931,1.0,274.0,all,DHFR,Triad_score,384
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
245995,one-hot,ridge,384,77,95,0.771698,0.017581,0.749107,0.330329,0.979702,0.190245,0.0,,single,TrpB4,none,96
245996,one-hot,ridge,384,77,96,0.597178,0.060315,0.749107,0.361424,0.980491,0.190741,0.0,,single,TrpB4,none,96
245997,one-hot,ridge,384,77,97,0.634966,0.009730,0.752895,0.428029,0.981345,0.193477,0.0,,single,TrpB4,none,96
245998,one-hot,ridge,384,77,98,0.783760,0.040222,0.749107,0.284615,0.979312,0.188739,0.0,,single,TrpB4,none,96


In [42]:
len(mlde_df[(mlde_df["encoding"] == "one-hot")])

98400

In [43]:
mlde_df[(mlde_df["encoding"] == "one-hot") & (mlde_df["lib"] == "GB1") & (mlde_df["model"] == "boosting") & (mlde_df["rep"] == 0)]

Unnamed: 0,encoding,model,n_sample,ft_lib,rep,maxes_all,means_all,maxes,means,ndcgs,rhos,if_truemaxs,truemax_inds,n_mut_cutoff,lib,zs,n_top
4800,one-hot,boosting,384,80000,0,0.460794,0.00945,0.796777,0.062503,0.801123,0.419881,0.0,,all,GB1,Triad_score,384
4900,one-hot,boosting,384,40000,0,0.592693,0.015428,0.862211,0.224275,0.817644,0.389161,0.0,,all,GB1,Triad_score,384
5000,one-hot,boosting,384,20000,0,0.490686,0.009391,1.0,0.258444,0.832224,0.420885,1.0,227.0,all,GB1,Triad_score,384
5400,one-hot,boosting,384,80000,0,0.460794,0.00945,0.796777,0.175173,0.801123,0.419881,0.0,,all,GB1,Triad_score,96
5500,one-hot,boosting,384,40000,0,0.592693,0.015428,0.689827,0.31451,0.817644,0.389161,0.0,,all,GB1,Triad_score,96
5600,one-hot,boosting,384,20000,0,0.490686,0.009391,0.68958,0.364982,0.832224,0.420885,0.0,,all,GB1,Triad_score,96
38600,one-hot,boosting,384,2168,0,0.515941,0.035169,0.735277,0.199068,0.801888,0.313289,0.0,,double,GB1,Triad_score,384
38800,one-hot,boosting,384,2168,0,0.515941,0.035169,0.614051,0.297802,0.801888,0.313289,0.0,,double,GB1,Triad_score,96
59600,one-hot,boosting,384,77,0,0.455008,0.139422,0.91819,0.121756,0.713787,0.170183,0.0,,single,GB1,Triad_score,384
59800,one-hot,boosting,384,77,0,0.455008,0.139422,0.620935,0.233584,0.713787,0.170183,0.0,,single,GB1,Triad_score,96


In [44]:
len(mlde_df[(mlde_df["encoding"] == "one-hot") & (mlde_df["lib"] == "GB1") & (mlde_df["model"] == "boosting") & (mlde_df["rep"] == 0)])

36

In [45]:
mlde_df.truemax_inds.unique()

array([267.,  nan,   2., 274., 226., 129.,   5.,  27., 130.,   1., 313.,
         9., 319., 174.,  42., 159.,  81., 102., 374., 245.,   3., 361.,
        17.,  52., 262.,  47., 306., 187.,  64.,  83.,   4., 106., 122.,
        10.,  59.,  63., 178.,  22., 276., 185., 183.,  11., 153., 266.,
       254.,  33.,  85., 241., 191.,  26.,  99.,  39., 307., 344., 309.,
       101.,  71.,  61.,  89.,  12.,  67., 116., 287., 281.,  60.,  70.,
        13.,  92., 234., 162., 232.,  72., 125., 273., 161.,  21.,  50.,
        62., 337.,  69., 145., 311., 222., 357., 329.,  29., 124., 110.,
        34.,  24., 115., 299., 237.,  43.,   6., 149., 214.,  36., 108.,
       301.,  93.,  14., 244.,  30.,  40., 168., 128., 150.,   8., 184.,
        37.,   7., 126., 182., 176., 236., 202.,  90., 136., 114.,  41.,
        75.,  48., 286., 261., 155., 193., 131.,  15., 160., 147.,  53.,
        51.,  18., 152.,  94.,  79., 339.,  95.,  19., 100., 239., 256.,
        57., 205., 227.,  82., 295., 138., 250., 20

In [27]:
import numpy as np

# Generate the 6D random array
data = np.random.randint(0, 100, (3, 2, 1, 3, 50, 96))

# Initialize an output array filled with np.nan
output_indices = np.full(data.shape[:-1], np.nan)

# Iterate over all possible indices of the first 5 dimensions
for i in range(data.shape[0]):
    for j in range(data.shape[1]):
        for k in range(data.shape[2]):
            for l in range(data.shape[3]):
                for m in range(data.shape[4]):
                    # Find the index in the last dimension where the element is 99
                    match_indices = np.where(data[i, j, k, l, m] == 99)[0]
                    if match_indices.size > 0:
                        # If there is at least one match, take the first one
                        output_indices[i, j, k, l, m] = match_indices[0]

# Print the output array to verify
output_indices.shape


(3, 2, 1, 3, 50)

In [14]:
mlde_df[mlde_df.isna().any(axis=1)]

Unnamed: 0,encoding,model,n_sample,ft_lib,rep,maxes_all,means_all,maxes,means,ndcgs,rhos,if_truemaxs,truemax_inds,n_mut_cutoff,lib,zs,n_top
