In [None]:
import os
if os.getcwd().split('\\')[-1] != 'koopman-rl' and os.getcwd().split('/')[-1] != 'koopman-rl':
    os.chdir('../')
print(f"Current working directory: {os.getcwd()}")

In [None]:
import gym
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import os
import pandas as pd
import torch

from custom_envs import *
from movies.env_enum import EnvEnum

In [3]:
# Helper function to create environments
def make_env(env_id, seed):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk

In [4]:
class Args:
    def __init__(
        self,
        seed=123,
        data_folder=True,
        save_every_n_steps=None,
        plot_uncontrolled=False,
        ma_window_size=None,
    ):
        self.seed = seed
        self.data_folder = data_folder
        self.save_every_n_steps = save_every_n_steps
        self.plot_uncontrolled = plot_uncontrolled
        self.ma_window_size = ma_window_size

data_folder_paths = {
    EnvEnum.LinearSystem: "./video_frames/LinearSystem-v0_1733955895",
    EnvEnum.FluidFlow: "./video_frames/FluidFlow-v0_1733955905",
    EnvEnum.Lorenz: "./video_frames/Lorenz-v0_1733955911",
    EnvEnum.DoubleWell: "./video_frames/DoubleWell-v0_1733955917"
}

args = Args(
    seed=123,

    data_folder=data_folder_paths[EnvEnum.LinearSystem],
    # data_folder=data_folder_paths[EnvEnum.FluidFlow],
    # data_folder=data_folder_paths[EnvEnum.Lorenz],
    # data_folder=data_folder_paths[EnvEnum.DoubleWell],

    save_every_n_steps=100,

    # plot_uncontrolled=False,
    plot_uncontrolled=True,

    ma_window_size=20,
    # ma_window_size=200,
)

In [None]:
# Load main policy data
main_policy_trajectories = np.load(f"{args.data_folder}/main_policy_trajectories.npy")
main_policy_actions = np.load(f"{args.data_folder}/main_policy_actions.npy")
main_policy_costs = np.load(f"{args.data_folder}/main_policy_costs.npy")

# Load baseline policy data
baseline_policy_trajectories = np.load(f"{args.data_folder}/baseline_policy_trajectories.npy")
baseline_policy_actions = np.load(f"{args.data_folder}/baseline_policy_actions.npy")
baseline_policy_costs = np.load(f"{args.data_folder}/baseline_policy_costs.npy")

#  Load zero policy data
if args.plot_uncontrolled:
    zero_policy_trajectories = np.load(f"{args.data_folder}/zero_policy_trajectories.npy")
    zero_policy_actions = np.load(f"{args.data_folder}/zero_policy_actions.npy")
    zero_policy_costs = np.load(f"{args.data_folder}/zero_policy_costs.npy")

# Load metadata
metadata = np.load(f"{args.data_folder}/metadata.npy", allow_pickle=True).item()

# Extract env_id
env_id = metadata['env_id']

# Function to reset seeds
def reset_seed():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

# Set seeds
reset_seed()

# Create gym env with ID
envs = gym.vector.SyncVectorEnv([make_env(env_id, args.seed)])

In [None]:
print(f"Main policy trajectories shape: {main_policy_trajectories.shape}")
print(f"Baseline policy trajectories shape: {baseline_policy_trajectories.shape}")
if args.plot_uncontrolled:
    print(f"Zero policy trajectories shape: {zero_policy_trajectories.shape}")

In [7]:
# Function to compute moving average, preserving the first n values
def moving_average(a, n, keep_first):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    moving_avg = ret[n - 1:] / n

    # If `keep_first` is True, concatenate the first n-1 values of the original array
    if keep_first:
        result = np.concatenate((a[:n - 1], moving_avg))
        print(result.shape)
    else:
        result = moving_avg

    return result

In [8]:
assert (
    np.array_equal(main_policy_trajectories[0, 0], baseline_policy_trajectories[0, 0]) and
    np.array_equal(main_policy_trajectories[0, 0], zero_policy_trajectories[0, 0])
), "Your trajectories have different initial conditions"

In [None]:
# Extract x, y, and z components from the trajectory data
full_x = main_policy_trajectories[0, :, 0]
full_y = main_policy_trajectories[0, :, 1]
full_z = main_policy_trajectories[0, :, 2]
if args.plot_uncontrolled:
    full_x_zero = zero_policy_trajectories[0, :, 0]
    full_y_zero = zero_policy_trajectories[0, :, 1]
    full_z_zero = zero_policy_trajectories[0, :, 2]

# In this notebook, we avoid looping to make this easier
x = full_x
y = full_y
z = full_z
if args.plot_uncontrolled:
    x_zero = full_x_zero
    y_zero = full_y_zero
    z_zero = full_z_zero

# Create trajectory figure
trajectory_fig = plt.figure(figsize=(21, 14), dpi=300, constrained_layout=True)
trajectory_ax = trajectory_fig.add_subplot(111, projection='3d')

# Place dot at the reference point from the environment
trajectory_ax.scatter3D(
    envs.envs[0].reference_point[0],
    envs.envs[0].reference_point[1],
    envs.envs[0].reference_point[2] if env_id != EnvEnum.DoubleWell else 0.0,
    color='black',
    s=100,
)

# Set axis limits according to the trajectory limits, using zero_policy only if plot_uncontrolled is True
if args.plot_uncontrolled:
    # Maximums
    max_x = np.max([np.max(full_x), np.max(full_x_zero)])
    max_y = np.max([np.max(full_y), np.max(full_y_zero)])
    max_z = np.max([np.max(full_z), np.max(full_z_zero)])

    # Minimums
    min_x = np.min([np.min(full_x), np.min(full_x_zero)])
    min_y = np.min([np.min(full_y), np.min(full_y_zero)])
    min_z = np.min([np.min(full_z), np.min(full_z_zero)])
else:
    # Maximums
    max_x = np.max(full_x)
    max_y = np.max(full_y)
    max_z = np.max(full_z)

    # Minimums
    min_x = np.min(full_x)
    min_y = np.min(full_y)
    min_z = np.min(full_z)

trajectory_ax.set_xlim(min_x, max_x)
trajectory_ax.set_ylim(min_y, max_y)
if env_id == EnvEnum.Lorenz:
    trajectory_ax.set_zlim(min_z * 2, max_z * 0.8)
else:
    trajectory_ax.set_zlim(min_z, max_z)

# Set the point up to which you want to see the trajectory plot
# For example, a value of `10` means you want to see the state at the first 10 points in time
step_num = 5

# Plot trajectory
if env_id == EnvEnum.DoubleWell:
    pass  # TODO: Get double well working for this notebook

    # trajectory_ax.plot3D(x, y, Z_path, alpha=1.0, linewidth=2, color='black', pad=0.1)
    # trajectory_ax.plot_surface(X, Y, Z, alpha=0.7, cmap=cm.coolwarm)
    # trajectory_ax.set_zlim(0,15)
else:
    # Plot
    if step_num == 1:
        if args.plot_uncontrolled:
            trajectory_ax.scatter3D(x_zero[0], y_zero[0], z_zero[0], color='tab:blue')
        trajectory_ax.scatter3D(x[0], y[0], z[0], linewidth=3, color='tab:orange')
    else:
        if args.plot_uncontrolled:
            # Plot the zero trajectory on the same graph
            trajectory_ax.plot3D(x_zero[:step_num], y_zero[:step_num], z_zero[:step_num], color='tab:blue')
        trajectory_ax.plot3D(x[:step_num], y[:step_num], z[:step_num], linewidth=3, color='tab:orange')

    # Adjust the view angle for better visibility
    # trajectory_ax.view_init(elev=20, azim=45)

    # Adjust layout to reduce white space
    # plt.tight_layout(pad=0.01)

    # Turn off grid
    trajectory_ax.grid(False)

    # Turn off axis
    trajectory_ax.set_axis_off()

# plt.tight_layout()
# plt.savefig('jupyter-output/beans.png')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_vector_field(env, resolution=20):
    # Define the scaling factors
    min_scaling_factor = [0.8, 1.2]
    max_scaling_factor = [1.2, 0.8]

    # Calculate the minimum values for x, y, and z
    x_min = np.min(main_policy_trajectories[0, :, 0])
    x_min = x_min * min_scaling_factor[0] if x_min > 0 else x_min * min_scaling_factor[1]
    y_min = np.min(main_policy_trajectories[0, :, 1])
    y_min = y_min * min_scaling_factor[0] if y_min > 0 else y_min * min_scaling_factor[1]
    z_min = np.min(main_policy_trajectories[0, :, 2])
    z_min = z_min * min_scaling_factor[0] if z_min > 0 else z_min * min_scaling_factor[1]

    # Calculate the maximum values for x, y, and z
    x_max = np.max(main_policy_trajectories[0, :, 0])
    x_max = x_max * max_scaling_factor[0] if x_max > 0 else x_max * max_scaling_factor[1]
    y_max = np.max(main_policy_trajectories[0, :, 1])
    y_max = y_max * max_scaling_factor[0] if y_max > 0 else y_max * max_scaling_factor[1]
    z_max = np.max(main_policy_trajectories[0, :, 2])
    z_max = z_max * max_scaling_factor[0] if z_max > 0 else z_max * max_scaling_factor[1]

    # Define the state ranges
    state_min = np.array([x_min, y_min, z_min])
    state_max = np.array([x_max, y_max, z_max])

    # Create a grid of points within the state space
    x = np.linspace(state_min[0], state_max[0], resolution)
    y = np.linspace(state_min[1], state_max[1], resolution)
    z = np.linspace(state_min[2], state_max[2], resolution)
    X, Y, Z = np.meshgrid(x, y, z)

    # Flatten the grid for computation
    grid_points = np.stack([X.flatten(), Y.flatten(), Z.flatten()], axis=1)

    # Compute the vector field
    zero_action = np.zeros((1, env.action_dim))
    deltas = []
    next_states = []

    for point in grid_points:
        point = point.reshape(-1, 1)
        next_state = env.f(point, zero_action)
        delta = next_state - point  # Change in state
        deltas.append(delta)
        next_states.append(next_state)

    deltas = np.array(deltas)
    next_states = np.array(next_states)
    dX, dY, dZ = deltas[:, 0], deltas[:, 1], deltas[:, 2]
    X_prime, Y_prime, Z_prime = next_states[:, 0], next_states[:, 1], next_states[:, 2]

    # Reshape the vectors to match the grid
    dX = dX.reshape(X.shape)
    dY = dY.reshape(Y.shape)
    dZ = dZ.reshape(Z.shape)

    # Plot the vector field
    fig = plt.figure(figsize=(21, 14), dpi=300, constrained_layout=True)
    ax = fig.add_subplot(111, projection='3d')

    # Quiver plot for 3D vector field
    ax.quiver(
        X, Y, Z,  # Grid points
        dX, dY, dZ,  # Vector components
        length=1.0,
        normalize=True,
        # linewidth=0.5,
        color='black',
        alpha=0.5
    )

    flattened_X = X.flatten()
    flattened_Y = Y.flatten()
    flattened_Z = Z.flatten()
    flattened_dX = dX.flatten()
    flattened_dY = dY.flatten()
    flattened_dZ = dZ.flatten()
    stacked_diffs = np.stack(
        [
            flattened_dX,
            flattened_dY,
            flattened_dZ,
        ],
        axis=1
    )  # (512, 3)
    distances = np.linalg.norm(stacked_diffs, axis=1)  # (512,)
    flattened_dX = flattened_dX / distances
    flattened_dY = flattened_dY / distances
    flattened_dZ = flattened_dZ / distances

    df = pd.DataFrame(
        {
            "X": flattened_X,
            "Y": flattened_Y,
            "Z": flattened_Z,
            "X_prime": flattened_X + flattened_dX,
            "Y_prime": flattened_Y + flattened_dY,
            "Z_prime": flattened_Z + flattened_dZ,
        }
    )
    df.to_csv(f"{args.data_folder}/{env_id}_vector_field.csv")

    # Set axis labels and limits
    ax.set_xlim(state_min[0], state_max[0])
    ax.set_ylim(state_min[1], state_max[1])
    ax.set_zlim(state_min[2], state_max[2])
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.set_zlabel("Z-axis")
    # ax.set_title("Vector Field for LinearSystem Environment")

    # Plot the reference point from the environment
    ax.scatter3D(
        envs.envs[0].reference_point[0],
        envs.envs[0].reference_point[1],
        envs.envs[0].reference_point[2] if env_id != EnvEnum.DoubleWell else 0.0,
        color='black',
        s=100,
    )

    # Set the point up to which you want to see the trajectory plot
    # For example, a value of `10` means you want to see the state at the first 10 points in time
    step_num = 5

    if step_num == 1:
        ax.scatter3D(
            zero_policy_trajectories[0, 0, 0],
            zero_policy_trajectories[0, 0, 1],
            zero_policy_trajectories[0, 0, 2],
            color='tab:blue',
            s=100,
        )
        ax.scatter3D(
            main_policy_trajectories[0, 0, 0],
            main_policy_trajectories[0, 0, 1],
            main_policy_trajectories[0, 0, 2],
            color='tab:orange',
            s=100,
        )
    else:
        ax.plot3D(
            zero_policy_trajectories[0, :step_num, 0],
            zero_policy_trajectories[0, :step_num, 1],
            zero_policy_trajectories[0, :step_num, 2],
            color='tab:blue',
            linewidth=4,
        )
        ax.plot3D(
            main_policy_trajectories[0, :step_num, 0],
            main_policy_trajectories[0, :step_num, 1],
            main_policy_trajectories[0, :step_num, 2],
            color='tab:orange',
            linewidth=4,
        )

    # Turn off grid
    ax.grid(False)

    # Turn off axis
    ax.set_axis_off()

    # plt.savefig(f'jupyter-output/test_vector_field_step{step_num}.png')
    plt.show()

# Example usage
plot_vector_field(envs.envs[0], resolution=8)

In [None]:
# Create cost figure
cost_fig = plt.figure(figsize=(21, 14), dpi=300)
cost_ax = cost_fig.add_subplot(111)
cost_ax.set_ylabel('')  # Removes the y-axis label on a specific axis

# Calculate the cost ratios for this iteration
all_main_costs = main_policy_costs[0]
all_baseline_costs = baseline_policy_costs[0]
all_cost_ratios = all_main_costs / all_baseline_costs
log_all_cost_ratios = np.log(all_main_costs / all_baseline_costs)
moving_average_log_all_cost_ratios = moving_average(
    log_all_cost_ratios,
    n=args.ma_window_size,
    keep_first=False
)

# Calculate the x values for the moving average
cost_x_values_start_index = args.ma_window_size - 1
moving_average_log_all_cost_ratio_x_values = np.arange(
    cost_x_values_start_index,
    moving_average_log_all_cost_ratios.shape[0] + cost_x_values_start_index
)

# In this notebook, we avoid looping to make this easier
log_cost_ratio = log_all_cost_ratios
moving_average_log_cost_ratio = moving_average_log_all_cost_ratios
moving_average_log_cost_ratio_x_value = moving_average_log_all_cost_ratio_x_values

# Calculate the overall min and max for consistent scaling
min_log_cost_ratio = np.min(log_all_cost_ratios)
max_log_cost_ratio = np.max(log_all_cost_ratios)

# Set axis limits
cost_ax.set_xlim(0, main_policy_costs.shape[1])
cost_ax.set_ylim(min_log_cost_ratio * 1.1, max_log_cost_ratio * 1.1)

# Set axis labels
# cost_ax.set_xlabel("Steps")
# cost_ax.set_ylabel(f"Cost Ratio ({metadata['main_policy_name']} / {metadata['baseline_policy_name']})")

# Set axis title
cost_ax.set_title(f"Cost Ratio: {metadata['main_policy_name']} / {metadata['baseline_policy_name']}")

# Plot a horizontal line at y=0
# cost_ax.axhline(y=0, color='red', linestyle='--')
cost_ax.axhline(y=0, color='black', linestyle='--')

# Make Title larger
cost_ax.title.set_fontsize(20)

# Adjust layout to reduce white space
# plt.tight_layout(pad=0.1)

# Turn on grid
cost_ax.grid(True)

# Plot values
step_num = 100
log_cost_ratio_opacity = 0.4

if step_num == 1:
    cost_ax.scatter(0.0, log_cost_ratio[0], alpha=log_cost_ratio_opacity, color='tab:orange')
else:
    cost_ax.plot(log_cost_ratio[:step_num], alpha=log_cost_ratio_opacity, color='tab:orange')
    if step_num == args.ma_window_size:
        cost_ax.scatter(moving_average_log_cost_ratio_x_value[:step_num], moving_average_log_cost_ratio[0], color='tab:blue', linewidth=3)
    elif step_num > args.ma_window_size:
        cost_ax.plot(moving_average_log_cost_ratio_x_value[:step_num-args.ma_window_size], moving_average_log_cost_ratio[:step_num-args.ma_window_size], color='tab:blue', linewidth=3)

# plt.savefig('jupyter-output/test_cost.png')
plt.show()