In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import ndpretty

import data_prep
import gnn_prep
from gnn_prep import Experiment
from ensemble import ensemblify, deep_ensemblify
import widget_util

import plotly.graph_objects as go
import plotly.express as px
import itertools

In [2]:
ndpretty.default()

# Analyse experiment results

In [4]:
gnn_data_path="Previous work/PM2.5-GNN/data/"
gnn_experiment_path = gnn_data_path + "results/"

devices_path = f"{gnn_data_path}city_{dataset_name}.txt"
city_txt = pd.read_csv(devices_path, sep=' ', header=None, index_col=0)
devices = list(city_txt[1])

In [3]:
dataset_name = "DS2" # "DS3_city"

if dataset_name == "DS2":
    data = data_prep.load_beijing_data()
elif dataset_name == "DS3_city":
    data = data_prep.load_hebei_city_data() 
else:
    raise RuntimeError("Unknown dataset")

HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Loaded air quality data from 87 devices. No weather data for []


In [5]:
exp_paths = {
    "PM25_GNN": "0_336/1/PM25_GNN/2021-03-29_09-16-24/01/",
    #**{"SGNN3_0-24_%02d" % i: "0_24/1/SplitGNN_3/2021-04-07_23-09-24/%02d/" % i for i in range(5)},
    **{"SGNN3_0-336_%02d" % i: "0_336/1/SplitGNN_3/2021-04-19_09-49-03/%02d/" % i for i in range(5)}, # good runs!
    # **{"SGNN3_0_168_%02d" % i: "0_168/1/SplitGNN_3/2021-04-21_07-51-46/%02d/" % i for i in range(5)},
    #**{"SGNN3_0_24_%02d" % i: "0_24/1/SplitGNN_3/2021-04-30_07-25-12/%02d/" % i for i in range(5)},
    #**{"SGNN3_1_24_%02d" % i: "1_24/1/SplitGNN_3/2021-04-30_08-37-48/%02d/" % i for i in range(5)},
    #**{"SGNN3.2_1_24_%02d" % i: "1_24/1/SplitGNN_3_2/2021-05-04_09-00-51/%02d/" % i for i in range(5)},
    #**{"PM25_GNN_1_24_%02d" % i: "1_24/1/PM25_GNN/2021-05-04_09-01-57/%02d/" % i for i in range(5)},
}

start_date = "2020-10-01 01:00:00"
end_date = "2020-12-31 01:00:00"

experiments = {}
for name, path in tqdm(exp_paths.items()):
    experiments[name] = Experiment.load(name, gnn_experiment_path + path, start_date, devices)
    # experiments[name + " - with start"] = Experiment.load(name, gnn_experiment_path + path, start_date, devices)
    # experiments[name + " - first pred"] = Experiment.load(name, gnn_experiment_path + path, None, devices)


print(f"Loaded {len(experiments)} experiments")

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))


Loaded 12 experiments


In [13]:
experiments["PM25_GNN - with start"].pred_npy

5536×336×12×1 float32 ndarray


interactive(children=(Text(value='[:100, :100, 0, 0]', description='Slice:', placeholder='e.g. [:100, :100, 0,…



In [6]:
# ensemble_name = "SGNN3_0_168_ensemble"
# ensemble = ensemblify(experiments, ensemble_name)

# deep_ensemble_name = "SGNN3_0_168_deep_ensemble"
# experiments[deep_ensemble_name] = deep_ensemblify(experiments, deep_ensemble_name)
# experiments[ensemble_name] = ensemble

In [7]:
df = data.measurements.copy()
df.index = df['time']
ground_truth = df[df['device_id'].isin(devices)]

In [8]:
device_widget = widgets.Select(
    options=devices,
    rows=5,
    description='Device'
)

exp_widget = widgets.SelectMultiple(
    options=experiments.keys(),
    rows=10,
    description='Exps'
)

date_widget = widget_util.get_date_slider(data, start=None)

In [11]:
experiments

{'PM25_GNN - with start': Experiment 'PM25_GNN',
 'PM25_GNN - first pred': Experiment 'PM25_GNN',
 'SGNN3_0-336_00 - with start': Experiment 'SGNN3_0-336_00',
 'SGNN3_0-336_00 - first pred': Experiment 'SGNN3_0-336_00',
 'SGNN3_0-336_01 - with start': Experiment 'SGNN3_0-336_01',
 'SGNN3_0-336_01 - first pred': Experiment 'SGNN3_0-336_01',
 'SGNN3_0-336_02 - with start': Experiment 'SGNN3_0-336_02',
 'SGNN3_0-336_02 - first pred': Experiment 'SGNN3_0-336_02',
 'SGNN3_0-336_03 - with start': Experiment 'SGNN3_0-336_03',
 'SGNN3_0-336_03 - first pred': Experiment 'SGNN3_0-336_03',
 'SGNN3_0-336_04 - with start': Experiment 'SGNN3_0-336_04',
 'SGNN3_0-336_04 - first pred': Experiment 'SGNN3_0-336_04'}

In [9]:
def plot_results(ground_truth, exps, selected_exps, device, date_range, feature, events):
    plt.figure(figsize=(10,5))

    # plot events
    for event in events:
        plt.axvspan(event.start, event.end, color=event.color, alpha=0.1, label=event.name)

    # plot ground truth
    (start, end) = date_range
    y = ground_truth[(ground_truth['device_id'] == device) & (ground_truth['time'] >= start) & (ground_truth['time'] < end)]
    plt.plot(y['time'], y[feature], label='ground truth')

    # plot experiment predictions
    for e in selected_exps:
        exp = exps[e]
        df = exp.preds
        y = df[(df['device_id'] == device) & (df['time'] >= start) & (df['time'] < end)]
        plt.plot(y['time'], y[feature], label=e)

    plt.tight_layout()
    plt.title(f"{feature} @ {device}")
    plt.legend(loc='upper left')
    axes = plt.gca()
    axes.set_xlim([start,end])
    plt.show()

    # plot R values for device
    for e in selected_exps:
        R_in, R_out = exps[e].R_in, exps[e].R_out
        if R_in is not None and R_out is not None:
            R_in[device].plot(kind='area', title=f"R_in for {e} @ {device}")
            R_out[device].plot(kind='area', title=f"R_out for {e} @ {device}")

interact_manual(plot_results, ground_truth=fixed(ground_truth), exps=fixed(experiments), selected_exps=exp_widget, device=device_widget, date_range=date_widget, feature=fixed('pm25'), events=fixed(data.events))

interactive(children=(SelectMultiple(description='Exps', options=('PM25_GNN - with start', 'PM25_GNN - first p…

<function __main__.plot_results(ground_truth, exps, selected_exps, device, date_range, feature, events)>

## Sankey diagram

In [10]:
exp = experiments["SGNN3_0_168_00"]
start_time_idx = 0
pred_time_idx = 0
plot_threshold = 0.3 # plot only edges with an r_i higher than this threshold

KeyError: 'SGNN3_0_168_00'

In [17]:
labels = exp.devices
R = exp.R_npy[start_time_idx, pred_time_idx]

target = list(itertools.chain(*[[i] * len(labels) for i in range(len(labels))]))
source = list(range(len(labels))) * len(labels)
values = list(R.flatten())

Filter according to threshold.

In [18]:
target_filtered, source_filtered, values_filtered = [], [], []

for t, s, v in zip(target, source, values):
    if v > plot_threshold:
        target_filtered.append(t)
        source_filtered.append(s)
        values_filtered.append(v)

print(f"Plotting {len(source_filtered)} edges")

Plotting 16 edges


In [19]:
colormap = px.colors.cyclical.IceFire
assert len(colormap) > len(labels), "Not enough colours for nodes"

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      label = labels,
      color = list(map(lambda i: colormap[i], range(len(labels)))),
      hoverinfo = 'none'
    ),
    link = dict(
      source = source_filtered,
      target = target_filtered,
      value = values_filtered,
      color = list(map(lambda i: colormap[i], source_filtered)),
      #hoverinfo = 'none'
  ))], layout={'width': 1200, 'height': 700})

fig.update_layout(font_size=10)
fig.show()

## New plotting

In [31]:
exp = experiments["SGNN3_0-336_00"]
exp

Experiment 'SGNN3_0-336_00'

In [39]:
exp.devices

['1029A',
 '1037A',
 '1046A',
 '1056A',
 '1061A',
 '1065A',
 '1068A',
 '1072A',
 'changping',
 'daxing',
 'qianmen',
 'shunyi']

In [35]:
time_df = pd.DataFrame(exp.time_npy).applymap(pd.Timestamp.fromtimestamp)

In [42]:
exp.preds

Unnamed: 0_level_0,device_id,pm25,time
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2020-12-02 01:00:00,1029A,51.640736,2020-12-02 01:00:00
2020-12-02 02:00:00,1029A,49.936935,2020-12-02 02:00:00
2020-12-02 03:00:00,1029A,53.263573,2020-12-02 03:00:00
2020-12-02 04:00:00,1029A,56.981144,2020-12-02 04:00:00
2020-12-02 05:00:00,1029A,60.321495,2020-12-02 05:00:00
...,...,...,...
2020-12-19 18:00:00,shunyi,46.884476,2020-12-19 18:00:00
2020-12-19 19:00:00,shunyi,49.094585,2020-12-19 19:00:00
2020-12-19 20:00:00,shunyi,50.890945,2020-12-19 20:00:00
2020-12-19 21:00:00,shunyi,51.976299,2020-12-19 21:00:00


In [45]:
pred_df.melt(ignore_index=False, var_name='device_id', value_name='pm25')

Unnamed: 0_level_0,device_id,pm25
0,Unnamed: 1_level_1,Unnamed: 2_level_1
2020-01-01 01:00:00,1029A,52.416451
2020-01-01 02:00:00,1029A,52.635353
2020-01-01 03:00:00,1029A,52.768593
2020-01-01 04:00:00,1029A,53.352798
2020-01-01 05:00:00,1029A,54.792267
...,...,...
2020-12-05 22:00:00,shunyi,66.045181
2020-12-05 23:00:00,shunyi,65.519463
2020-12-06 00:00:00,shunyi,64.921768
2020-12-06 01:00:00,shunyi,60.260269


In [32]:
exp.pred_npy

5536×336×12×1 float32 ndarray


interactive(children=(Text(value='[:100, :100, 0, 0]', description='Slice:', placeholder='e.g. [:100, :100, 0,…

