In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from models import dnn, lnn
import gymnasium

In [None]:
torch.cuda.empty_cache()

In [None]:
env = gymnasium.make("Pendulum-v1")

In [None]:
device = torch.device("cuda:0")
batch_size = 1
n = 1

In [None]:
a_scale = torch.tensor([2.0], dtype=torch.double, device=device)
a_scale

In [None]:
obs_size = env.observation_space.shape[0]
action_size = env.action_space.shape[0]
a_zeros = None
with torch.no_grad():
    dnn_model = dnn(obs_size, action_size).double().to(device)
    lnn_model = lnn(env_name="pendulum", n=1, obs_size=obs_size,action_size= action_size, dt=0.05, a_zeros= a_zeros).double().to(device)

In [None]:
print(dnn_model)

In [None]:
print(lnn_model)

In [None]:
# Load state dict to dnn
dnn_model.load_state_dict(torch.load('dnn_49.ckpt')['transition_model'])
lnn_model.load_state_dict(torch.load('lnn_99.ckpt')['transition_model'])

In [None]:
dnn_pred_os = []
lnn_pred_os = []
true_os = []

o, _ = env.reset()
o_tensor = torch.tensor(o, dtype=torch.float64, device=device).unsqueeze(0)

# Save true observation
dnn_pred_os.append(o_tensor)
lnn_pred_os.append(o_tensor)
true_os.append(o_tensor)

# take random action
a = np.random.uniform(-2.0, 2.0, size=action_size)
a_tensor = (torch.tensor(a, dtype=torch.float64, device=device)*a_scale).unsqueeze(0)

In [None]:
# DNN precit next observation
dnn_pred_o = dnn_model(o_tensor, a_tensor)
print(dnn_pred_o)

In [None]:
# LNN predict next observation
lnn_pred_o = lnn_model(o_tensor, a_tensor)
print(lnn_pred_o)

In [None]:
o_t, r, done, _, _ = env.step(a)
o_t_tensor = torch.tensor(o_t, dtype=torch.float64, device=device).unsqueeze(0)
print(o_t_tensor)

In [None]:
dnn_pred_os.append(dnn_pred_o) 
lnn_pred_os.append(lnn_pred_o)
true_os.append(o_t_tensor)

In [None]:
for _ in range(20):
    # take new random action
    a = np.random.uniform(-2.0, 2.0, size=action_size)
    a_tensor = (torch.tensor(a, dtype=torch.float64, device=device)*a_scale).unsqueeze(0)

    # ground truth transition
    o_t, r, done, _, _ = env.step(a)
    o_t_tensor = torch.tensor(o_t, dtype=torch.float64, device=device).unsqueeze(0)

    # pred transition with dnn
    dnn_pred_o = dnn_model(dnn_pred_o, a_tensor)
    
    # pred transition with lnn
    lnn_pred_o = lnn_model(lnn_pred_o, a_tensor)
    
    dnn_pred_os.append(dnn_pred_o)
    lnn_pred_os.append(lnn_pred_o)
    true_os.append(o_t_tensor)

In [None]:
dnn_error = []
lnn_error = []
for dnn, lnn, true in zip(dnn_pred_os, lnn_pred_os, true_os):
    dnn_error.append(torch.abs(dnn-true).sum().cpu().detach().numpy())
    lnn_error.append(torch.abs(lnn-true).sum().cpu().detach().numpy())

In [None]:
plt.plot(dnn_error, label='DNN Error')
plt.plot(lnn_error, label='LNN Error')

plt.xlabel('Time Step')
plt.ylabel('Error')
plt.legend()
plt.show()

In [None]:
# Data for the bar chart
labels = ['DNN', 'LNN', 'Ground Truth']
values = [-248, -246, -123]

# Set the color for each bar
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']

# Create the bar chart with improved aesthetics
plt.figure(figsize=(10, 6))
bars = plt.bar(labels, values, color=colors)

# Add title and labels with a larger font for better readability
plt.title('Comparison of Total Reward Across Models', fontsize=16)
plt.xlabel('Model', fontsize=14)
plt.ylabel('Total Reward', fontsize=14)

# Change the style of the axes for a cleaner look
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('gray')
plt.gca().spines['left'].set_linewidth(0.5)
plt.gca().spines['bottom'].set_color('gray')
plt.gca().spines['bottom'].set_linewidth(0.5)

# Add a light background color
plt.gca().set_facecolor('#f5f5f5')

# Add value labels on top of each bar
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval + 5, round(yval, 1), 
             ha='center', va='bottom', fontsize=12)

# Show the bar chart with a tight layout to ensure everything fits well
plt.tight_layout()
plt.show()