In [None]:
import math
import onnxruntime
import pandas as pd
from sim.sim_env import MoabSim
import PIL
import matplotlib.pyplot as plt

In [None]:
# Launch an onnxruntime session to run inference on the policy
session = onnxruntime.InferenceSession("outputs/model.onnx")

In [None]:
# Initialize the simulation environment
sim = MoabSim(env_config=None)

# Define the low and high bounds of your action space.
low = sim.action_space.low
high = sim.action_space.high

In [None]:
# Define a constant for the minimum log value to avoid numerical issues
MIN_LOG_VALUE = -1e7

# Define a function for converting logits to actions for a SquashedGaussian distribution
# NOTE: This function is specific to the PPO algorithm used during training.
#       If you use a different algorithm, you might need to modify this function.
def logits_to_actions(logits, low=-1.0, high=1.0):
    # Split the logits into mean and log_std
    means, log_stds = logits[::2], logits[1::2]
    actions = []
    i = 0
    for mean, log_std in zip(means, log_stds):
        # Clip the log_std to a reasonable range
        log_std = max(min(log_std, -MIN_LOG_VALUE), MIN_LOG_VALUE)
        # Compute the std from the log_std
        std = math.exp(log_std)
        # Create a normal distribution with the mean and std
        normal_dist = [mean, std]
        # Sample a value from the normal distribution
        normal_sample = mean # use the mean for deterministic output
        # Apply a tanh function to the normal sample
        tanh_sample = math.tanh(normal_sample)
        # Scale the tanh sample by the low and high bounds of the action space
        action = low[i] + (high[i] - low[i]) * (tanh_sample + 1) / 2
        actions.append(action)
        i += 1
    # Return the list of actions
    return actions



In [None]:
# Episodes
n_episodes = 30
logs = []
for episode in range(n_episodes):
    # Resets the environment
    sim.reset()
    # Runs the simulation while the episode is not terminated or truncated
    terminated = False
    truncated = False
    step = 0
    while not terminated and not truncated:
        # Run inference on the observation and get the logits.
        logits = session.run(None, {"obs": [sim._get_obs()], "state_ins": [None]})[0][0]
        actions = logits_to_actions(logits, low=low, high=high)
        # if step<5:
        #     action = np.random.uniform(-1, 1, 2)
        state, reward, terminated, truncated, _ = sim.step(actions)
        step += 1
        # Logs
        log = {}
        log["episode"] = episode
        log["step"] = step
        log["ball_x"] = state[0]
        log["ball_y"] = state[1]
        log["ball_vel_x"] = state[2]
        log["ball_vel_y"] = state[3]
        log["input_pitch"] = actions[0]
        log["input_roll"] = actions[1]
        log["reward"] = reward
        log["terminated"] = terminated
        log["truncated"] = truncated
        logs.append(log)

logs_df = pd.DataFrame(logs)

In [None]:
def generate_plot(episode_id, logs_df):
    fig, ax = plt.subplots()
    circle = plt.Circle((0, 0), 0.1125, color="lightgray") # Create a circle object
    ax.add_patch(circle) # Add the circle to the axis
    ax.set_aspect("equal") # Set the aspect ratio to be equal
    ax.set_xlim(-0.15, 0.15) # Set the x-axis limits
    ax.set_ylim(-0.15, 0.15) # Set the y-axis limits

    # Plot the trajectory using step, ball_x and ball_y from logs_df
    scatter = plt.scatter(logs_df.loc[logs_df["episode"] == episode_id, "ball_x"],
                          logs_df.loc[logs_df["episode"] == episode_id, "ball_y"],
                          c=logs_df.loc[logs_df["episode"] == episode_id, "reward"],
                          cmap="plasma") 
    cbar = fig.colorbar(scatter) # Add a colorbar to the plot
    cbar.ax.set_title("Reward", fontsize=10) # Set the label of the colorbar
    plt.clim(0, 1)
    # Show the plot
    plt.title(f"Moab Ball Trajectory #{episode_id+1}") # Set the title of the plot
    plt.xlabel("ball_x") # Set the x-axis label
    plt.ylabel("ball_y") # Set the y-axis label
    plt.close(fig)
    fig.canvas.draw()
    img = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
    return img

imgs = [generate_plot(episode_id, logs_df) for episode_id in range(n_episodes)]

In [None]:
_, axs = plt.subplots(10, 3, figsize=(18, 50))
axs = axs.flatten()
for img, ax in zip(imgs, axs):
    ax.imshow(img)
    ax.axis("off")
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()