In [1]:
from itertools import product
from display import *
from util import load_experiments

save_svg, save_png, save_pdf = False, False, False

  self[key] = other[key]


## Loading data

In [2]:
dynamics = ["sis", "plancksis"]
networks = "ba"

path = "../../data/case-study/summaries"
exp_names = {d: f"exp-{d}-{networks}" for d in dynamics}
exp = load_experiments(path, exp_names)

Did not find file `exp-sis-ba.zip`, kept proceding.
Did not find file `exp-plancksis-ba.zip`, kept proceding.


In [3]:
from dynalearn.experiments.metrics import LTPMetrics

transitions = [(0, 1), (1, 0)]
colors = {
    "true": [color_pale["blue"], color_pale["red"]],
    "gnn": [color_dark["blue"], color_dark["red"]],
    "mle": [color_dark["blue"], color_dark["red"]],
}
linestyles = {
    "true": ["-", "-"],
    "gnn": ["--", "--"],
    "mle": ["None", "None"],
}
markers = {
    "true": ["None", "None"],
    "gnn": ["None", "None"],
    "mle": ["o", "^"],
}


def ltp_plot(experiment, ax):
    summary = experiment.metrics["TrueLTPMetrics"].data["summaries"]
    true_ltp = experiment.metrics["TrueLTPMetrics"].data["ltp"]
    gnn_ltp = experiment.metrics["GNNLTPMetrics"].data["ltp"]
    mle_ltp = experiment.metrics["MLELTPMetrics"].data["ltp"]
    agg = lambda ltp, in_s, out_s: LTPMetrics.aggregate(
            ltp, summary, 
            in_state=in_s, 
            out_state=out_s,
            axis=1, 
            reduce="mean", 
            err_reduce="percentile"
        )
    x_min, x_max = -np.inf, np.inf
    for i, (in_s, out_s) in enumerate(transitions):
        x, y, yl, yh = agg(true_ltp, in_s, out_s)
        ax.plot(
            x, y, color=colors["true"][i], linestyle=linestyles["true"][i],marker=markers["true"][i],linewidth=3
        )
        ax.fill_between(x, yl, yh, color=colors["true"][i], alpha=0.3)
        
        x, y, yl, yh = agg(gnn_ltp, in_s, out_s)
        ax.plot(
            x, y, color=colors["gnn"][i], linestyle=linestyles["gnn"][i],marker=markers["gnn"][i],linewidth=3
        )
        ax.fill_between(x, yl, yh, color=colors["gnn"][i], alpha=0.3)
        
        x, y, yl, yh = agg(mle_ltp, in_s, out_s)
        yerr = np.concatenate([np.expand_dims(y-yl,0), np.expand_dims(yh-y,0)], axis=0)
        ax.errorbar(
            x, 
            y, 
            yerr=yerr,
            color=colors["mle"][i], 
            linestyle=linestyles["mle"][i], 
            marker=markers["mle"][i], 
            alpha=0.3
        )
#         ax.plot(
#             x, y, color=colors["mle"][i], linestyle=linestyles["mle"][i], marker=markers["mle"][i], alpha=0.5
#         )
#         ax.fill_between(x, yl, yh, color=colors["mle"][i], alpha=0.3)
        
        if x.min() > x_min:
            x_min = x.min()
        if x.max() < x_max:
            x_max = x.max()
    ax.set_xlim([x_min, x_max])
    ax.set_ylim([0, 1.1])
    return ax

## Making the plot

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharey=True)


for i, d in enumerate(["sis", "plancksis"]):
    ltp_plot(exp[d], ax[i])
    ax[i].tick_params(axis='both', which='major', labelsize=small_fontsize)
    ax[i].tick_params(axis='both', which='minor', labelsize=small_fontsize)

ax[0].set_xlabel(r"Number of infected neighbors $[\ell]$", fontsize=large_fontsize)
ax[1].set_xlabel(r"Number of infected neighbors $[\ell]$", fontsize=large_fontsize)
ax[0].set_ylabel(r"Transition probability", fontsize=18)
# ax[1].set_ylabel(r"Transition probability", fontsize=18)

handles = []

handles.append(Line2D([-1], [-1], linestyle="-", marker="None", linewidth=3,
                     color=color_pale["grey"], 
                     label=r"True")
             )
handles.append(Line2D([-1], [-1], linestyle="--", marker="None", linewidth=3,
                     color=color_dark["grey"], 
                     label=r"GNN")
             )
handles.append((Line2D([-1], [-1], linestyle="None", marker="o", markersize=5, markeredgewidth=1,
                      markeredgecolor='k', color=color_dark["grey"], alpha=0.3),
                Line2D([-1], [-1], linestyle="None", marker="^", markersize=5, markeredgewidth=1,
                      markeredgecolor='k', color=color_dark["grey"], alpha=0.3))
             )
handles.append(Line2D([-1], [-1], linestyle="None", marker="s", markersize=12,
                     color=color_pale["blue"])
             )
handles.append(Line2D([-1], [-1], linestyle="None", marker="s", markersize=12,
                     color=color_pale["red"])
             )
ax[1].legend(handles=handles, 
             labels=[r"GT", r"GNN", r"MLE", "Infection", "Recovery"],
             handler_map={tuple: HandlerTuple(ndivide=None)},
             loc="upper right", fancybox=True, fontsize=14, framealpha=0.75, ncol=1
)
label_plot(ax[0], r"\textbf{(a)}", loc="upper left")
label_plot(ax[1], r"\textbf{(b)}", loc="upper left")
ax[0].set_title(r"\textbf{Simple}", fontsize=large_fontsize)
ax[1].set_title(r"\textbf{Complex}", fontsize=large_fontsize)

plt.tight_layout(0.)

figname = "manuscript-figure1"
if save_png:
    fig.savefig(os.path.join("png", f"{figname}.png"))
if save_pdf:
    fig.savefig(os.path.join("pdf", f"{figname}.pdf"))
if save_svg:
    fig.savefig(os.path.join("svg", f"{figname}.svg"))