/
cross_val.py
52 lines (42 loc) · 1.62 KB
/
cross_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from DGE_experiments import cross_val
from DGE_data import get_real_and_synthetic
from DGE_utils import get_folder_names
num_runs = 10
model_type = 'deepish_mlp'
model_name = 'ctgan_deep'
nsyn = 5000
max_n = 5000
p_train = 0.8
n_models = 20
cross_fold = 5
load_syn = True
load = True
save = True
verbose = True
scores_s_all = {}
scores_r_all = {}
for dataset in ['moons', 'circles', 'breast_cancer', 'adult', 'covid', 'seer']:
workspace_folder, results_folder = get_folder_names(
dataset, model_name, max_n=max_n, nsyn=nsyn)
X_gt, X_syns = get_real_and_synthetic(dataset=dataset,
p_train=p_train,
n_models=n_models*num_runs,
model_name=model_name,
load_syn=load_syn,
verbose=verbose,
max_n=max_n,
nsyn=nsyn)
print(f'Dataset {dataset}\n')
scores_s, scores_r = cross_val(X_gt,
X_syns,
workspace_folder=workspace_folder,
results_folder=results_folder,
save=save,
load=load,
task_type=model_type,
cross_fold=cross_fold,
verbose=verbose
)
scores_s_all[dataset] = scores_s
scores_r_all[dataset] = scores_r