# Figures that summarize DDN and JGL results

In [1]:
import matplotlib.pyplot as plt
from ddn3_extra import plot_simulation as ps

%load_ext autoreload
%autoreload 2

In [3]:
from cycler import cycler
default_cycler = (cycler(color=["#1F77B4", "#FF7F0E", "#3A3A3A", "#5CCEB3"]))
plt.rc('axes', prop_cycle=default_cycler)

# import matplotlib
# matplotlib.rcParams.update({'font.size': 12})

Figures
- 100 features, 200+200 samples
- 100 features, 50+50 samples
- 400 features, 200+200 samples
- 100 features, 50+500 samples

Each figure contains curve with lambda2 that can achieve best performance. Four network types. 
- Fig1: ROC for common. 8 curves. 4 for DDN, 4 for JGL.
- Fig2: F1 for differential
- Fig3: F1 for overall

In [4]:
dat_dir = "../../../x_output/ddn/ddn_jgl/l1_002-002-100_l2_000-0025-015/"

import os
os.path.isdir(dat_dir)

## Graph types

200 samples

In [5]:
res_name_dict = dict(
    random="res_ddn_jgl_random_n_20_n-node_100_200p200_group_0",
    hub="res_ddn_jgl_hub_n_20_n-node_100_200p200",
    cluster="res_ddn_jgl_cluster_n_20_n-node_100_200p200",
    scale_free_1="res_ddn_jgl_scale-free_n_20_n-node_100_200p200",
    scale_free_2="res_ddn_jgl_scale-free-multi_n_20_n-node_100_200p200_group_2",
    scale_free_4="res_ddn_jgl_scale-free-multi_n_20_n-node_100_200p200_group_4",
)

curve_dict = ps.collect_curves(res_name_dict, dat_dir)

In [7]:
fig, ax = ps.draw_roc(curve_dict, scale_free=4)
fig.savefig(f"./figures/common_roc_node_100_sample_200.svg")
# fig.savefig(f"./figures/common_roc_node_100_sample_200.png")

In [10]:
fig, ax = ps.draw_f1(curve_dict, x_type='f1_diff', scale_free=4)
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 1.0])
fig.savefig(f"./figures/diff_f1_node_100_sample_200.svg")

In [11]:
fig, ax = ps.draw_f1(curve_dict, x_type='f1_mean', scale_free=4)
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 1.0])
fig.savefig(f"./figures/mean_f1_node_100_sample_200.svg")

50 samples

In [12]:
res_name_dict = dict(
    random="res_ddn_jgl_random_n_40_n-node_100_50p50_group_0",
    hub="res_ddn_jgl_hub_n_20_n-node_100_50p50",
    cluster="res_ddn_jgl_cluster_n_20_n-node_100_50p50",
    scale_free_1="res_ddn_jgl_scale-free_n_50_n-node_100_50p50",
    scale_free_2="res_ddn_jgl_scale-free-multi_n_40_n-node_100_50p50_group_2",
    scale_free_4="res_ddn_jgl_scale-free-multi_n_40_n-node_100_50p50_group_4",
)

curve_dict = ps.collect_curves(res_name_dict, dat_dir)

In [13]:
fig, ax = ps.draw_roc(curve_dict, scale_free=4)
fig.savefig(f"./figures/common_roc_node_100_sample_50.svg")

fig, ax = ps.draw_f1(curve_dict, x_type='f1_diff', scale_free=4)
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 0.45])
fig.savefig(f"./figures/diff_f1_node_100_sample_50.svg")

fig, ax = ps.draw_f1(curve_dict, x_type='f1_mean', scale_free=4)
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 0.65])
fig.savefig(f"./figures/mean_f1_node_100_sample_50.svg")

## Feature numbers

200 samples

In [14]:
res_name_dict = dict(
    scale_free_4="res_ddn_jgl_scale-free-multi_n_20_n-node_100_200p200_group_4",
    scale_free_8_200="res_ddn_jgl_scale-free-multi_n_20_n-node_200_200p200_group_8",
    scale_free_16_400="res_ddn_jgl_scale-free-multi_n_20_n-node_400_200p200_group_16",
)
curve_dict = ps.collect_curves(res_name_dict, dat_dir)

In [15]:
fig, ax = ps.draw_roc_feature_num(curve_dict)
fig.savefig(f"./figures/common_roc_node_100_200_400_sample_200.svg")

fig, ax = ps.draw_f1_feature_num(curve_dict, x_type='f1_diff')
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 1])
fig.savefig(f"./figures/diff_f1_node_100_200_400_sample_200.svg")

fig, ax = ps.draw_f1_feature_num(curve_dict, x_type='f1_mean')
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 1])
fig.savefig(f"./figures/mean_f1_node_100_200_400_sample_200.svg")

50 samples

In [16]:
res_name_dict = dict(
    scale_free_4="res_ddn_jgl_scale-free-multi_n_40_n-node_100_50p50_group_4",
    scale_free_8_200="res_ddn_jgl_scale-free-multi_n_40_n-node_200_50p50_group_8",
    scale_free_16_400="res_ddn_jgl_scale-free-multi_n_40_n-node_400_50p50_group_16",
)
curve_dict = ps.collect_curves(res_name_dict, dat_dir)

In [17]:
fig, ax = ps.draw_roc_feature_num(curve_dict)
fig.savefig(f"./figures/common_roc_node_100_200_400_sample_50.svg")

fig, ax = ps.draw_f1_feature_num(curve_dict, x_type='f1_diff')
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 0.45])
fig.savefig(f"./figures/diff_f1_node_100_200_400_sample_50.svg")

fig, ax = ps.draw_f1_feature_num(curve_dict, x_type='f1_mean')
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 0.65])
fig.savefig(f"./figures/mean_f1_node_100_200_400_sample_50.svg")

## Unbalanced samples

In [None]:
res_name_dict = dict(
    balanced="res_ddn_jgl_scale-free-multi_n_20_n-node_100_275p275_group_2_jgl-weights_equal",
    balanced_not="res_ddn_jgl_scale-free-multi_n_20_n-node_100_50p500_group_2_jgl-weights_equal",
)
curve_dict = ps.collect_curves(res_name_dict, dat_dir)

In [None]:
fig, ax = ps.draw_roc_balance(curve_dict)
fig.savefig(f"./figures/common_roc_node_100_sample_balanced.svg")

fig, ax = ps.draw_f1_balance(curve_dict, x_type='f1_diff')
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 1])
fig.savefig(f"./figures/diff_f1_node_100_sample_balanced.svg")

fig, ax = ps.draw_f1_balance(curve_dict, x_type='f1_mean')
ax.set_xlim([0, 0.6])
ax.set_ylim([0, 1])
fig.savefig(f"./figures/mean_f1_node_100_sample_balanced.svg")