In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
plt.rcParams.update({'font.size': 14})
import os

import re

In [2]:
EQUATIONS = [
    'wave', 
    'gas_dynamics', 
    'reaction_diffusion', 
    'kuramoto_sivashinsky', 
    'advection',
    'burgers',
]
EQUATION_REAL_NAMES = {
    'wave': 'Wave', 
    'gas_dynamics': 'Gas Dynamics', 
    'reaction_diffusion': 'Reaction-Diffusion', 
    'kuramoto_sivashinsky': 'Kuramoto-Sivashinsky', 
    'advection': 'Advection',
    'burgers': 'Burgers',

}
MODELS = [
    'feast', 
    'gat', 
    'gcn', 
    'point_gnn', 
    'point_transformer', 
    'kernelNN', 
    'graphpde', 
    'persistence',
    'cnn', 
    'resnet', 
    'neuralpde'
]
MODEL_REAL_NAMES = {
    'feast': 'FeaSt', 
    'gat': 'GAT', 
    'gcn': 'GCN', 
    'point_gnn': 'PointGNN', 
    'point_net': 'PointNet',
    'point_transformer': 'Point Transformer', 
    'kernelNN': 'KernelNN', 
    'graphpde': 'GraphPDE', 
    'persistence': 'Persistence',
    'cnn': 'CNN', 
    'resnet': 'ResNet', 
    'neuralpde': 'NeuralPDE'
}

TASK = "forecast"
SUPPORTS = ["cloud", "grid"]
NUM_POINTS = "high"
METRIC = [f"test_rollout_{i}" for i in range(1,17)]

listing = glob.glob("../output/csv_logs/*/high/*/*/*/metrics.csv")
rows = []
for file in listing:
    path = file.split("/")
    equation = path[5]
    model = path[6]
    csv = pd.read_csv(file)[[f"test_rollout_{i}" for i in range(1, 17)]].dropna()
    csv['equation'] = EQUATION_REAL_NAMES[equation]
    csv['model'] = MODEL_REAL_NAMES[model]
    csv = csv.set_index(['equation', 'model'])
    rows.append(csv)

results = pd.concat(rows)
results = results.query("model != \"PointNet\"")


In [3]:
results

Unnamed: 0_level_0,Unnamed: 1_level_0,test_rollout_1,test_rollout_2,test_rollout_3,test_rollout_4,test_rollout_5,test_rollout_6,test_rollout_7,test_rollout_8,test_rollout_9,test_rollout_10,test_rollout_11,test_rollout_12,test_rollout_13,test_rollout_14,test_rollout_15,test_rollout_16
equation,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
Wave,CNN,0.001434,0.010996,0.035093,0.073743,0.120123,0.165976,0.207038,0.241552,0.272328,0.304355,0.339347,0.378488,0.421988,0.468417,0.515587,0.561433
Wave,NeuralPDE,0.001699,0.010858,0.029375,0.052151,0.072729,0.088916,0.105090,0.124263,0.144217,0.162839,0.178689,0.191123,0.201686,0.212009,0.226893,0.247704
Wave,ResNet,0.001459,0.010329,0.032706,0.067625,0.102738,0.128124,0.146789,0.162218,0.175900,0.186586,0.195999,0.207434,0.223201,0.244746,0.270747,0.299457
Gas Dynamics,CNN,0.004204,0.017296,0.038840,0.069755,0.111115,0.163608,0.227226,0.301390,0.384521,0.474257,0.568182,0.662843,0.755784,0.843738,0.923799,0.995382
Gas Dynamics,NeuralPDE,0.003734,0.014553,0.030368,0.050412,0.074455,0.102024,0.132273,0.164706,0.199090,0.234429,0.269893,0.304653,0.339796,0.375110,0.409833,0.443498
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Kuramoto-Sivashinsky,GraphPDE,0.007198,0.059355,0.193169,0.433975,0.794523,1.246442,1.720997,2.141097,2.454362,2.641909,2.706388,2.665718,2.550938,2.399127,2.242848,2.104396
Kuramoto-Sivashinsky,KernelNN,0.006687,0.051397,0.158321,0.339433,0.597169,0.911650,1.242362,1.544169,1.784620,1.951613,2.048040,2.085832,2.083302,2.059652,2.029875,2.003340
Kuramoto-Sivashinsky,PointGNN,0.006730,0.052549,0.159624,0.334471,0.575235,0.863840,1.168177,1.457609,1.716235,1.941022,2.136551,2.307629,2.456894,2.588295,2.707710,2.820627
Kuramoto-Sivashinsky,Persistence,0.142243,0.482181,0.871265,1.219627,1.494708,1.691459,1.817970,1.888432,1.917897,1.920181,1.908135,1.893092,1.882097,1.879186,1.885136,1.897520


In [5]:
results_single = results.pivot_table(values="test_rollout_16", index="model", columns="equation")
t = results_single.to_latex(float_format='%.2E')
t = results_single.to_markdown()
print(t)

| model             |       Advection |   Burgers |   Gas Dynamics |   Kuramoto-Sivashinsky |   Reaction-Diffusion |     Wave |
|:------------------|----------------:|----------:|---------------:|-----------------------:|---------------------:|---------:|
| CNN               |     0.00161331  |  0.554554 |       0.995382 |            1.26011     |          0.0183483   | 0.561433 |
| FeaSt             |     1.48288     |  0.561197 |       0.819594 |            3.74448     |          0.130149    | 1.61066  |
| GAT               | 41364.1         |  0.833353 |       1.21436  |            5.68925     |          3.85506     | 2.38418  |
| GCN               |     3.51453e+13 | 13.0876   |       7.20633  |            1.70612e+24 |          1.75955e+07 | 7.89253  |
| GraphPDE          |     1.07953     |  0.729879 |       0.969208 |            2.1044      |          0.0800235   | 1.02586  |
| KernelNN          |     0.897431    |  0.72716  |       0.854015 |            2.00334     |          0

  t = results_single.to_latex(float_format='%.2E')


In [6]:
scores = re.findall("\d\D\d\dE\D\d\d", t)
for s in scores:
    exponent = s[6:] if s[6] != "0" else s[7]
    sign = "" if s[5] == "+" else "-"
    formated_s = f"${s[:4]}\cdot 10^{{{sign}{exponent}}}$"
    t = t.replace(s, formated_s)

print(t)

| model             |       Advection |   Burgers |   Gas Dynamics |   Kuramoto-Sivashinsky |   Reaction-Diffusion |     Wave |
|:------------------|----------------:|----------:|---------------:|-----------------------:|---------------------:|---------:|
| CNN               |     0.00161331  |  0.554554 |       0.995382 |            1.26011     |          0.0183483   | 0.561433 |
| FeaSt             |     1.48288     |  0.561197 |       0.819594 |            3.74448     |          0.130149    | 1.61066  |
| GAT               | 41364.1         |  0.833353 |       1.21436  |            5.68925     |          3.85506     | 2.38418  |
| GCN               |     3.51453e+13 | 13.0876   |       7.20633  |            1.70612e+24 |          1.75955e+07 | 7.89253  |
| GraphPDE          |     1.07953     |  0.729879 |       0.969208 |            2.1044      |          0.0800235   | 1.02586  |
| KernelNN          |     0.897431    |  0.72716  |       0.854015 |            2.00334     |          0

In [7]:
UPPER_LIMIT_DICT = {
    'Wave': 3.1, 
    'Gas Dynamics': 2.1, 
    'Reaction-Diffusion': 4.1, 
    'Kuramoto-Sivashinsky': 4.1, 
    'Advection': 2.6,
    'Burgers': 2.1
}

for equation in EQUATIONS:
    equation = EQUATION_REAL_NAMES[equation]

    for model in MODELS:
        model_name = MODEL_REAL_NAMES[model]
        values = results.query(f"model == \"{model_name}\" and equation == \"{equation}\"").values.flatten()

        if model_name in ["CNN", "ResNet", "NeuralPDE"]:
            linestyle = "dotted"
        else:
            linestyle = "solid"

        if model_name == "Persistence":
            color = "black"
            linewidth = 3
            linestyle = "dashdot"
        else:
            color = None
            linewidth = 2.5
        
        #print(equation, model)
        #print(values)
        plt.plot(range(1, 17), values, label=model_name, linestyle=linestyle, linewidth=linewidth, c=color)
    

    plt.legend()
    plt.ylim(-0.1, UPPER_LIMIT_DICT[equation])
    plt.ylabel("MSE")
    plt.xlabel("Rollout step")
    plt.tight_layout()
    plt.savefig(f"../figures/results/{equation}.svg", dpi=200, bbox_inches="tight")
    plt.clf()

<Figure size 640x480 with 0 Axes>