# Loss Plot

In [None]:
import json
import matplotlib.pyplot as plt
import glob
import numpy as np


def plot_loss(filter_params=None, end_idx=None):
    # Get a list of files that match the pattern
    files = glob.glob('results/*.json')

    # Dimension names
    dim_names = ['x', 'y', 'z', 'Mean']

    # Determine the number of subplots from the loss dimension
    with open(files[0], 'r') as f:
        data = json.load(f)
    losses = data['losses']  # Take first loss array as an example
    num_subplots = len(losses)  # Include plot for the mean loss

    # Create a subplot for each loss dimension
    fig, axs = plt.subplots(num_subplots, figsize=(14, 10), sharex=True)
    fig.suptitle(f"VAl Loss - Epochs")

    # Ensure axs is a list even when num_subplots = 1
    if num_subplots == 1:
        axs = [axs]

    # List to store parameters of the files that passed the filter
    filtered_params_list = []
    files_passed_filter = []

    # Loop over each file
    for filename in files:
        # Load the data from the JSON file
        with open(filename, 'r') as f:
            data = json.load(f)

        # Extract the training parameters from the data
        params = {k: v for k, v in data.items() if k != 'losses'}

        # If filter is not None, skip this file if it does not match the filter
        if filter_params is not None and not all(item in params.items() for item in filter_params.items()):
            continue

        # If the file passed the filter, add its parameters and filename to the list
        if filter_params is None or all(item in params.items() for item in filter_params.items()):
            filtered_params_list.append(params)
            files_passed_filter.append(filename)

    # Determine common parameters
    common_params = filtered_params_list[0].copy()
    for params in filtered_params_list[1:]:
        keys_to_delete = []
        for k in common_params.keys():
            if common_params[k] != params[k]:
                keys_to_delete.append(k)
        for k in keys_to_delete:
            common_params.pop(k)

    # Loop over each file again to plot the data
    for params, filename in zip(filtered_params_list, files_passed_filter):
        # Load the data from the JSON file
        with open(filename, 'r') as f:
            data = json.load(f)

        # Extract the unique training parameters for this file
        unique_params = {k: v for k,
                         v in params.items() if k not in common_params}

        # Create a label for this line from the unique training parameters
        label = ', '.join(f'{key}: {value}' for key,
                          value in unique_params.items())

        # Extract loss values
        losses = data['losses']

        # Plot the loss values for each dimension
        for j in range(num_subplots):
            axs[j].plot(losses[j][:end_idx], label=label)
            axs[j].set_ylabel(f'Loss for {dim_names[j]}')
            if end_idx:
                axs[j].set_xlim(0, end_idx)
            # axs[j].grid(True)
                
        axs[-1].legend()
        axs[-1].set_xlabel('Epochs')

    plt.tight_layout()
    plt.show()


In [None]:
plot_loss(filter_params={"message": "silicone R"})

# Data Efficiency Plot

In [None]:
import glob
import numpy as np
import re
import matplotlib.pyplot as plt

# 定义一个自定义的排序函数，将百分比字符串转换为整数进行排序
def custom_sort(item):
    percentage = int(re.search(r'(\d+)%\_', item).group(1))
    return percentage

def adjust_x_position(labels, offset):
    positions = range(len(labels))
    return [pos + offset for pos in positions]

# 定义一个函数来获取数据
def get_data(suffix):
    rmse_list = []
    std_list = []
    labels = []
    subsets_rmse = []

    rmse_files = sorted(glob.glob(f'plots/*{suffix}.npy'), key=custom_sort)

    for rmse_file in rmse_files:
        rmse = np.load(rmse_file)
        mean_rmse = np.mean(rmse, axis=0)  # 修改这里，获取每个子集的平均rmse
        subsets_rmse.append(mean_rmse)
        overall_mean_rmse = np.mean(mean_rmse)
        std_rmse = np.std(mean_rmse)
        rmse_list.append(overall_mean_rmse)
        std_list.append(std_rmse)
        
        percentage = re.search(r'(\d+%)\_', rmse_file).group(1)
        labels.append(percentage)
    
    return rmse_list, std_list, labels, subsets_rmse

rmse_list_fcnn, std_list_fcnn, labels_fcnn, subsets_rmse_fcnn = get_data('_fcnn')
rmse_list_gnn, std_list_gnn, labels_gnn, subsets_rmse_gnn = get_data('_gnn')

# 使用matplotlib绘制散点图并添加误差棒
fig, ax = plt.subplots(figsize=(12, 8))

# 使用errorbar绘制FCNN的误差棒，并稍微向左移动
ax.errorbar(adjust_x_position(labels_fcnn, -0.1), rmse_list_fcnn, yerr=std_list_fcnn, fmt='o', label='FCNN', color='blue', ecolor='red', capsize=5)

# 使用errorbar绘制GNN的误差棒，并稍微向右移动
ax.errorbar(adjust_x_position(labels_gnn, 0.1), rmse_list_gnn, yerr=std_list_gnn, fmt='s', label='GNN', color='green', ecolor='orange', capsize=5)

# # 绘制FCNN的每个子集的点
# for i, subset_rmse in enumerate(subsets_rmse_fcnn):
#     ax.scatter([i - 0.2] * len(subset_rmse), subset_rmse, color='blue', alpha=0.5)

# # 绘制GNN的每个子集的点
# for i, subset_rmse in enumerate(subsets_rmse_gnn):
#     ax.scatter([i + 0.2] * len(subset_rmse), subset_rmse, color='green', alpha=0.5)

ax.set_xticks(range(len(labels_fcnn)))  # 确保x轴刻度正确
ax.set_xticklabels(labels_fcnn)  # 使用FCNN的标签作为x轴标签

plt.xlabel("Data Usage Percentage")
plt.ylabel("Average RMSE Value")
plt.title("Average RMSE for Different Data Usages (Stride: 30)")
plt.grid(axis='y')
plt.legend()

plt.tight_layout()
plt.show()

# Time plots

In [None]:
import pickle
import os
import matplotlib.pyplot as plt


def time_plot(set_name):
    # 加载第一个图
    with open(os.path.join(os.getcwd(), f'plots/{set_name}_graph.pickle'), 'rb') as f:
        fig1 = pickle.load(f)

    # 加载第二个图
    with open(os.path.join(os.getcwd(), f'plots/{set_name}_fcnn.pickle'), 'rb') as f:
        fig2 = pickle.load(f)

    # 创建一个新的图来容纳所有子图
    new_fig, new_ax = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
    new_fig.suptitle(f"Position - Time plot of imageset {set}")

    # 将两个图的相应子图合并到新图中
    for i in range(3):
        ax1 = fig1.get_axes()[i]
        ax2 = fig2.get_axes()[i]

        # 从第一个图复制线
        for line in ax1.lines:
            # 只复制第一个图的actual线
            if 'actual' in line.get_label():
                new_ax[i].plot(line.get_xdata(), line.get_ydata(),
                               color=line.get_color(), label=line.get_label())
            else:
                new_ax[i].plot(line.get_xdata(), line.get_ydata(
                ), linestyle='--', color=line.get_color(), label=line.get_label() + '_gnn')

        # 从第二个图复制非actual的线
        for line in ax2.lines:
            if 'predict' in line.get_label():
                new_ax[i].plot(line.get_xdata(), line.get_ydata(), color=line.get_color(
                ), label=line.get_label() + '_fcnn')

        new_ax[i].set_ylabel(ax1.get_ylabel())
        new_ax[i].set_xlim(ax1.get_xlim())
        new_ax[i].set_ylim(ax1.get_ylim())

        if i == 1:
            new_ax[i].legend()

    new_ax[-1].set_xlabel('Time (s)')
    plt.tight_layout()
    plt.savefig(os.path.join(os.getcwd(), f'plots/{set}.pdf'), format='pdf')
    plt.close(fig1)
    plt.close(fig2)
    plt.show()


In [None]:
test_sets = ['C_M1_T1_8', 'M1_NF', 'M2_NF', 'R1_M1_T1_1',
             'R1_M1_T1_2', 'R3_M1_T1_1', 'R3_M1_T1_2', 'L1_M1_T1_1',
             'L1_M1_T1_2', 'L3_M1_T1_1', 'L3_M1_T1_2', 'Z1_M1_T1_1',
             'Z3_M1_T1_1', 'R2_M1_T1_8', 'M3_R2_NF', 'M5_R2_NF',
             'M5_L2_NF', 'M3_L2_NF', 'M7_L2_NF', 'Z2_M1_T1_5',
             'M9_Z2_NF', 'M8_Z2_NF']

for set in test_sets:
    time_plot(set)