In [1]:
from itertools import product
from display import *

save_svg, save_png, save_pdf = False, False, False

  self[key] = other[key]


## Loading data

In [2]:
import pickle
from dynalearn.experiments import Experiment

path_to_covid = "../../data/covid/"
lag = 5
bias = 0.5
models = ["DynamicsGATConv", "IndependentGNN", "FullyConnectedGNN"]

path = lambda m: os.path.join(path_to_covid, f"summaries/exp-rnn-{m}-l{lag}-b{bias}.zip")

covid_exp = {}

for m in models:
    covid_exp[m] = Experiment.unzip(path(m), label_with_mode=False)
p = os.path.join(path_to_covid, f"summaries/exp-Kapoor2020.zip")
covid_exp["kapoor"] = Experiment.unzip(p, label_with_mode=False)

Did not find file `exp-sis-ba.zip`, kept proceding.
Did not find file `exp-plancksis-ba.zip`, kept proceding.


## Additional functions

In [None]:
from dynalearn.utilities import color_dark, color_pale
from dynalearn.dynamics import VARDynamics

prov_index = {k: i for i, k in enumerate(covid_data["cases"].keys())}
index_prov = {i: k for i, k in enumerate(covid_data["cases"].keys())}

def ridgeline_diff(
    dataref, 
    data, 
    x=None, 
    ax=None, 
    xticks=None, 
    overlap=0, 
    yshift=0., 
    ocolor=color_dark["blue"], 
    ucolor=color_dark["red"], 
    zorder=0, 
    alpha=0.1, 
    withline=True, 
    witharea=True,
    order=None,
    skip=None,
    linestyles=None
):
    
    if ax is None:
        ax = plt.gca()
    if order is None:
        order = list(dataref.keys())
    curves = []
    ys = []
    labels = []
    assert dataref.keys() == data.keys()
    if skip is None:
        skip = []
    if linestyles is None:
        linestyles = ["-", "--"]
    
    for i, k in enumerate(order):
        v = dataref[k] * 1
        vv = data[k] * 1
        labels.append(k)
        valmax = max(v.max(), vv.max())
        d = v / valmax * (1. + overlap)
        dd = vv / valmax * (1. + overlap)
        if x is None:
            x = np.arange(len(d))
        y = i + yshift
        ys.append(y)
        overshot = np.zeros(d.shape)
        overshot[dd > d] = dd[dd > d] * 1
        overshot[dd < d] = d[dd < d] * 1
        
        undershot = np.zeros(d.shape)
        undershot[dd > d] = d[dd > d] * 1
        undershot[dd < d] = dd[dd < d] * 1
        
        if k in skip:
            continue
        if witharea:
            ax.fill_between(
                x, 
                d + y, 
                overshot + y, 
                zorder=zorder+len(data)-i+1, 
                color=ocolor, 
                alpha=alpha
            )
            ax.fill_between(
                x, 
                undershot + y, 
                d + y, 
                zorder=zorder+len(data)-i+1, 
                color=ucolor, 
                alpha=alpha
            )
            ax.fill_between(
                x, 
                np.ones(len(undershot)) * y, 
                undershot + y, 
                zorder=zorder+len(data)-i+1, 
                color=color_dark["grey"], 
                alpha=alpha
            )
            
        if withline:
            ax.plot(x, d + y, c="k", zorder=len(data)-i+1, lw=1.2, ls=linestyles[0])
            ax.plot(x, dd + y, c="k", zorder=len(data)-i+1, lw=1.2, ls=linestyles[1])
    ax.set_yticks(ys)
    ax.set_yticklabels(labels, fontsize=8)
    ax.set_xticks(list(xticks.values()))
    ax.set_xticklabels(list(xticks.keys()), rotation = 45, fontsize=14)
    return y
            
def plot_timeseries(target, pred, ax=None, x=None, index_dict=None, palette=color_dark):
    if index_dict is None:
        index_dict = index_prov
    target_dict = {k: target[:,i].squeeze() for i, k in index_dict.items() if i < 52}
    pred_dict = {k: pred[:,i].squeeze() for i, k in index_dict.items() if i < 52}
    order = []
    order +=[l + rf" (top {52 - i})" for i, l in list(enumerate(worst_l))] 
    order += ["..."]
    order += [l + rf" (top {i + 1})" for i, l in reversed(list(enumerate(best_l)))]
    ridgeline_diff(
        target_dict, 
        pred_dict, 
        ax=ax, 
        x=x, 
        xticks=cumul_times, 
        withline=False, 
        witharea=True, 
        ocolor=palette["red"], 
        ucolor=palette["blue"], 
        alpha=0.5,
        skip=[r"..."],
        order=order,
    )

def forecast(model, dataset, steps=7):
    w = np.ones(dataset.state_weights[0].data.shape)

    inputs = dataset.inputs[0].data[w > 0]
    targets = dataset.targets[0].data[w > 0]
    
    pred = np.zeros(targets.shape)
    for t in range(pred.shape[0]):
        if t % steps == 0:
            x0 = inputs[t]
        else:
            x0 = np.roll(x0, -1, axis=0)
            x0[:,:,-1] = pred[t - 1,:,:]
        pred[t] = model.predict(x0)
    return targets, pred

def get_best(target, pred, index_dict, n=1, best=True):
    m = target.mean(axis=0, keepdims=True)
    s = target.std(axis=0, keepdims=True)
    target = (target - m) / s
    pred = (pred - m) / s
    target_dict = {k: target[:,i].squeeze() for i, k in index_dict.items() if i < 52}
    pred_dict = {k: pred[:,i].squeeze() for i, k in index_dict.items() if i < 52}
    diff_dict = {k: target[:,i].squeeze() - pred[:,i].squeeze() for i, k in index_prov.items() if i < 52}
    criterion = {k: np.mean(v**2) * (-1)**(1 - best) for k, v in diff_dict.items()}
    keep = []
    val = []
    for i in range(n):
        k = min(criterion, key=criterion.get)
        val.append(criterion.pop(k))
        keep.append(k)
    return keep, val

## Post-processing

In [None]:
gnn = covid_exp["DynamicsGATConv"].model
kapoor = covid_exp["OG-KapoorConv"].model
var = VARDynamics(covid_exp["DynamicsGATConv"].model.num_states, config=covid_exp["DynamicsGATConv"].config.model)
metapop = covid_exp["DynamicsGATConv"].dynamics
network = covid_exp["DynamicsGATConv"].dataset.networks[0].data
gnn.network = network
var.network = network
kapoor.network = covid_exp["OG-KapoorConv"].dataset.networks[0].data
metapop.network = network

c = covid_exp["DynamicsGATConv"].dataset.state_weights[0].data > 0
X = covid_exp["DynamicsGATConv"].dataset.inputs[0].data
Y = covid_exp["DynamicsGATConv"].dataset.targets[0].data

num_steps = 1
var.fit(X[c], Y=Y[c])
yt, yp = {}, {}
for n, m in zip(["var", "gnn", "kapoor"], [var, gnn, kapoor]):
    if n == "kapoor":
        dd = covid_exp["OG-KapoorConv"].dataset
    else:
        dd = covid_exp["DynamicsGATConv"].dataset
    yt[n], yp[n] = forecast(m, dd, steps=num_steps)
yp["metapop"] = np.zeros(Y.shape)
metapop.latent_state = np.zeros((Y.shape[1], 3))
metapop.latent_state[:, 0] = 1
susceptible = np.ones(Y.shape[1])
infected = np.zeros(Y.shape[1])
recovered = np.zeros(Y.shape[1])
for i, x in enumerate(X):
    if i % num_steps == 0:
        metapop.latent_state[:, 0] = susceptible
        metapop.latent_state[:, 1] = infected
        metapop.latent_state[:, 2] = recovered
        x0 = x
    else:
        x0 = np.roll(x0, -1, axis=0)
        x0[:,:,-1] = y
    y = metapop.predict(x0)
    yp["metapop"][i] = y
    
    infected += Y[i].squeeze() / metapop.population - metapop.recovery_prob * infected
    recovered += metapop.recovery_prob * infected

    
best_l, best_v = get_best(yt["gnn"], yp["gnn"], index_prov, n=5, best=True)
worst_l, worst_v = get_best(yt["gnn"], yp["gnn"], index_prov, n=5, best=False)
best_dict = {prov_index[k] : k + rf" (top {i + 1})" for i, k in enumerate(best_l)}
worst_dict = {prov_index[k] : k + rf" (top {52 - i})" for i, k in enumerate(worst_l)}
index = best_dict.copy()
index.update({-1: r"..."})
index.update(worst_dict)

## Making the plot

In [None]:

def get_dataset(key, dataset):
    if dataset == "train":
        return covid_exp[key].dataset
    elif dataset == "val":
        return covid_exp[key].val_dataset
    elif dataset == "test":
        return covid_exp[key].test_dataset

def get_true_target(dataset, steps=1):
    w = dataset.state_weights[0].data
    y = dataset.targets[0].data[w > 0]
    return y[steps-1:]

def score(x, y):
    x = x.flatten()
    y = y.flatten()
    return pearsonr(x, y)[0]


for nopoints in [True, False]:
    y = 0.94
    x = 0.81
    dd = 0.25
    s = 1
    fig, ax = plt.subplots(2, 1, figsize=(5, 8), sharex=True, sharey=True)
    handles = []
    for a, k in zip(ax, ["train", "test"]):
        d = get_dataset("DynamicsGATConv", k)
        true = get_true_target(d, steps=s)
        xmin, xmax = true.min(), true.max()
        line = np.linspace(xmin, xmax, 100)
        pred = covid_exp["DynamicsGATConv"].metrics["GNNForecastMetrics"].data[f"{k}-{s}"]
        r = np.round(score(true, pred), 4)
        n = pred.flatten().shape[0]

        a.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        label = r"\textbf{GNN} ($\boldsymbol{r = %6.4f}$)" % r
        if not nopoints:
            a.scatter(true.flatten(), pred.flatten(), s=2, marker=".", color=color_dark["blue"])
        else:
            handles = [
                Line2D([0], [0], marker="s", markersize=12, color=color_dark["blue"], linestyle="None", label=label)
            ]
        
        d = get_dataset("OG-KapoorConv", k)
        rect = [x, y, 0.28, 0.25]
        true = get_true_target(d, steps=s)
        pred = covid_exp["OG-KapoorConv"].metrics["GNNForecastMetrics"].data[f"{k}-{s}"]
        r = np.round(score(true, pred), 4)
        a4 = add_subplot_axes(a, rect)
        if not nopoints:
            a4.scatter(true.flatten(), pred.flatten(), s=2, marker=".", color=color_dark["orange"])
        else:
            handles.append(
                Line2D([0], [0], marker="s", markersize=12, color=color_dark["orange"], linestyle="None", label=rf"KP-GNN ($r = {r}$)")
            )
        a4.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        a4.tick_params(labelsize=8)

        d = get_dataset("IndependentGNN", k)
        rect[1] -= dd
        true = get_true_target(d, steps=s)
        pred = covid_exp["IndependentGNN"].metrics["GNNForecastMetrics"].data[f"{k}-{s}"]
        r = np.round(score(true, pred), 4)
        a2 = add_subplot_axes(a, rect)
        a2.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        if not nopoints:
            a2.scatter(true.flatten(), pred.flatten(), s=2, marker=".", color=color_pale["red"])
        else:
            handles.append(
                Line2D([0], [0], marker="s", markersize=12, color=color_pale["red"], linestyle="None", label=rf"IND ($r = {r}$)")
            )

        a2.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        a2.tick_params(labelsize=8)
        a2.set_xticks([])
        
        d = get_dataset("FullyConnectedGNN", k)
        rect[1] -= dd
        true = get_true_target(d, steps=s)
        pred = covid_exp["FullyConnectedGNN"].metrics["GNNForecastMetrics"].data[f"{k}-{s}"]
        r = np.round(score(true, pred), 4)
        a3 = add_subplot_axes(a, rect)
        if not nopoints:
            a3.scatter(true.flatten(), pred.flatten(), s=2, marker=".", color=color_dark["purple"])
        else:
            handles.append(
                Line2D([0], [0], marker="s", markersize=12, color=color_dark["purple"], linestyle="None", label=rf"FC ($r = {r}$)")
            )
        a3.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        a3.tick_params(labelsize=8)
        a3.set_xticks([])

        pred = covid_exp["DynamicsGATConv"].metrics["VARForecastMetrics"].data[f"{k}-{s}"]
        rect[1] -= dd
        r = np.round(score(true, pred), 4)
        a1 = add_subplot_axes(a, rect)
        a1.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        if not nopoints:
            a1.scatter(true.flatten(), pred.flatten(), s=2, marker=".", color=color_dark["green"])
        else:
            handles.append(
                Line2D([0], [0], marker="s", markersize=12, color=color_dark["green"], linestyle="None", label=rf"VAR ($r = {r}$)")
            )
        a1.plot(line, line, linestyle="--", lw=1, alpha=1, color=color_dark["grey"])
        a1.tick_params(labelsize=8)
        a1.set_xticks([])



        if k == "train":
            a1.axvspan(0, 2 * true.max(), alpha=0.15, color=color_pale["grey"])
            a2.axvspan(0, 2 * true.max(), alpha=0.15, color=color_pale["grey"])
            a3.axvspan(0, 2 * true.max(), alpha=0.15, color=color_pale["grey"])
            a4.axvspan(0, 2 * true.max(), alpha=0.15, color=color_pale["grey"])
            a.axvspan(0, 2 * true.max(), alpha=0.15, color=color_pale["grey"])

        a.set_xlim([true.min(), true.max()])
        a.set_ylim([true.min(), true.max()])
        a1.set_xlim([true.min(), true.max()])
        a1.set_ylim([true.min(), true.max()])
        a2.set_xlim([true.min(), true.max()])
        a2.set_ylim([true.min(), true.max()])
        a3.set_xlim([true.min(), true.max()])
        a3.set_ylim([true.min(), true.max()])
        a4.set_xlim([true.min(), true.max()])
        a4.set_ylim([true.min(), true.max()])
    #     set_title(a, s, k)
        a.tick_params(labelsize=14)
#         a1.tick_params(labelsize=8)
#         a2.tick_params(labelsize=8)
#         a3.tick_params(labelsize=8)
        if nopoints:
            a.legend(handles=handles, loc="upper left", fontsize=12)


        y -= 0.15
    ax[0].set_title("Daily Forecasts (in-sample)", fontsize=18)
    ax[1].set_title("Daily Forecasts (out-of-sample)", fontsize=18)
    ax[0].set_ylabel(r"Predicted incidence [$\hat{y}_i(t)$]", fontsize=16)
    ax[1].set_ylabel(r"Predicted incidence [$\hat{y}_i(t)$]", fontsize=16)
    ax[1].set_xlabel(r"Target incidence [$y_i(t)$]", fontsize=16)

    label_plot(ax[0], r"(a)", (0.95, 0.02, "bottom", "right"), fontsize=16)
    label_plot(ax[1], r"(b)", (0.95, 0.02, "bottom", "right"), fontsize=16)
    plt.tight_layout()
    filename = "manuscript-figure5ab"
    if nopoints:
        fig.savefig(f"svg/{filename}.svg", dpi=150)
    else:
        fig.savefig(f"png/{filename}.png", dpi=150)
        

In [None]:
fig, _ax = plt.subplots(2, 1, figsize=(16, 8), sharex=True)
ax = _ax[0]
plot_timeseries(yt["gnn"], yp["gnn"], ax=ax, index_dict=index, palette=color_dark)
ax.axvspan(0, 335, alpha=0.15, color=color_pale["grey"])
ax.set_xlim([0, 450 - lag])
ax.set_ylabel(r"Incidence", fontsize=18)
ax.tick_params(axis="x", labelsize=14)
ax.tick_params(axis="y", labelsize=12)

handles = [
    Line2D([0], [0], color=color_pale["blue"], label="GNN under.", marker="s", linestyle="None", markersize=12),
    Line2D([0], [0], color=color_pale["red"], label="GNN over.", marker="s", linestyle="None", markersize=12),
]
ax.legend(handles=handles, fontsize=16, loc="upper left")
label_plot(ax, r"(c)", loc=(.04, 0.05, "bottom", "right"), fontsize=16)
label_plot(ax, r"In-sample", loc=(335 / 445 * 0.94, 0.9, "bottom", "center"), fontsize=16)
label_plot(ax, r"Out-of-sample", loc=(335 / 445 + (445 - 335)/444 *0.75, 0.9, "bottom", "center"), fontsize=16)
plt.tight_layout()


# fig, ax = plt.subplots(1, 1, figsize=(14, 4))
ax = _ax[1]

true = yt["gnn"].sum((1, 2))
gnn = yp["gnn"].sum((1, 2))
kapoor = yp["kapoor"].sum((1, 2))
var = yp["var"].sum((1, 2))
metapop = yp["metapop"].sum((1, 2))



ax.plot(kapoor, color=color_dark["orange"], linewidth=2, linestyle="--", marker="None")
ax.plot(var, color=color_dark["green"], linewidth=2, linestyle="dotted", marker="None")
axx = ax.twinx()
axx.plot(metapop / 52, color=color_dark["red"], linewidth=2, linestyle="-.", marker="None")

ax.plot(true, color=color_dark["grey"], linewidth=3, linestyle="-", marker="None")
ax.plot(gnn, color=color_dark["blue"], linewidth=2, linestyle="--", marker="None")

ax.set_xticks(list(cumul_times.values()))
ax.set_xticklabels(list(cumul_times.keys()), rotation = 45, fontsize=14)
# xmin, xmax = min(list(cumul_times.values())), max(list(cumul_times.values()))
# ax.set_xlim([xmin, xmax])
ax.set_xlim([0, 450 - lag])
ax.set_ylim([0, ax.get_ylim()[1]])
axx.set_ylim([0, axx.get_ylim()[1]])


ax.set_ylabel(r"Global incidence", fontsize=18)
axx.set_ylabel(r"Global incidence (Metapop.)", fontsize=18, rotation=90)
ax.tick_params(axis="x", labelsize=14)
ax.tick_params(axis="y", labelsize=14)
axx.tick_params(axis="y", labelsize=14)
ax.axvspan(0, 335, alpha=0.15, color=color_pale["grey"])

ax.ticklabel_format(style="sci", scilimits=(0,0), axis="y")
ax.tick_params(axis='y', labelsize=14, colors=color_dark["grey"], length=5, width=2)
ax.spines['left'].set_color(color_dark["grey"])
ax.spines['left'].set_linewidth(3)

axx.ticklabel_format(style="sci", scilimits=(0,0), axis="y")
axx.tick_params(axis='y', labelsize=14, colors=color_dark["red"], length=5, width=2)
axx.spines['right'].set_color(color_dark["red"])
axx.spines['right'].set_linewidth(3)

handles = [
    Line2D([0], [0], color=color_dark["grey"], label="GT", marker="None", linestyle="-", linewidth=3),
    Line2D([0], [0], color=color_dark["blue"], label="GNN", marker="None", linestyle="--", linewidth=2),
    Line2D([0], [0], color=color_dark["orange"], label="KP-GNN", marker="None", linestyle="--", linewidth=2),
    Line2D([0], [0], color=color_dark["green"], label="VAR", marker="None", linestyle="dotted", linewidth=2),
    Line2D([0], [0], color=color_dark["red"], label="Metapop.", marker="None", linestyle="-.", linewidth=2),
]
ax.legend(handles=handles, fontsize=16, loc="upper left")
label_plot(ax, r"(d)", loc=(.04, 0.05, "bottom", "right"), fontsize=16)
label_plot(ax, r"In-sample", loc=(335 / 445 * 0.94, 0.9, "bottom", "center"), fontsize=16)
label_plot(ax, r"Out-of-sample", loc=(335 / 445 + (445 - 335)/444 * 0.75, 0.9, "bottom", "center"), fontsize=16)
plt.tight_layout()

filename = "manuscript-figure5cd"
if nopoints:
    fig.savefig(f"svg/{filename}.svg", dpi=150)
else:
    fig.savefig(f"png/{filename}.png", dpi=150)

