In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import ndpretty

import data_prep
import gnn_prep
from gnn_prep import Experiment
import widget_util

import plotly.graph_objects as go
import plotly.express as px
import itertools

In [2]:
ndpretty.default()
data = data_prep.load_hebei_city_data() # data_prep.load_beijing_data()

HBox(children=(IntProgress(value=0, max=13), HTML(value='')))


Loaded air quality data from 13 devices. No weather data for []


# Analyse experiment results

In [3]:
gnn_data_path="Previous work/PM2.5-GNN/data/"
gnn_experiment_path = gnn_data_path + "results/"

dataset_name = "DS3_city"

devices_path = f"{gnn_data_path}city_{dataset_name}.txt"
city_txt = pd.read_csv(devices_path, sep=' ', header=None, index_col=0)
devices = list(city_txt[1])

In [4]:
experiments = {
    # "PM25_GNN": "0_336/1/PM25_GNN/2021-03-29_09-16-24/01/",
    # **{"SGNN3_0-24_%02d" % i: "0_24/1/SplitGNN_3/2021-04-07_23-09-24/%02d/" % i for i in range(5)},
    # **{"SGNN3_0-336_%02d" % i: "0_336/1/SplitGNN_3/2021-04-19_09-49-03/%02d/" % i for i in range(5)},
    **{"SGNN3_0_168_%02d" % i: "0_168/1/SplitGNN_3/2021-04-21_07-51-46/%02d/" % i for i in range(5)}
}

start_date = "2020-01-01 01:00:00"
end_date = "2020-03-01 01:00:00"
for name, path in tqdm(experiments.items()):
    experiments[name] = Experiment(name, gnn_experiment_path + path, start_date, devices)

print(f"Loaded {len(experiments)} experiments")

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))


Loaded 5 experiments


In [19]:
df = data.measurements.copy()
df.index = df['time']
ground_truth = df[df['device_id'].isin(devices)]

In [20]:
device_widget = widgets.Select(
    options=devices,
    rows=5,
    description='Device'
)

exp_widget = widgets.SelectMultiple(
    options=experiments.keys(),
    rows=10,
    description='Exps'
)

date_widget = widget_util.get_date_slider(data, start=start_date)

In [21]:
def plot_results(ground_truth, exps, selected_exps, device, date_range, feature, events):
    plt.figure(figsize=(10,5))

    # plot events
    for event in events:
        plt.axvspan(event.start, event.end, color=event.color, alpha=0.1, label=event.name)

    # plot ground truth
    (start, end) = date_range
    y = ground_truth[(ground_truth['device_id'] == device) & (ground_truth['time'] >= start) & (ground_truth['time'] < end)]
    plt.plot(y['time'], y[feature], label='ground truth')

    # plot experiment predictions
    for e in selected_exps:
        exp = exps[e]
        df = exp.preds
        y = df[(df['device_id'] == device) & (df['time'] >= start) & (df['time'] < end)]
        plt.plot(y['time'], y[feature], label=e)

    plt.tight_layout()
    plt.title(f"{feature} @ {device}")
    plt.legend(loc='upper left')
    axes = plt.gca()
    axes.set_xlim([start,end])
    plt.show()

    # plot R values for device
    for e in selected_exps:
        R_in, R_out = exps[e].R_in, exps[e].R_out
        if R_in is not None and R_out is not None:
            R_in[device].plot(kind='area', title=f"R_in for {e} @ {device}")
            R_out[device].plot(kind='area', title=f"R_out for {e} @ {device}")

interact_manual(plot_results, ground_truth=fixed(ground_truth), exps=fixed(experiments), selected_exps=exp_widget, device=device_widget, date_range=date_widget, feature=fixed('pm25'), events=fixed(data.events))

interactive(children=(SelectMultiple(description='Exps', options=('SGNN3_0_168_00', 'SGNN3_0_168_01', 'SGNN3_0…

<function __main__.plot_results(ground_truth, exps, selected_exps, device, date_range, feature, events)>

## Sankey diagram

In [117]:
exp = experiments["SGNN3_0_168_00"]
start_time_idx = 0
pred_time_idx = 0
plot_threshold = 0.2 # plot only edges with an r_i higher than this threshold

In [118]:
labels = exp.devices
R = exp.R_npy[start_time_idx, pred_time_idx]

target = list(itertools.chain(*[[i] * len(labels) for i in range(len(labels))]))
source = list(range(len(labels))) * len(labels)
values = list(R.flatten())

Filter according to threshold.

In [119]:
target_filtered, source_filtered, values_filtered = [], [], []

for t, s, v in zip(target, source, values):
    if v > plot_threshold:
        target_filtered.append(t)
        source_filtered.append(s)
        values_filtered.append(v)

print(f"Plotting {len(source_filtered)} edges")

Plotting 24 edges


In [120]:
colormap = px.colors.cyclical.IceFire
assert len(colormap) > len(labels), "Not enough colours for nodes"

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      label = labels,
      color = list(map(lambda i: colormap[i], range(len(labels)))),
      hoverinfo = 'none'
    ),
    link = dict(
      source = source_filtered,
      target = target_filtered,
      value = values_filtered,
      color = list(map(lambda i: colormap[i], source_filtered)),
      #hoverinfo = 'none'
  ))], layout={'width': 1200, 'height': 700})

fig.update_layout(font_size=10)
fig.show()