In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('../..')))

import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from reservoirgrid.helpers import chaos_utils
from reservoirgrid.helpers import viz

In [2]:
path = "../../Examples/Input_Discretization/results/Chaotic/"
save_path = "../../Examples/Input_Discretization/Plots/SingleMetric/"
system_name = "Lorenz"
system_path = os.path.join(path, system_name)

In [3]:
file = os.path.join(system_path, "100.0.pkl")
with open(file, "rb") as f:
    data_10 = pickle.load(f)

In [4]:
all_lyapunov = []
all_kldiv = []
all_params = []
all_jsdiv = []
all_skl = []

for data in data_10:
    lyap1 = chaos_utils.lyapunov_time(data["true_value"], data["predictions"])
    Kldiv = chaos_utils.kl_divergence(data["true_value"], data["predictions"])
    jsdiv = chaos_utils.js_divergence(data["true_value"], data["predictions"])
    skl = chaos_utils.symmetric_kl(data["true_value"], data["predictions"])

    all_skl.append(skl)
    all_jsdiv.append(jsdiv)
    all_lyapunov.append(lyap1)
    all_kldiv.append(Kldiv)
    
    params = data["parameters"]
    all_params.append(params)

In [10]:
jsidx = np.argpartition(all_jsdiv, 10)[:10]
klidx = np.argpartition(all_kldiv, 10)[:10]
sklidx = np.argpartition(all_skl, 10)[:10]

In [16]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

# define grid shape
n = len(jsidx)
cols = 5   # needs atleast 10 values to se how the kl divergence favours the coverage of base attractor over the perfect coverage
rows = int(np.ceil(n / cols))

clean_scene = dict(
    camera=dict(
        eye=dict(x=-1.5, y=1.5, z=0.5),
        center=dict(x=0, y=0, z=0),
        up=dict(x=0, y=0, z=1)
    ),
    xaxis=dict(showbackground=False, showgrid=False, zeroline=False,
               showticklabels=False, title=''),
    yaxis=dict(showbackground=False, showgrid=False, zeroline=False,
               showticklabels=False, title=''),
    zaxis=dict(showbackground=False, showgrid=False, zeroline=False,
               showticklabels=False, title=''),
    # optional: aspect ratio
    aspectmode="data"    
)

# create subplot figure with 3D subplots
fig = make_subplots(
    rows=rows, cols=cols,
    specs=[[{"type": "scene"} for _ in range(cols)] for _ in range(rows)],
    subplot_titles=[f"{all_skl[u]:.4f}" for u in sklidx],
    horizontal_spacing = 0 ,  # reduce horizontal gap (0 = no gap)
    vertical_spacing = 0     # reduce vertical gap
)

# loop and add traces
for i, u in enumerate(sklidx):
    r, c = divmod(i, cols)

    # call your existing function
    best_js = np.array([
        data_10[u]["true_value"],
        data_10[u]["predictions"].numpy()
    ])
    subfig = viz.compare_plot(best_js)  # <- returns a plotly figure

    # extract traces and add them into subplot
    for trace in subfig.data:
        fig.add_trace(trace, row=r+1, col=c+1)
        
# Apply to all subplot scenes
layout_updates = {}
for i in range(1, rows*cols + 1):
    scene_name = "scene" if i == 1 else f"scene{i}"
    layout_updates[scene_name] = clean_scene

fig.update_layout(
    margin=dict(l=0, r=0, t=30, b=0),  # trim margins
    showlegend = False,
    **layout_updates
)

fig.write_image("SKL_best.png", width=1250, height=500)
fig.show()

In [7]:
selected = 3
viz.compare_plot([data_10[sklidx[selected]]["true_value"], data_10[sklidx[selected]]["predictions"]], 
                 legend_names=["True Value", "Predictions"], 
                 title="Symmetric KL Divergence")

viz.compare_plot([data_10[jsidx[selected]]["true_value"], data_10[jsidx[selected]]["predictions"]], 
                 legend_names=["True Value", "Predictions"], 
                 title="Jensen-Shannon Divergence")

viz.compare_plot([data_10[klidx[selected]]["true_value"], data_10[klidx[selected]]["predictions"]], 
                 legend_names=["True Value", "Predictions"], 
                 title="KL Divergence")

In [8]:
df = pd.DataFrame(
    {
        "Jensen_Divergence" : all_jsdiv ,
        "KL_Divergence": all_kldiv,
        "Symmetric_KL" : all_skl,
        "Lyapunov_Time" : all_lyapunov,
        "Parameters" : all_params

    }
)
