## Notebook for debugging ESCE

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import json

In [2]:
# use to prepare features, targets and covariates
from prepare_data import prepare_data
from generate_splits import write_splitfile
from fit_model import fit
from aggregate import aggregate

rule prepare_features_or_targets:

In [3]:
target = "bmi-0"
covariate = "none"
n_train = 92682
seed = 1
grid = "genetic"

In [4]:
prepare_data(
    out_path="/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/features/whole-genetic_none.npy",
    dataset='ukbb',
    features_targets_covariates="features",  # or "targets" or "covariates"
    variant="whole-genetic",  
    custom_datasets=yaml.safe_load(open("/ritter/share/projects/lauraf/esce/config/config_debugging.yaml","r"))['custom_datasets']
)


__Reminder: Check if target is listed in "/ritter/share/projects/lauraf/esce/config/config_debugging.yaml"__

In [8]:

prepare_data(
    out_path=f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/targets/whole-{target}_none.npy",
    dataset='ukbb',
    features_targets_covariates="targets",  # or "targets" or "covariates"
    variant=f"whole-{target}", 
    custom_datasets=yaml.safe_load(open("/ritter/share/projects/lauraf/esce/config/config_debugging.yaml","r"))['custom_datasets']
)


rule prepare_covariates:

In [9]:
prepare_data(
    out_path="/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/covariates/none.npy",
    dataset='ukbb',
    features_targets_covariates="covariates",  # or "targets" or "covariates"
    variant="none",
    custom_datasets=yaml.safe_load(open("/ritter/share/projects/lauraf/esce/config/config_debugging.yaml","r"))['custom_datasets']
)

prepare_data(
    out_path="/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/covariates/balanced.npy",
    dataset='ukbb',
    features_targets_covariates="covariates",  # or "targets" or "covariates"
    variant="balanced",
    custom_datasets=yaml.safe_load(open("/ritter/share/projects/lauraf/esce/config/config_debugging.yaml","r"))['custom_datasets']
)



rule split:

In [12]:
val_test_frac = .25
val_test_max = 1000
val_test_min = 100

n_val = n_test = min(
    round(n_train * val_test_frac),val_test_max
) if val_test_max else round(n_train * val_test_frac)
n_val = n_test = max(n_val, val_test_min) if val_test_min else n_val
assert n_train > 1 and n_val > 1 and n_test > 1

write_splitfile(
    features_path="/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/features/whole-genetic_none.npy",
    targets_path=f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/targets/whole-{target}_none.npy",
    split_path=f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/splits/whole-genetic_none_whole-{target}_none_{covariate}_{str(n_train)}_{str(seed)}.json",
    sampling_path=f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/covariates/{covariate}.npy",
    sampling_type=f"{covariate}",
    n_train=n_train,
    n_val=n_val,
    n_test=n_test,
    seed=seed,
    stratify=True,
)


In [13]:
f = open(f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/splits/whole-genetic_none_whole-{target}_none_{covariate}_{str(n_train)}_{str(seed)}.json")
json.load(f)

{'idx_train': [367686,
  132633,
  128655,
  155523,
  475613,
  53196,
  462737,
  400874,
  235287,
  372855,
  192463,
  446950,
  459113,
  287758,
  44246,
  76763,
  287820,
  346618,
  264730,
  364485,
  223017,
  139529,
  28095,
  318363,
  313485,
  143680,
  350455,
  332364,
  54106,
  270054,
  245627,
  114709,
  25057,
  483937,
  351022,
  188946,
  203339,
  231089,
  123180,
  102846,
  109986,
  393581,
  109334,
  76884,
  163052,
  221716,
  26282,
  172439,
  496218,
  143592,
  196574,
  59591,
  61393,
  278029,
  416182,
  64407,
  135019,
  478826,
  334901,
  454178,
  307766,
  133137,
  375750,
  494917,
  88659,
  67252,
  29409,
  145014,
  110352,
  228595,
  211578,
  169382,
  417848,
  83250,
  264073,
  475074,
  4610,
  67340,
  158797,
  336422,
  73971,
  420137,
  227939,
  37431,
  345467,
  450739,
  161985,
  380413,
  64966,
  65082,
  243009,
  4740,
  335097,
  3714,
  71874,
  470972,
  59381,
  132667,
  327025,
  432662,
  173672,
  110

In [15]:
fit(
    features_path = "/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/features/whole-genetic_none.npy",
    targets_path = f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/targets/whole-{target}_none.npy",
    split_path = f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/splits/whole-genetic_none_whole-{target}_none_{covariate}_{str(n_train)}_{str(seed)}.json",
    scores_path = f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/fits/ridge-cls/whole-genetics_none_whole-{target}_none_{covariate}_{str(n_train)}_{str(seed)}_{grid}.csv",
    model_name = "ridge-reg",
    grid_path = f"/ritter/share/projects/lauraf/thesis_laura_modalities/config/grids/{grid}.yaml",
    existing_scores_path_list = [],
)

In [16]:
fit = pd.read_csv(f"/ritter/share/projects/lauraf/thesis_laura_modalities/results/ukbb/fits/ridge-cls/whole-genetics_none_whole-{target}_none_{covariate}_{str(n_train)}_{str(seed)}_{grid}.csv")

In [17]:
fit

Unnamed: 0,r2_train,r2_val,r2_test,mae_train,mae_val,mae_test,mse_train,mse_val,mse_test,alpha,n,s
0,0.092861,0.00512,-0.014629,3.487732,3.704002,3.69415,20.943166,23.705894,22.72918,5000,92682,1
1,0.078959,0.025251,0.009289,3.509604,3.655748,3.648599,21.264122,23.226224,22.193375,40000,92682,1
2,0.0661,0.032372,0.01759,3.531936,3.636907,3.628429,21.561014,23.056545,22.007423,100000,92682,1
3,0.035606,0.027401,0.018659,3.588251,3.63739,3.620498,22.265035,23.174979,21.98348,500000,92682,1
4,0.028722,0.023733,0.016975,3.601154,3.643953,3.622855,22.423961,23.2624,22.021203,750000,92682,1
5,0.013799,0.01303,0.010804,3.629169,3.666341,3.632214,22.768483,23.517411,22.159445,2500000,92682,1
6,0.005096,0.004863,0.0046,3.645644,3.683846,3.645127,22.96942,23.712011,22.298406,10000000,92682,1
