In [None]:
import sys
import wandb
import pandas as pd
import numpy as np
from pprint import pprint

def mean_and_std(df):
    agg = np.stack(df.to_numpy(), axis=0)
    return np.mean(agg, axis=0), np.std(agg, axis=0)

download_root = "."

In [None]:
def get_sweep_regression_df_all(sweep_id, allow_crash=False):
    api = wandb.Api()
    sweep = api.sweep("ngruver/physics-uncertainty-exps/{}".format(sweep_id))
    
    results = []
    for run in sweep.runs:        
        config = pd.Series(run.config)
        
        if not allow_crash and not "finished" in str(run):
            continue
        
        if "finished" in str(run):
            summary = pd.Series(run.summary)
        else:
            history = run.history()
            summary = pd.Series({k: history[k].to_numpy()[-1] for k,v in history.items()})
        results.append(pd.concat([config,summary]))
    return pd.concat(results,axis=1).T

In [None]:
df = get_sweep_regression_df_all("v96kirjy",allow_crash=True)
df2 = get_sweep_regression_df_all("pexiwka8",allow_crash=True)
df = pd.concat((df,df2))

In [None]:
df["model_type"].unique()

In [None]:
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

reader_friendly_dict = {
    "NN": "NODE",
    "MechanicsNN": "NODE + SO",
    "HNN": "HNN"
}

sns.set_style('whitegrid')
colors = ["#00abdf", "#00058A", "#6A0078", (96/255,74/255,123/255), "#8E6100"]
sns.set_palette(sns.color_palette(colors))

filtered = df[df['model_type'].isin(['HNN','NN','MechanicsNN'])].copy()
filtered["model_type"] = filtered["model_type"].apply(lambda s: reader_friendly_dict[s])
filtered["dataset"]=filtered["system_type"]+filtered["num_bodies"].astype(str)
filtered["Rollout Error"] = filtered["test_gerr"].astype(float)
filtered["Energy Violation"] = filtered["test_Herr"].astype(float)
filled_markers = ('o', 'v', '^', '<', '>', '8', 's', 'p', '*', 'h', 'H', 'D', 'd', 'P', 'X')[:len(filtered["dataset"].unique())]

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams.update({'font.size': 14})

fig, ax = plt.subplots(1, 1, figsize=(4.5,3.5))
sns.scatterplot(data=filtered,x='Rollout Error',y='Energy Violation',hue='model_type', ax=ax)#,style="dataset",markers=filled_markers)
ax.get_legend().remove()
plt.yscale('log')
plt.xscale('log')

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3)
fig.subplots_adjust(bottom=0.3)

plt.savefig('energy_conservation_loglog.pdf', bbox_inches='tight')
plt.show()

In [None]:
from sklearn import datasets, linear_model, metrics
regr = linear_model.LinearRegression()
regr.fit(np.log(filtered["Rollout Error"][:,None]), np.log(filtered["Energy Violation"]))
y_pred = regr.predict(np.log(filtered["Rollout Error"][:,None]))
y_true = np.log(filtered["Energy Violation"])
residuals = y_true-y_pred
filtered["residuals"] = residuals/np.log(filtered["Energy Violation"]).std()

In [None]:
metrics.r2_score(y_true, y_pred)

In [None]:
plt.figure(figsize=(4, 3))
order = sorted(filtered["dataset"].unique())
order = order[5:]+order[:5]

plot =sns.barplot(y="residuals",hue="model_type",x="dataset",data=filtered,order=order)
plt.setp(plot.get_xticklabels(), rotation=30)
plt.xlabel('')
plt.savefig('energy_conservation_residuals.pdf', bbox_inches='tight')


In [None]:
df = get_sweep_regression_df_all("kj4ke9i2",allow_crash=True)

In [None]:
filtered = df#df[df['model_type'].str.fullmatch('|'.join(['HNN','NN','MechanicsNN','SecondOrderNN']))]
filtered["dataset"]=filtered["system_type"].apply(lambda s: s.replace("Pendulum", " ")) +filtered["num_bodies"].astype(str)

filtered["SymReg strength"] = 1/filtered["alpha"].astype(float)
order = sorted(filtered["dataset"].unique())

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(1, 1, figsize=(6.75,5.25))
plot = sns.barplot(data=filtered, x="dataset", y='test_gerr', hue="SymReg strength", order=order, palette="rocket",ax=ax)
ax.get_legend().remove()
ax.grid(False)

plt.yscale('log')
plt.xlabel('')
plt.ylabel("Rollout Error")

handles, labels = ax.get_legend_handles_labels()
leg = fig.legend(handles, labels, loc='lower center', ncol=6, prop={'size': 12}, title="$\\alpha=$")#, fontsize=45)
fig.subplots_adjust(bottom=0.2, left=-.15)
    
plt.savefig('state_err_reg.pdf', bbox_inches='tight')
plt.show()
plt.close()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import FuncFormatter

filtered = df
filtered["dataset"]=filtered["system_type"].apply(lambda s: s.replace("Pendulum", " ")) +filtered["num_bodies"].astype(str)
filtered["Symplectic Error"] = np.log10(filtered["Train_symreg"].astype(float))
filtered["Rollout Error"] = np.log10(filtered["test_gerr"].astype(float))

filtered["SymReg strength"] = 1/filtered["alpha"].astype(float)
order = sorted(filtered["dataset"].unique())

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams.update({'font.size': 14})

palette = sns.color_palette("Paired", desat=0.8)[4:]
g = sns.lmplot(data=filtered,x="Symplectic Error",y='Rollout Error',hue="dataset",hue_order=order, legend_out=True, height=4, aspect=1.3, palette=palette)

ax = g.axes[0,0]
ax.set_xticks(np.arange(-10,4,2))
ax.set_yticks(np.arange(-4,1))
formatter = lambda x, pos: f'{10. ** x:g}'
ax.get_xaxis().set_major_formatter(FuncFormatter(formatter))
ax.get_yaxis().set_major_formatter(FuncFormatter(formatter))

ax.grid(False)
ax.tick_params(axis='both', which='major', labelsize=14)
ax.set_xlabel("Symplectic Error", labelpad=10)

legend = g.legend
legend.set_title("")

plt.savefig('state_err_reg_value.pdf', bbox_inches='tight')
plt.show()

In [None]:
import sys
import wandb
import pandas as pd
import numpy as np
from pprint import pprint

def mean_and_std(df):
    agg = np.stack(df.to_numpy(), axis=0)
    return np.mean(agg, axis=0), np.std(agg, axis=0)

download_root = "."
import json
def get_sweep_tabular(sweep_id, allow_crash=True):
    api = wandb.Api()
    sweep = api.sweep("ngruver/physics-uncertainty-exps/{}".format(sweep_id))
    
    results = []
    for run in sweep.runs:
#         print(run)
        config = pd.Series(run.config)
        
        if not allow_crash and not "finished" in str(run):
            continue
        if "finished" in str(run):
#             print(run.summary)
            summary = pd.Series(run.summary)
        else:
            history = run.history()
            summary = pd.Series({k: history[k].to_numpy()[-1] for k,v in history.items()})
        for f in run.files():
            if not f.name.endswith(summary['H_err_vec']['path'].split('/')[-1]):
                continue
            f.download(root=".", replace=True)
            with open(f.name) as fd:
                data = np.array(json.load(fd)['data'])
                print(f.name)
        config = pd.Series(run.config)
        logherrs=data
        ic = np.arange(logherrs.shape[0])[:,None]
        ic =ic+ np.zeros_like(logherrs)
        T = np.linspace(0,1,ic.shape[-1])[None,:]+np.zeros_like(logherrs)
        df = pd.DataFrame({'logherr':logherrs.reshape(-1),'ics':ic.reshape(-1),'T':T.reshape(-1)})
        c = config.to_frame()
        for att in c.T.columns:
            df[att]=config[att]
        results.append(df)
    return pd.concat(results)

In [None]:
df = df_all = get_sweep_tabular("j3sjkwvo",False)

In [None]:
df["dataset"]=df["system_type"]+df["num_bodies"].astype(str)
mean = df.groupby(['model_type','dataset','T']).mean()['logherr'].reset_index()
std = df.groupby(['model_type','dataset','T']).std()['logherr'].reset_index()

In [None]:
mean['std'] = std['logherr']
mean['std'] = np.exp(mean['std'])
mean['logherr']=np.exp(mean['logherr'])

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

colors = ["#00abdf", "#00058A", "#6A0078", (96/255,74/255,123/255), "#8E6100"]
sns.set_palette(sns.color_palette(colors))
matplotlib.rcParams.update({'font.size': 18})

fig1, f1_axes = plt.subplots(ncols=3, nrows=2, constrained_layout=True,figsize=(8,6),sharex=True,sharey=True)
datasets = [f'ChainPendulum{i}' for i in (2,3,4)]+[f'SpringPendulum{i}' for i in (2,3,4)]
for i,ds in enumerate(datasets):
    dfs = mean[mean['dataset']==ds]
    #print(dfs[dfs['T']==1])
    dfhnn = dfs[dfs['model_type']=='HNN']
    dfnn =  dfs[dfs['model_type']=='NN']
    ax = f1_axes[i//3,i%3]
    ax.plot(dfhnn['T'],dfhnn['logherr'],label="HNN")
    ax.fill_between(dfhnn['T'],dfhnn['logherr']/dfhnn['std'],dfhnn['logherr']*dfhnn['std'],alpha=.2)
    ax.plot(dfnn['T'],dfnn['logherr'],label="NODE")
    ax.fill_between(dfnn['T'],dfnn['logherr']/dfnn['std'],dfnn['logherr']*dfnn['std'],alpha=.2)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_ylim(bottom=5e-6, top=1e1)
    ax.grid(True)
    ax.tick_params(axis=u'both', which=u'both',length=0)
    if i//3==0:
        ax.set_title(f"{i+2} link")
    #ax.title(ds.split('P')[0])
fig1.text(1.01, 0.72, 'Chain', ha='center', va='center', rotation='vertical')
fig1.text(1.01, 0.3, 'Spring', ha='center', va='center', rotation='vertical')
fig1.text(-0.005, 0.5, 'Energy Error', ha='center', va='center', rotation='vertical')
fig1.text(0.54, 0, 'Rollout Time T', ha='center', va='center')

plt.legend()
plt.tight_layout()
plt.show()

fig1.savefig('energy_growth.pdf', bbox_inches='tight')