In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import ndpretty

import data_prep
import gnn_prep
import widget_util

In [2]:
ndpretty.default()
data = data_prep.load_beijing_data()

HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Loaded air quality data from 87 devices. No weather data for []


# Analyse experiment results

In [3]:
gnn_data_path="Previous work/PM2.5-GNN/data/"
gnn_experiment_path = gnn_data_path + "results/"

dataset_name = "DS2"

devices_path = f"{gnn_data_path}city_{dataset_name}.txt"
city_txt = pd.read_csv(devices_path, sep=' ', header=None, index_col=0)
devices = list(city_txt[1])

In [4]:
class Experiment:
    def __init__(self, name, path, start_date, devices):
        self.name = name
        self.path = path
        self.start_date = start_date
        self.devices = devices
        self.preds, self.R_in, self.R_out = self.load_experiment()

    def load_experiment(self):
        time_npy = np.load(self.path + 'time.npy')
        pred_npy = np.load(self.path + 'predict.npy')

        preds = gnn_prep.transform_preds_back(time_npy, pred_npy, self.devices, self.start_date, melted=True)

        try:
            R_npy = np.load(self.path + 'R.npy')
            R_in, R_out = gnn_prep.transform_R_back(time_npy, R_npy, self.devices, self.start_date)
        except FileNotFoundError as e:
            print(f"Couldn't load R.npy for '{self.name}': {e}")
            R_in, R_out = None, None

        preds['time'] = preds.index
        return preds, R_in, R_out

    def __repr__(self):
        return f"Experiment '{self.name}'"
        

In [5]:
"""experiments = {
    "PM25_GNN": "0_336/1/PM25_GNN/2021-03-29_09-16-24/01/",
    "SGNN3_00": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/00/",
    "SGNN3_01": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/01/",
    "SGNN3_02": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/02/",
    "SGNN3_03": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/03/",
    "SGNN3_04": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/04/",
    "SGNN3.1_00": "0_24/1/SplitGNN_3_1/2021-04-09_10-26-53/00/",
    "SGNN3.1_01": "0_24/1/SplitGNN_3_1/2021-04-09_10-26-53/01/",
    "SGNN3.1_02": "0_24/1/SplitGNN_3_1/2021-04-09_10-26-53/02/",
    "SGNN3.1_03": "0_24/1/SplitGNN_3_1/2021-04-09_10-26-53/03/",
    "SGNN3.2_00": "0_24/1/SplitGNN_3_2/2021-04-09_10-36-49/00/",
    "SGNN3.2_01": "0_24/1/SplitGNN_3_2/2021-04-09_10-36-49/01/",
    "SGNN3.2_02": "0_24/1/SplitGNN_3_2/2021-04-09_10-36-49/02/",
    "SGNN3.2_03": "0_24/1/SplitGNN_3_2/2021-04-09_10-36-49/03/",
    "SGNN3.2'_00": "0_24/1/SplitGNN_3_2/2021-04-10_16-22-36/00/",
    "SGNN3.2'_01": "0_24/1/SplitGNN_3_2/2021-04-10_16-22-36/01/",
    "SGNN3.2'_02": "0_24/1/SplitGNN_3_2/2021-04-10_16-22-36/02/",
    "SGNN3.2'_03": "0_24/1/SplitGNN_3_2/2021-04-10_16-22-36/03/",
    "SGNN3.2'_04": "0_24/1/SplitGNN_3_2/2021-04-10_16-22-36/04/",
    "SGNN3.3_00": "0_24/1/SplitGNN_3_3/2021-04-10_20-09-05/00/",
    "SGNN3.3_01": "0_24/1/SplitGNN_3_3/2021-04-10_20-09-05/01/",
    "SGNN3.3_02": "0_24/1/SplitGNN_3_3/2021-04-10_20-09-05/02/",
    "SGNN3.3_03": "0_24/1/SplitGNN_3_3/2021-04-10_20-09-05/03/",
    "SGNN3.3_04": "0_24/1/SplitGNN_3_3/2021-04-10_20-09-05/04/",
    "SGNN3.4_00 aff=True": "0_24/1/SplitGNN_3_4/2021-04-11_10-58-54/00/",
    "SGNN3.4_01 aff=True": "0_24/1/SplitGNN_3_4/2021-04-11_10-58-54/01/",
    "SGNN3.4_02 aff=True": "0_24/1/SplitGNN_3_4/2021-04-11_10-58-54/02/",
    "SGNN3.4_03 aff=True": "0_24/1/SplitGNN_3_4/2021-04-11_10-58-54/03/",
    "SGNN3.4_04 aff=True": "0_24/1/SplitGNN_3_4/2021-04-11_10-58-54/04/",
    "SGNN3.4_00 aff=False": "0_24/1/SplitGNN_3_4/2021-04-11_13-27-15/00/",
    "SGNN3.4_01 aff=False": "0_24/1/SplitGNN_3_4/2021-04-11_13-27-15/01/",
    "SGNN3.4_02 aff=False": "0_24/1/SplitGNN_3_4/2021-04-11_13-27-15/02/",
    "SGNN3.4_03 aff=False": "0_24/1/SplitGNN_3_4/2021-04-11_13-27-15/03/",
    "SGNN3.4_04 aff=False": "0_24/1/SplitGNN_3_4/2021-04-11_13-27-15/04/",
    "SGNN3.4'_00 aff=True": "0_24/1/SplitGNN_3_4/2021-04-12_07-22-27/00/",
    "SGNN3.4'_01 aff=True": "0_24/1/SplitGNN_3_4/2021-04-12_07-22-27/01/",
    "SGNN3.4'_02 aff=True": "0_24/1/SplitGNN_3_4/2021-04-12_07-22-27/02/",
    "SGNN3.4'_03 aff=True": "0_24/1/SplitGNN_3_4/2021-04-12_07-22-27/03/",
    "SGNN3.4'_04 aff=True": "0_24/1/SplitGNN_3_4/2021-04-12_07-22-27/04/",
    "SGNN3.4'_00 aff=False": "0_24/1/SplitGNN_3_4/2021-04-12_07-21-03/00/",
    "SGNN3.4'_01 aff=False": "0_24/1/SplitGNN_3_4/2021-04-12_07-21-03/01/",
    "SGNN3.4'_02 aff=False": "0_24/1/SplitGNN_3_4/2021-04-12_07-21-03/02/",
    "SGNN3.4'_03 aff=False": "0_24/1/SplitGNN_3_4/2021-04-12_07-21-03/03/",
    "SGNN3.4'_04 aff=False": "0_24/1/SplitGNN_3_4/2021-04-12_07-21-03/04/",
    "SGNN3.4'_00 aff=False - 1_24": "1_24/1/SplitGNN_3_4/2021-04-14_07-56-15/00/",
    "SGNN3.4'_01 aff=False - 1_24": "1_24/1/SplitGNN_3_4/2021-04-14_07-56-15/01/", 
    "SGNN3.4'_02 aff=False - 1_24": "1_24/1/SplitGNN_3_4/2021-04-14_07-56-15/02/",
    "SGNN3.4'_03 aff=False - 1_24": "1_24/1/SplitGNN_3_4/2021-04-14_07-56-15/03/",
    "SGNN3.4'_04 aff=False - 1_24": "1_24/1/SplitGNN_3_4/2021-04-14_07-56-15/04/",
    "SGNN3.4'_00 aff=False - 0_336": "0_336/1/SplitGNN_3_4/2021-04-13_07-43-34/00/",
    "SGNN3.4'_01 aff=False - 0_336": "0_336/1/SplitGNN_3_4/2021-04-13_07-43-34/01/",
    "SGNN3.4'_02 aff=False - 0_336": "0_336/1/SplitGNN_3_4/2021-04-13_07-43-34/02/",
    "SGNN3.4'_03 aff=False - 0_336": "0_336/1/SplitGNN_3_4/2021-04-13_07-43-34/03/",
    "SGNN3.4'_04 aff=False - 0_336": "0_336/1/SplitGNN_3_4/2021-04-13_07-43-34/04/",
    "SGNN3.4'_00 aff=False - 1_336": "1_336/1/SplitGNN_3_4/2021-04-13_19-49-40/00/",
    "SGNN3.4'_01 aff=False - 1_336": "1_336/1/SplitGNN_3_4/2021-04-13_19-49-40/01/",
    "SGNN3.4'_02 aff=False - 1_336": "1_336/1/SplitGNN_3_4/2021-04-13_19-49-40/02/",
    "SGNN3.4'_03 aff=False - 1_336": "1_336/1/SplitGNN_3_4/2021-04-13_19-49-40/03/", # running
    "SGNN3.4'_04 aff=False - 1_336": "1_336/1/SplitGNN_3_4/2021-04-13_19-49-40/04/", # running
}"""

experiments = {
    "PM25_GNN": "0_336/1/PM25_GNN/2021-03-29_09-16-24/01/",
    "SGNN3_00": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/00/",
    "SGNN3_01": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/01/",
    "SGNN3_02": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/02/",
    "SGNN3_03": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/03/",
    "SGNN3_04": "0_24/1/SplitGNN_3/2021-04-07_23-09-24/04/",
}

start_date = "2020-01-01 01:00:00"
end_date = "2020-03-01 01:00:00"
for name, path in tqdm(experiments.items()):
    experiments[name] = Experiment(name, gnn_experiment_path + path, start_date, devices)

print(f"Loaded {len(experiments)} experiments")

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))


Loaded 6 experiments


In [13]:
df = data.measurements.copy()
df.index = df['time']
ground_truth = df[df['device_id'].isin(devices)]

In [14]:
device_widget = widgets.Select(
    options=devices,
    rows=5,
    description='Device'
)

exp_widget = widgets.SelectMultiple(
    options=experiments.keys(),
    rows=10,
    description='Exps'
)

date_widget = widget_util.get_date_slider(data, start=start_date)

In [15]:
def plot_results(ground_truth, exps, selected_exps, device, date_range, feature, events):
    plt.figure(figsize=(10,5))

    # plot events
    for event in events:
        plt.axvspan(event.start, event.end, color=event.color, alpha=0.1, label=event.name)

    # plot ground truth
    (start, end) = date_range
    y = ground_truth[(ground_truth['device_id'] == device) & (ground_truth['time'] >= start) & (ground_truth['time'] < end)]
    plt.plot(y['time'], y[feature], label='ground truth')

    # plot experiment predictions
    for e in selected_exps:
        exp = exps[e]
        df = exp.preds
        y = df[(df['device_id'] == device) & (df['time'] >= start) & (df['time'] < end)]
        plt.plot(y['time'], y[feature], label=e)

    plt.tight_layout()
    plt.title(f"{feature} @ {device}")
    plt.legend(loc='upper left')
    axes = plt.gca()
    axes.set_xlim([start,end])
    plt.show()

    # plot R values for device
    for e in selected_exps:
        R_in, R_out = exps[e].R_in, exps[e].R_out
        if R_in is not None and R_out is not None:
            R_in[device].plot(kind='area', title=f"R_in for {e} @ {device}")
            R_out[device].plot(kind='area', title=f"R_out for {e} @ {device}")

interact_manual(plot_results, ground_truth=fixed(ground_truth), exps=fixed(experiments), selected_exps=exp_widget, device=device_widget, date_range=date_widget, feature=fixed('pm25'), events=fixed(data.events))

interactive(children=(SelectMultiple(description='Exps', options=('PM25_GNN', 'SGNN3_00', 'SGNN3_01', 'SGNN3_0…

<function __main__.plot_results(ground_truth, exps, selected_exps, device, date_range, feature, events)>