In [4]:
from os import makedirs
import torch
import importlib
import contextlib

from src.utils_data import load_PeMS04_flow_data, preprocess_PeMS_data, local_dataset, plot_prediction
from src.utils_training import train_model, testmodel
from src.utils_fed import fed_training_plan
from src.metrics import calculate_metrics, metrics_table, Percentage_of_Superior_Predictions
import src.config
import sys

import json 
params = src.config.Params('config.json')
module_name = 'src.models'
class_name = params.model
module = importlib.import_module(module_name)
model = getattr(module, class_name)

input_size = 1
hidden_size = 32
num_layers = 6
output_size = 1

#Load traffic flow dataframe and graph dataframe from PEMS
df_PeMS, distance = load_PeMS04_flow_data()
df_PeMS, adjmat, meanstd_dict = preprocess_PeMS_data(df_PeMS, distance, params.init_node, params.n_neighbours,
                                                    params.smooth, params.center_and_reduce,
                                                    params.normalize, params.sort_by_mean)
datadict = local_dataset(df = df_PeMS,
                        nodes = params.nodes_to_filter,
                        window_size=params.window_size,
                        stride=params.stride,
                        prediction_horizon=params.prediction_horizon)
y_true, y_pred, y_true_fed, y_pred_fed = {},{},{},{}

for node in range(len(params.nodes_to_filter)):
    y_true[node], y_pred[node] = testmodel(model(1,32,1), datadict[node]['test'], f'{params.save_model_path}local{node}.pth', meanstd_dict = meanstd_dict, sensor_order_list=[params.nodes_to_filter[node]])  

    y_true_fed[node], y_pred_fed[node] = testmodel(model(1,32,1), datadict[node]['test'], f'{params.save_model_path}bestmodel_node{node}.pth', meanstd_dict = meanstd_dict, sensor_order_list=[params.nodes_to_filter[node]])

def plot_comparison(y_true, y_pred, y_pred_fed,node):  
    from src.metrics import rmse
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import ipywidgets as widgets
    from IPython.display import display
    y_true, y_pred, y_pred_fed = y_true[node], y_pred[node], y_pred_fed[node]
    test_set = datadict[node]['test_data']*meanstd_dict[params.nodes_to_filter[node]]['std']+ meanstd_dict[params.nodes_to_filter[node]]['mean']
    index= test_set.index

    def plot_slider(i):
        plt.figure(figsize=(20, 9))
        
        # Plot first subplot
        plt.subplot(2, 1, 1)
        plt.axvspan(index[i], index[i+ params.window_size -1], alpha=0.1, color='gray')
        plt.plot(index[i:i+params.window_size], test_set[i:i+params.window_size], label='Window')
        plt.plot(index[i+params.window_size-1:i+params.window_size + params.prediction_horizon], test_set[i+params.window_size -1 :i+params.window_size + params.prediction_horizon], label='y_true')
        plt.scatter(index[i+params.window_size:i+ params.window_size + params.prediction_horizon], y_pred_fed[i, :], color='blue', label='Federated prediction')
        plt.plot(index[i+params.window_size:i +params.window_size + params.prediction_horizon], y_pred_fed[i, :], color='blue', linestyle='-', linewidth=1)
        
        ax = plt.gca()
        ax.xaxis.set_major_locator(mdates.HourLocator(interval=1))
        ax.xaxis.set_minor_locator(mdates.MinuteLocator(interval=5))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
        plt.xlabel('Time')
        plt.ylabel('Traffic Flow')
        plt.title("Federated Prediction for the {}".format(index[i].strftime('%Y-%m-%d')), fontsize=18, fontweight='bold')
        plt.legend(fontsize='large')
    #    plt.text(index[i+84], max(y_pred_fed[i,:]+50), f'RMSE: {rmse(y_true[i, :].flatten(), y_pred_fed[i, :].flatten()):.2f}', fontsize='large', fontweight='bold')

        
        # Plot second subplot
        plt.subplot(2, 1, 2)
        plt.axvspan(index[i], index[i+params.window_size -1], alpha=0.1, color='gray')
        plt.plot(index[i:i+ params.window_size], test_set[i:i+params.window_size], label='Window')
        plt.plot(index[i+params.window_size-1 :i+ params.window_size +params.prediction_horizon], test_set[i+params.window_size-1:i+params.window_size +params.prediction_horizon], label='y_true')
        plt.scatter(index[i+params.window_size:i+ params.window_size +params.prediction_horizon], y_pred[i, :], color='green', label='Local prediction')
        plt.plot(index[i+ params.window_size :i+ params.window_size +params.prediction_horizon], y_pred[i, :], color='green', linestyle='-', linewidth=1)
        
        ax = plt.gca()
        ax.xaxis.set_major_locator(mdates.HourLocator(interval=1))
        ax.xaxis.set_minor_locator(mdates.MinuteLocator(interval=5))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
        
        plt.xlabel('Time')
        plt.ylabel('Traffic Flow')
        plt.title("Local Prediction for the {}".format(index[i].strftime('%Y-%m-%d')), fontsize=18, fontweight='bold')
        plt.legend(fontsize='large')
        # plt.text(index[i+84], max(y_pred[i,:]+50), f'RMSE: {rmse(y_true_fed[i, :].flatten(), y_pred[i, :].flatten()):.2f}', fontsize='large', fontweight='bold')

        plt.tight_layout()
        plt.subplots_adjust(hspace=0.5) 
        plt.show()

    slider = widgets.IntSlider(min=0, max=len(y_true)-params.window_size, value=0, description='Index')

    def update_slider_description(change):
        index_value = index[change.new +params.window_size].strftime('%H:%M')
        slider.description = f'Index: {index_value}'

    slider.observe(update_slider_description, 'value')

    interactive_plot = widgets.interactive(plot_slider, i=slider)
    display(interactive_plot)

In [3]:
for node in range(params.number_of_nodes):
    print(f'For node {params.nodes_to_filter[node]}')
    plot_comparison(y_true, y_pred, y_pred_fed)

For node 118


interactive(children=(IntSlider(value=0, description='Index', max=2370), Output()), _dom_classes=('widget-inte…

For node 168


interactive(children=(IntSlider(value=0, description='Index', max=2370), Output()), _dom_classes=('widget-inte…

For node 261


interactive(children=(IntSlider(value=0, description='Index', max=2370), Output()), _dom_classes=('widget-inte…