In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

dataname = "PeMS08"
data_path = f"../../data/y_hat/GAF_{dataname}_prediction.npz"
data = np.load(data_path)
x = data["x"]
y = data["y"]
y_hat = data["y_hat"]

In [None]:
print("x.shape:", x.shape)
print("y.shape:", y.shape)
print("y_hat.shape:", y_hat.shape)

In [None]:
def get_random_sample(x, y, y_hat):
    num_samples, _, num_vertex = x.shape

    sample_idx = np.random.randint(num_samples)
    vertex_idx = np.random.randint(num_vertex)

    select_x = x[sample_idx, :, vertex_idx]
    select_y = y[sample_idx, :, vertex_idx]
    select_y_hat = y_hat[sample_idx, :, vertex_idx]

    return select_x, select_y, select_y_hat

x,y,y_hat = get_random_sample(x, y, y_hat)
print("x=", x)
print("y=", y)
print("y_hat=", y_hat)
y_values = np.concatenate([x, y], axis=0)
print("shape of y_values:", y_values.shape)

In [None]:
def plot_prediction(x, y, y_hat, dataname):
    y=np.insert(y,0,x[-1])
    y_hat=np.insert(y_hat,0,x[-1])
    data = pd.DataFrame({
        'Time Step': np.arange(12),  # 时间步
        'Traffic Flow': x,  # 流量数据
        'Data Type': np.repeat(['X'] * 12, 1)  # 数据类型
    })
    data_y_hat = pd.DataFrame({
        'Time Step': np.tile(np.arange(11, 24),2),  # 时间步
        'Traffic Flow': np.concatenate((y, y_hat)),  # 流量数据
        'Data Type': np.repeat(['Y'] * 13 + ['Y_hat'] * 13, 1)  # 数据类型
    })
    data = pd.concat([data, data_y_hat], ignore_index=True)
    # 使用 Seaborn 绘制图形
    plt.figure(figsize=(8, 4))
    sns.set_theme(style='whitegrid')
    sns.lineplot(x='Time Step', y='Traffic Flow', hue='Data Type', data=data, palette=['blue', 'green', 'red'])
    plt.xlabel('Time Step')
    plt.ylabel('Traffic Flow')
    plt.title(f'Traffic Flow Prediction--{dataname}')
    plt.legend(title='Data Type', loc='upper right')
    plt.xticks(np.arange(24))
    plt.show()


plot_prediction(x, y, y_hat,dataname=dataname)