In [None]:
import os
import glob
import re
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch import nn
import sys
import matplotlib.cm as cm
import matplotlib.lines as mlines
import matplotlib.patches as mpatches

sys.path.append("../scripts")
sys.path.append("../utility")

from network import KoopmanNet
from dataset import KoopmanDatasetCollector

#project_name = "Koopman_Results_Apr_8_2"
project_name = "Test"
gamma = 0.8

if not os.path.exists(project_name):
    os.makedirs(project_name)

In [56]:
def evaluate_model(model, data, u_dim, gamma, state_dim, device):
    model.eval()
    with torch.no_grad():
        steps = data.shape[0]
        if u_dim is None:
            X = model.encode(data[0].to(device))
        else:
            X = model.encode(data[0, :, u_dim:].to(device))

        encoded_initial = X[:, state_dim:]
        
        weighted_loss = 0.0
        beta = 1.0
        beta_sum = 0.0
        step_errors = []

        for i in range(steps - 1):
            if u_dim is None:
                X = model.forward(X, None)
                target = data[i+1].to(device)
            else:
                X = model.forward(X, data[i, :, :u_dim].to(device))
                target = data[i+1, :, u_dim:].to(device)
            error = nn.MSELoss()(X[:, :state_dim], target)
            step_errors.append(error.item())
            weighted_loss += beta * error
            beta_sum += beta
            beta *= gamma
        weighted_loss /= beta_sum

        z = encoded_initial
        z_mean = torch.mean(z, dim=0, keepdim=True)
        z_centered = z - z_mean
        cov_matrix = (z_centered.t() @ z_centered) / (z_centered.size(0) - 1)
        diag_cov = torch.diag(torch.diag(cov_matrix))
        off_diag = cov_matrix - diag_cov
        cov_loss_val = torch.norm(off_diag, p='fro')**2
        encode_dim = X.shape[1] - state_dim
        normalized_cov_loss = (cov_loss_val.item() / (encode_dim * (encode_dim - 1))
                               if encode_dim > 1 else cov_loss_val.item())
    return weighted_loss.item(), step_errors, normalized_cov_loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

envs = ['Franka']#['LogisticMap', 'DampingPendulum', 'Franka', 'DoublePendulum', 'Polynomial', 'G1', 'Go2']
encode_dims = [64]#[4, 16, 64, 256, 1024]
cov_regs = [0]#[0, 1]
seeds = [1]#[1, 2, 3, 4, 5]

best_models_dir = os.path.join("..", "log", "best_models", project_name)
results = {}

pattern = r"best_model_norm_(\w+)_(\d+)_(\d+)_(\d+)\.pth"
all_files = glob.glob(os.path.join(best_models_dir, "best_model_norm_*.pth"))
for f in all_files:
    basename = os.path.basename(f)
    m = re.match(pattern, basename)
    if m:
        env_name = m.group(1)
        enc_dim = int(m.group(2))
        cov_reg_val = int(m.group(3))
        seed_val = int(m.group(4))
        if enc_dim not in encode_dims or env_name not in envs or seed_val not in seeds:
            continue
        key = (env_name, enc_dim, cov_reg_val)
        if key not in results:
            results[key] = []
        results[key].append((seed_val, f))

agg_metrics = {}
step_error_curves_agg = {}

for env in envs:
    Ksteps = 1 if env in ["Polynomial", "LogisticMap"] else 10

    norm_str = "norm"
    dataset_path = os.path.join("..", "data", "datasets", 
                                f"dataset_{env}_{norm_str}_Ktrain_60000_Kval_20000_Ktest_20000_Ksteps_{Ksteps}.pt")
    if not os.path.exists(dataset_path):
        print(f"Dataset file {dataset_path} not found for environment {env}, skipping.")
        continue
    data_dict = torch.load(dataset_path, weights_only=False)
    test_data = torch.from_numpy(data_dict["Ktest_data"]).float().to(device)

    if env in ["Franka", "DoublePendulum", "DampingPendulum", "G1", "Go2"]:
        if env == "Franka":
            u_dim = 7
            state_dim = test_data.shape[2] - u_dim
        elif env == "DoublePendulum":
            u_dim = 2
            state_dim = test_data.shape[2] - u_dim
        elif env == "DampingPendulum":
            u_dim = 1
            state_dim = test_data.shape[2] - u_dim
        elif env == "G1":
            u_dim = 37
            state_dim = test_data.shape[2] - u_dim
        elif env == "Go2":
            u_dim = 12
            state_dim = test_data.shape[2] - u_dim
        else:
            u_dim = None
            state_dim = test_data.shape[2]
    else:
        u_dim = None
        state_dim = test_data.shape[2]

    for key in results:
        if key[0] != env:
            continue
        enc_dim, cov_reg_val = key[1], key[2]
        weighted_errors = []
        norm_cov_losses = []
        step_errors_all = []
        for (seed_val, filepath) in results[key]:
            checkpoint = torch.load(filepath, map_location=device)
            layers = checkpoint['layer']
            Nkoopman = state_dim + enc_dim
            model = KoopmanNet(layers, Nkoopman, u_dim)
            model.load_state_dict(checkpoint['model'])
            model.to(device)
            w_err, step_errs, norm_cov_loss = evaluate_model(model, test_data, u_dim, gamma, state_dim, device)

            weighted_errors.append(w_err)
            norm_cov_losses.append(norm_cov_loss)
            step_errors_all.append(np.array(step_errs))

        weighted_errors = np.array(weighted_errors)
        norm_cov_losses = np.array(norm_cov_losses)

        weight_log = np.log(weighted_errors)
        w_mean = np.exp(np.mean(weight_log))
        w_std = np.exp(np.std(weight_log))
        w_lower = np.exp(np.mean(weight_log) - np.std(weight_log))
        w_upper = np.exp(np.mean(weight_log) + np.std(weight_log))
        
        cov_log = np.log(norm_cov_losses + 1e-12)
        cov_mean = np.exp(np.mean(cov_log))
        cov_std = np.exp(np.std(cov_log))
        cov_lower = np.exp(np.mean(cov_log) - np.std(cov_log))
        cov_upper = np.exp(np.mean(cov_log) + np.std(cov_log))
        
        step_errors_all = np.array(step_errors_all)

        step_log = np.log(step_errors_all + 1e-12)
        step_mean = np.exp(np.mean(step_log, axis=0))
        step_std = np.exp(np.std(step_log, axis=0))
        step_lower = np.exp(np.mean(step_log, axis=0) - np.std(step_log, axis=0))
        step_upper = np.exp(np.mean(step_log, axis=0) + np.std(step_log, axis=0))
        

        agg_metrics[key] = {
            "WeightedError_mean": w_mean,
            "WeightedError_lower": w_lower,
            "WeightedError_upper": w_upper,
            "NormalizedCovLoss_mean": cov_mean,
            "NormalizedCovLoss_lower": cov_lower,
            "NormalizedCovLoss_upper": cov_upper
        }
        step_error_curves_agg[key] = {
            "mean": step_mean,
            "lower": step_lower,
            "upper": step_upper
        }

rows = []

for env in envs:
    for enc_dim in encode_dims:
        key0 = (env, enc_dim, 0)
        key1 = (env, enc_dim, 1)
        if key0 in agg_metrics and key1 in agg_metrics:
            diff_weight = agg_metrics[key0]["WeightedError_mean"] - agg_metrics[key1]["WeightedError_mean"]
            diff_cov = agg_metrics[key0]["NormalizedCovLoss_mean"] - agg_metrics[key1]["NormalizedCovLoss_mean"]
            rows.append({
                "Environment": env,
                "EncodeDim": enc_dim,
                "WeightedError (CovReg off)": agg_metrics[key0]["WeightedError_mean"],
                "WeightedError (CovReg on)": agg_metrics[key1]["WeightedError_mean"],
                "Diff_WeightedError": diff_weight,
                "NormCovLoss (CovReg off)": agg_metrics[key0]["NormalizedCovLoss_mean"],
                "NormCovLoss (CovReg on)": agg_metrics[key1]["NormalizedCovLoss_mean"],
                "Diff_NormCovLoss": diff_cov
            })
        else:
            for cov in cov_regs:
                key = (env, enc_dim, cov)
                if key in agg_metrics:
                    rows.append({
                        "Environment": env,
                        "EncodeDim": enc_dim,
                        "WeightedError (CovReg off)" if cov==0 else "WeightedError (CovReg on)": agg_metrics[key]["WeightedError_mean"],
                        "Diff_WeightedError": np.nan,
                        "NormCovLoss (CovReg off)" if cov==0 else "NormCovLoss (CovReg on)": agg_metrics[key]["NormalizedCovLoss_mean"],
                        "Diff_NormCovLoss": np.nan
                    })

df_summary = pd.DataFrame(rows)
print("Summary Table:")
print(df_summary)
table_csv_path = "evaluation_summary.csv"
df_summary.to_csv(table_csv_path, index=False)
print(f"Summary table saved to {table_csv_path}")


Summary Table:
        Environment  EncodeDim  WeightedError (CovReg off)  \
0       LogisticMap          4                    0.254342   
1       LogisticMap         16                    0.215659   
2       LogisticMap         64                    0.217502   
3       LogisticMap        256                    0.217140   
4       LogisticMap       1024                    0.218892   
5   DampingPendulum          4                    0.013826   
6   DampingPendulum         16                    0.001175   
7   DampingPendulum         64                    0.000519   
8   DampingPendulum        256                    0.000497   
9   DampingPendulum       1024                    0.000495   
10           Franka          4                    0.516339   
11           Franka         16                    0.449228   
12           Franka         64                    0.371807   
13           Franka        256                    0.297837   
14           Franka       1024                    0.281

In [58]:
# Plot 1: Average Multi-step Prediction Error vs. Encode Dimension
for env in envs:
    out_dir = os.path.join(project_name, env)
    os.makedirs(out_dir, exist_ok=True)
    
    plt.figure(figsize=(8, 6))
    for cov_reg_val in cov_regs:
        x_vals = []
        y_vals = []
        lower_bounds = []
        upper_bounds = []
        for key, metrics in agg_metrics.items():
            if key[0] == env and key[2] == cov_reg_val:
                x_vals.append(key[1])
                y_vals.append(metrics["WeightedError_mean"])
                lower_bounds.append(metrics["WeightedError_lower"])
                upper_bounds.append(metrics["WeightedError_upper"])
        if len(x_vals) == 0:
            continue

        order = np.argsort(x_vals)
        x_vals = np.array(x_vals)[order]
        y_vals = np.array(y_vals)[order]
        lower_bounds = np.array(lower_bounds)[order]
        upper_bounds = np.array(upper_bounds)[order]
        label = f"CovReg {'on' if cov_reg_val==1 else 'off'}"
        plt.plot(x_vals, y_vals, marker='o', label=label)
        plt.fill_between(x_vals, lower_bounds, upper_bounds, alpha=0.3)
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel("Encode Dimension (log scale)")
    plt.ylabel("Average Multi-step Prediction Error (MSE, log scale)")
    plt.title(f"Multi-step Prediction Error vs. Encode Dimension for {env}")
    plt.legend()
    plt.grid(True, which="both", ls="--")
    plt.tight_layout()
    fig_path = os.path.join(out_dir, f"MultiStepError_{env}.png")
    plt.savefig(fig_path, dpi=300)
    print(f"Saved plot: {fig_path}")
    plt.close()

# Plot 2: Normalized Covariance Loss vs. Encode Dimension
for env in envs:
    out_dir = os.path.join(project_name, env)
    os.makedirs(out_dir, exist_ok=True)
    
    plt.figure(figsize=(8, 6))
    for cov_reg_val in cov_regs:
        x_vals = []
        y_vals = []
        lower_bounds = []
        upper_bounds = []
        for key, metrics in agg_metrics.items():
            if key[0] == env and key[2] == cov_reg_val:
                x_vals.append(key[1])
                y_vals.append(metrics["NormalizedCovLoss_mean"])
                lower_bounds.append(metrics["NormalizedCovLoss_lower"])
                upper_bounds.append(metrics["NormalizedCovLoss_upper"])
        if len(x_vals) == 0:
            continue
        order = np.argsort(x_vals)
        x_vals = np.array(x_vals)[order]
        y_vals = np.array(y_vals)[order]
        lower_bounds = np.array(lower_bounds)[order]
        upper_bounds = np.array(upper_bounds)[order]
        label = f"CovReg {'on' if cov_reg_val==1 else 'off'}"
        plt.plot(x_vals, y_vals, marker='o', label=label)
        plt.fill_between(x_vals, lower_bounds, upper_bounds, alpha=0.3)
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel("Encode Dimension (log scale)")
    plt.ylabel("Normalized Covariance Loss (log scale)")
    plt.title(f"Normalized Covariance Loss vs. Encode Dimension for {env}")
    plt.legend()
    plt.grid(True, which="both", ls="--")
    plt.tight_layout()
    fig_path = os.path.join(out_dir, f"NormalizedCovLoss_{env}.png")
    plt.savefig(fig_path, dpi=300)
    print(f"Saved plot: {fig_path}")
    plt.close()

# Plot 3: Multi-step Loss Curves for each environment
for env in envs:
    out_dir = os.path.join(project_name, env)
    os.makedirs(out_dir, exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    cmap = cm.get_cmap('tab10', len(encode_dims))

    dimension_handles = []
    covreg_handles = []
    
    for i, enc_dim in enumerate(encode_dims):
        color = cmap(i)

        dimension_handles.append(
            mlines.Line2D([], [], color=color, marker='o', linestyle='-', 
                          label=f'{enc_dim}')
        )
        
        for cov_reg_val in cov_regs:
            key = (env, enc_dim, cov_reg_val)
            if key in step_error_curves_agg:
                steps = np.arange(1, len(step_error_curves_agg[key]["mean"]) + 1)
                mean_curve = step_error_curves_agg[key]["mean"]
                lower_curve = step_error_curves_agg[key]["lower"]
                upper_curve = step_error_curves_agg[key]["upper"]
                
                linestyle = '-' if cov_reg_val == 0 else '--'

                ax.plot(
                    steps, mean_curve, marker='o', color=color, linestyle=linestyle
                )
                ax.fill_between(steps, lower_curve, upper_curve, color=color, alpha=0.3)
    
    cov_off_line = mlines.Line2D([], [], color='black', marker='', linestyle='-',
                                 label='Off')
    cov_on_line = mlines.Line2D([], [], color='black', marker='', linestyle='--',
                                label='On')
    covreg_handles.extend([cov_off_line, cov_on_line])

    legend1 = ax.legend(handles=dimension_handles, loc='upper left', bbox_to_anchor=(0, 1),
                        title='Encode Dimension', frameon=True)
    ax.add_artist(legend1)

    legend2 = ax.legend(handles=covreg_handles, loc='upper left', bbox_to_anchor=(0.17, 1),
                        title='Covariance Loss', frameon=True)
    ax.add_artist(legend2)

    ax.set_xlabel("Prediction Step")
    ax.set_ylabel("MSE Loss at Step (log scale)")
    ax.set_yscale('log')
    ax.set_title(f"Multi-step Loss Curves for {env}")
    ax.grid(True, which="both", ls="--")
    fig.tight_layout()
    
    fig_path = os.path.join(out_dir, f"MultiStepLossCurves_{env}.png")
    plt.savefig(fig_path, dpi=300)
    print(f"Saved plot: {fig_path}")
    plt.close(fig)

plt.show()


Saved plot: Koopman_Results_Apr_8_2/LogisticMap/MultiStepError_LogisticMap.png
Saved plot: Koopman_Results_Apr_8_2/DampingPendulum/MultiStepError_DampingPendulum.png
Saved plot: Koopman_Results_Apr_8_2/Franka/MultiStepError_Franka.png
Saved plot: Koopman_Results_Apr_8_2/DoublePendulum/MultiStepError_DoublePendulum.png
Saved plot: Koopman_Results_Apr_8_2/Polynomial/MultiStepError_Polynomial.png
Saved plot: Koopman_Results_Apr_8_2/G1/MultiStepError_G1.png
Saved plot: Koopman_Results_Apr_8_2/Go2/MultiStepError_Go2.png
Saved plot: Koopman_Results_Apr_8_2/LogisticMap/NormalizedCovLoss_LogisticMap.png
Saved plot: Koopman_Results_Apr_8_2/DampingPendulum/NormalizedCovLoss_DampingPendulum.png
Saved plot: Koopman_Results_Apr_8_2/Franka/NormalizedCovLoss_Franka.png
Saved plot: Koopman_Results_Apr_8_2/DoublePendulum/NormalizedCovLoss_DoublePendulum.png
Saved plot: Koopman_Results_Apr_8_2/Polynomial/NormalizedCovLoss_Polynomial.png
Saved plot: Koopman_Results_Apr_8_2/G1/NormalizedCovLoss_G1.png
Sav

  cmap = cm.get_cmap('tab10', len(encode_dims))


Saved plot: Koopman_Results_Apr_8_2/LogisticMap/MultiStepLossCurves_LogisticMap.png
Saved plot: Koopman_Results_Apr_8_2/DampingPendulum/MultiStepLossCurves_DampingPendulum.png
Saved plot: Koopman_Results_Apr_8_2/Franka/MultiStepLossCurves_Franka.png
Saved plot: Koopman_Results_Apr_8_2/DoublePendulum/MultiStepLossCurves_DoublePendulum.png
Saved plot: Koopman_Results_Apr_8_2/Polynomial/MultiStepLossCurves_Polynomial.png
Saved plot: Koopman_Results_Apr_8_2/G1/MultiStepLossCurves_G1.png
Saved plot: Koopman_Results_Apr_8_2/Go2/MultiStepLossCurves_Go2.png
