In [1]:
from mm79.train_modules.utils import get_ensemble_results
from mm79 import EXPERIMENT_DIR
from mm79.train_modules.utils import check_constraints, convert_results_to_latex
import os

  from .autonotebook import tqdm as notebook_tqdm


## Pretrain the model

Configure your sweep. Example in `example_sweep.yaml`

Run the sweep:

`python run_sweep.py --config_name=example_sweep`

Write down your sweep number.

You can evaluate the performance of your different models here:

`python precompute_AR_results.py`

Change the sweep name to match the pre-trained sweep name.

The cells below will extract the results and select the best runs to fine-tune

In [9]:
sweep_name = "un80dv5dgm" #[insert sweep name]
df_ensemble = get_ensemble_results(sweep_name = sweep_name, constraints = {"fold":[0,1,2],"bootstrap_seed":[0,1,2,3,4]}, evaluation_params = {"t_cond": 6, "t_horizon":6, "subgroup_strat":"myeloma-type"})

Evaluating on MM2 dataset...
Constraint
{'data_type': 'Custom', 'dropout_p': 0.1, 'emission_proba': False, 'fold': 0, 'gpus': 0, 'hidden_dim': 16, 'lambda_reg': 0.0, 'max_epochs': 1, 'nhead_treat': 0, 'nheads': 2, 'num_layers': 2, 'planned_treatments': True, 'restricted_input_features_set': False, 'restricted_pred_features_set': False, 't_cond': -1, 'use_rna': False, 'bootstrap_seed': 0}
Processing run /Users/edebrouwer/YaleLocal/MIT-D79-Repo/mm79/../experiments/logs/un80dv5dgm_Seq2Seq_transformer_Custom/version_1
Constraint
{'data_type': 'Custom', 'dropout_p': 0.1, 'emission_proba': False, 'fold': 0, 'gpus': 0, 'hidden_dim': 16, 'lambda_reg': 0.0, 'max_epochs': 1, 'nhead_treat': 0, 'nheads': 2, 'num_layers': 2, 'planned_treatments': True, 'restricted_input_features_set': False, 'restricted_pred_features_set': False, 't_cond': -1, 'use_rna': False, 'bootstrap_seed': 1}
Processing run /Users/edebrouwer/YaleLocal/MIT-D79-Repo/mm79/../experiments/logs/un80dv5dgm_Seq2Seq_transformer_Custom

In [10]:
# Choose best hyper-params
best_idx = df_ensemble.groupby(["fold","subgroup"])["val_mse"].transform("min") == df_ensemble["val_mse"]
df_ens_best = df_ensemble[best_idx]
df_best_all = df_ens_best.loc[df_ens_best.subgroup=="all"]
print(df_best_all)

# Choose runs to fine-tune

cols_to_drop = ["val_mse","test_mse","val_auc","test_auc","subgroup",
                "val_mse_serum","test_mse_serum","val_mse_chem","test_mse_chem"]
cols_to_drop = cols_to_drop + [c for c in df_best_all.columns if "concordance" in c]

versions = []
for idx in range(len(df_best_all)):
  constraint_ = dict(df_best_all.iloc[idx].drop(cols_to_drop))
  #constraint_["bootstrap_seed"] = [0,1,2,3,4]

  log_dir = os.path.join(EXPERIMENT_DIR, "logs")
  exp_dirs = [os.path.join(log_dir, d)
            for d in os.listdir(log_dir) if sweep_name in d]
  assert(len(exp_dirs) == 1)

  run_dirs = []
  for exp_dir in exp_dirs:
    run_dirs += [os.path.join(exp_dir, d) for d in os.listdir(exp_dir)]            

  versions += [
            r.split("/")[-1] for r in run_dirs if check_constraints(r, constraint_)]
  
for version in versions:
  print(f"- {version}")

  data_type  dropout_p  emission_proba  fold  gpus  hidden_dim  lambda_reg  \
0    Custom        0.1           False     0     0          16         0.0   
0    Custom        0.1           False     1     0          16         0.0   
0    Custom        0.1           False     2     0          16         0.0   

   max_epochs  nhead_treat  nheads  ...  test_concordance_ae_7  \
0           1            0       2  ...                    NaN   
0           1            0       2  ...                    NaN   
0           1            0       2  ...                    NaN   

   val_concordance_ae_8  test_concordance_ae_8  val_concordance_ae_9  \
0                   NaN                    NaN                   NaN   
0                   NaN                    NaN                   NaN   
0                   NaN                    NaN                   NaN   

   test_concordance_ae_9  val_concordance_ae_10  test_concordance_ae_10  \
0                    NaN                    NaN           

## Fine-tuning

Copy-paste the version names and edit the `pretrained_sweep_name` in `example_fine_tune_sweep.yaml`

Change the outcome and event_type according to your use-case (pfs for progression free survival, OS for overall survival), adverse events are fine-tuned by default.

Then run the fine-tuning sweep using:

`python run_sweep.py --config_name=example_fine_tune_sweep`

You can then precompute your results using

`python precompute_AR_results.py` (You should modify the sweep name with the fine-tuning sweep name there).

Then run the cells below to collect the final results !

In [3]:
sweep_name = "k5cq332uzw" # This should be the sweep name of the fine-tuning sweep
dataset_name = "MM2"
var_bin = None
df_res = []
for t_cond in [6]: # The different condition times  for the predictions.
  for t_horizon in  [6]: # The different time horizons for the predictions.
    df_ensemble = get_ensemble_results(sweep_name = sweep_name, constraints = { "fold":[0,1,2],
                                                                               "early_stopping":[50]}, evaluation_params = {"t_cond": t_cond, "t_horizon": t_horizon, "subgroup_strat":"myeloma-type","var_bin":var_bin, "dataset_name":dataset_name})
    df_ensemble["t_cond"] = t_cond
    df_ensemble["t_horizon"] = t_horizon
    df_res.append(df_ensemble)

Evaluating on MM2 dataset...
Constraint
{'early_stopping': 50, 'emission_type': 'non_linear', 'emission_window': 1, 'event_type': 'pfs', 'include_baseline': True, 'include_last': True, 'outcome': 'pfs', 'fold': 0, 'bootstrap_seed': 0}
Processing run /Users/edebrouwer/YaleLocal/MIT-D79-Repo/mm79/../experiments/logs/k5cq332uzw_FineTune_Seq2Seq_transformer_Custom/version_0
Constraint
{'early_stopping': 50, 'emission_type': 'non_linear', 'emission_window': 1, 'event_type': 'pfs', 'include_baseline': True, 'include_last': True, 'outcome': 'pfs', 'fold': 0, 'bootstrap_seed': 1}
Processing run /Users/edebrouwer/YaleLocal/MIT-D79-Repo/mm79/../experiments/logs/k5cq332uzw_FineTune_Seq2Seq_transformer_Custom/version_5
Constraint
{'early_stopping': 50, 'emission_type': 'non_linear', 'emission_window': 1, 'event_type': 'pfs', 'include_baseline': True, 'include_last': True, 'outcome': 'pfs', 'fold': 0, 'bootstrap_seed': 2}
Processing run /Users/edebrouwer/YaleLocal/MIT-D79-Repo/mm79/../experiments/l

In [4]:
import pandas as pd
metrics = ["val_mse","test_mse","val_mse_serum","test_mse_serum","val_mse_chem","test_mse_chem","val_auc","test_auc","val_concordance_event","test_concordance_event"]
metrics = metrics + [c for c in df_res[0] if "concordance_ae" in c]
protected_cols = ["emission_type","emission_window","subgroup","t_cond","t_horizon"]

mu_df = pd.concat(df_res).groupby(protected_cols)[metrics].mean().reset_index()
std_df = pd.concat(df_res).groupby(protected_cols)[metrics].std().reset_index()

aggregate_cols = [c for c in mu_df.columns if c not in protected_cols]

df = pd.DataFrame()

for i in range(len(mu_df)):
  constraint = mu_df.iloc[i][protected_cols].to_dict()
  mu_df_ = mu_df[mu_df[list(constraint.keys())].eq(
      constraint).all(axis=1)]
  std_df_ = std_df[std_df[list(constraint.keys())].eq(
      constraint).all(axis=1)]
  df_ = {
      c: f"${mu_df_[c].item():.3f} \pm {std_df_[c].item():.3f}$" for c in aggregate_cols}
  for k, v in constraint.items():
      df_[k] = v
  df = df.append(pd.DataFrame(df_, index=[0]))
df = df[protected_cols+aggregate_cols]

print(df.to_latex(escape = False, index = False))

\begin{tabular}{lrlrrllllllllllllllllllllllllllllllllll}
\toprule
emission_type &  emission_window & subgroup &  t_cond &  t_horizon &           val_mse &          test_mse &     val_mse_serum &    test_mse_serum &      val_mse_chem &     test_mse_chem &           val_auc &          test_auc & val_concordance_event & test_concordance_event & val_concordance_ae_0 & test_concordance_ae_0 & val_concordance_ae_1 & test_concordance_ae_1 & val_concordance_ae_2 & test_concordance_ae_2 & val_concordance_ae_3 & test_concordance_ae_3 & val_concordance_ae_4 & test_concordance_ae_4 & val_concordance_ae_5 & test_concordance_ae_5 & val_concordance_ae_6 & test_concordance_ae_6 & val_concordance_ae_7 & test_concordance_ae_7 & val_concordance_ae_8 & test_concordance_ae_8 & val_concordance_ae_9 & test_concordance_ae_9 & val_concordance_ae_10 & test_concordance_ae_10 & val_concordance_ae_11 & test_concordance_ae_11 \\
\midrule
   non_linear &                1 &      IGA &       6 &          6 & $1.002 \p

  df = df.append(pd.DataFrame(df_, index=[0]))
  df = df.append(pd.DataFrame(df_, index=[0]))
  df = df.append(pd.DataFrame(df_, index=[0]))
  print(df.to_latex(escape = False, index = False))
