In [1]:
import mlflow
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from biological_fuzzy_logic_networks.DREAM_analysis.model_analysis_utils import get_test_data_formatted

In [2]:
client = mlflow.tracking.MlflowClient(tracking_uri="http://localhost:5000")

In [3]:
exp_id = client.get_experiment_by_name("Subnetwork").experiment_id
runs = client.search_runs(experiment_ids=exp_id, max_results=3500)
runs_df = pd.DataFrame([{**dict(run.info), **run.data.params, **run.data.metrics} for run in runs])
runs_df = runs_df.drop(columns = ['artifact_uri', 'end_time', 'experiment_id', 'lifecycle_stage', 
                                  'run_uuid', 'start_time', 'status', 'user_id'])

In [None]:
runs = runs_df[~runs_df["test_mse_RB"].isna()]

In [None]:
runs.dtypes[runs.dtypes!="float64"]

In [None]:
runs['n_epochs'] = runs["n_epochs"].astype(int)
runs['batch_size'] = runs["batch_size"].astype(int)
runs['learning_rate'] = runs["n_epochs"].astype(float)

In [None]:
runs["param_number"] = [setting.split("_")[0] for setting in runs["param_setting"]]
all_cv_params = list((runs.groupby("param_number").size()[runs.groupby("param_number").size()==5]).index)
sel_runs = runs[runs["param_number"].isin(all_cv_params)]
avg_runs = sel_runs.groupby("param_number").mean().reset_index(drop=False)

In [None]:
avg_runs

In [None]:
ax = sns.scatterplot(data=avg_runs, x="best_val_loss", y="train_loss", hue="param_number", legend=False)
ax.plot([0,1], [0,1], transform=ax.transAxes, linestyle="--", color="grey")

In [None]:
sns.barplot(data=avg_runs, x="n_epochs", y="valid_loss")

In [None]:
sns.barplot(data=avg_runs, x="batch_size", y="valid_loss")

In [None]:
sns.barplot(data=avg_runs, x="learning_rate", y="valid_loss")

In [None]:
val_r2 = pd.melt(frame=avg_runs, id_vars = "param_number", value_vars=['val_r2_cleavedCas', 'val_r2_AKT_S473',
       'val_r2_AKT_T308', 'val_r2_AMPK', 'val_r2_CREB', 'val_r2_ERK12',
       'val_r2_FAK', 'val_r2_GSK3B', 'val_r2_H3', 'val_r2_JNK',
       'val_r2_MAPKAPK2', 'val_r2_MEK12', 'val_r2_MKK36', 'val_r2_MKK4',
       'val_r2_p38', 'val_r2_p53', 'val_r2_p90RSK', 'val_r2_RB',
       'val_r2_SMAD23', 'val_r2_SRC', 'val_r2_EGFR'], var_name="node", value_name="val_r2")
test_r2 = pd.melt(frame=avg_runs, id_vars = "param_number", value_vars=['test_r2_cleavedCas',
       'test_r2_AKT_S473', 'test_r2_AKT_T308', 'test_r2_AMPK', 'test_r2_CREB',
       'test_r2_ERK12', 'test_r2_FAK', 'test_r2_GSK3B', 'test_r2_H3',
       'test_r2_JNK', 'test_r2_MAPKAPK2', 'test_r2_MEK12', 'test_r2_MKK36',
       'test_r2_MKK4', 'test_r2_p38', 'test_r2_p53', 'test_r2_p90RSK',
       'test_r2_RB', 'test_r2_SMAD23', 'test_r2_SRC', 'test_r2_EGFR'], 
        var_name="node", value_name="test_r2")

In [None]:
temp = pd.merge(val_r2, test_r2, on="param_number")
ax=sns.scatterplot(data=temp, x="test_r2", y="val_r2", hue="node_x", legend=False)
ax.plot([0,1], [0,1], transform=ax.transAxes, linestyle="--", color="grey")

In [None]:
temp = pd.merge(val_r2, test_r2, on="param_number")
ax=sns.scatterplot(data=temp, x="test_r2", y="val_r2", hue="param_number", legend=False)
ax.plot([0,1], [0,1], transform=ax.transAxes, linestyle="--", color="grey")

In [None]:
sns.barplot(data=val_r2, x="node", y="val_r2")
t=plt.xticks(rotation=90)

In [None]:
sns.barplot(data=test_r2, x="node", y="test_r2")
t=plt.xticks(rotation=90)

In [None]:
avg_runs.sort_values("valid_loss")

In [None]:
runs['test_cell_lines']

In [None]:
best_param = 1

In [None]:
# Average over folds
data_folder = "/dccstor/ipc1/CAR/DREAM/"
run_base = f"{data_folder}Model/Test/Subnetwork/1_"

test_outputs = []
test_unscaleds = []
for i in range(5):
    print(i)
    run_folder = f"{run_base}{i}/"
    test_output, test_unscaled = get_test_data_formatted(run_folder, data_folder)
    
    test_outputs.append(test_output)
    test_unscaleds.append(test_unscaled)

In [None]:
/dccstor/ipc1/CAR/DREAM/Model/Test/Subnetwork/