In [95]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 14})
import os

In [39]:
EQUATIONS = [
    'wave', 
    'gas_dynamics', 
    'reaction_diffusion', 
    'kuramoto_sivashinsky', 
    'advection'
]
EQUATION_REAL_NAMES = {
    'wave': 'Wave', 
    'gas_dynamics': 'Gas Dynamics', 
    'reaction_diffusion': 'Reaction-Diffusion', 
    'kuramoto_sivashinsky': 'Kuramoto-Sivashinsky', 
    'advection': 'Advection'

}
MODELS = [
    'feast', 
    'gat', 
    'gcn', 
    'point_gnn', 
    'point_transformer', 
    'kernelNN', 
    'graphpde', 
    'persistence',
    'cnn', 
    'resnet', 
    'neuralpde'
]
MODEL_REAL_NAMES = {
    'feast': 'FeaSt', 
    'gat': 'GAT', 
    'gcn': 'GCN', 
    'point_gnn': 'PointGNN', 
    'point_transformer': 'Point Transformer', 
    'kernelNN': 'KernelNN', 
    'graphpde': 'GraphPDE', 
    'persistence': 'Persistence',
    'cnn': 'CNN', 
    'resnet': 'ResNet', 
    'neuralpde': 'NeuralPDE'
}

TASK = "forecast"
SUPPORTS = ["cloud", "grid"]
NUM_POINTS = "high"
METRIC = [f"test_rollout_{i}" for i in range(1,17)]

entries = []
for sup in SUPPORTS:
    for model in MODELS:
        for eq in EQUATIONS:
            model_path = f"../output/csv_logs/{sup}/{NUM_POINTS}/{eq}/{model}/"

            if not os.path.exists(model_path):
                continue

            version_numbers = [int(v[8:]) for v in os.listdir(model_path)]
            version_last = f"version_{max(version_numbers)}"
            results_csv = f"{model_path}/{version_last}/metrics.csv"

            results = pd.read_csv(results_csv)[METRIC].dropna()
            results["equation"] = EQUATION_REAL_NAMES[eq]
            results["model"] = MODEL_REAL_NAMES[model]
            entries.append(results)


results = pd.concat(entries).set_index(["model", "equation"])


In [107]:
UPPER_LIMIT_DICT = {
    'Wave': 3.1, 
    'Gas Dynamics': 2.1, 
    'Reaction-Diffusion': 4.1, 
    'Kuramoto-Sivashinsky': 4.1, 
    'Advection': 2.6
}

for equation in EQUATIONS:
    equation = EQUATION_REAL_NAMES[equation]

    for model in MODELS:
        model_name = MODEL_REAL_NAMES[model]
        values = results.query(f"model == \"{model_name}\" and equation == \"{equation}\"").values.flatten()

        if model_name in ["CNN", "ResNet", "NeuralPDE"]:
            linestyle = "dotted"
        else:
            linestyle = "solid"

        if model_name == "Persistence":
            color = "black"
            linewidth = 3
            linestyle = "dashdot"
        else:
            color = None
            linewidth = 2.5
        
        plt.plot(range(1, 17), values, label=model_name, linestyle=linestyle, linewidth=linewidth, c=color)

    plt.legend()
    plt.ylim(-0.1, UPPER_LIMIT_DICT[equation])
    plt.tight_layout()
    plt.savefig(f"../figures/results/{equation}.pdf", dpi=200)
    plt.clf()

<Figure size 640x480 with 0 Axes>