# MuZero Visualization Analysis
This code provides a visualization analysis of the MountainCar environment's observation space and the latent states of a trained MuZero model using PCA (Principal Component Analysis) according to paper [Visualizing MuZero Models](http://arxiv.org/abs/2102.12924). The goal is to gain insights into the specific types of representations learned by these models.

The code performs the following steps:

1. Import the required libraries and modules.
2. Load the MountainCar environment and its observation space.
3. Load a pre-trained MuZero model.
4. Generate random observation samples (observations) from the MountainCar environment.
5. Perform PCA to reduce the dimensionality of the original observation space.
6. Extract the latent states from the MuZero model by inputting the observation samples into the model's representation network and extracting the output.
7. Apply PCA to reduce the dimensionality of the latent states.

Visualize the reduced observation space and latent states.
By conducting these visualization analyses, we can gain a deeper understanding of how the MuZero model learns and represents information in the MountainCar environment.

For more information about the mountain_car environment, see [mountain_car doc](https://www.gymlibrary.dev/environments/classic_control/mountain_car/).

In [None]:
import os
from functools import partial
from typing import Optional, Tuple

import numpy as np
import torch
from tensorboardX import SummaryWriter
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.torch_utils import to_tensor, to_device, to_ndarray
from ding.worker import BaseLearner
from lzero.worker import MuZeroEvaluator
from lzero.policy import InverseScalarTransform, mz_network_output_unpack

from zoo.classic_control.mountain_car.config.mtcar_muzero_config import main_config, create_config
# from lzero.entry import eval_muzero
import numpy as np

from typing import Optional, Tuple, List

## Load Model
This code segment loads a pre-trained MuZero model and performs evaluation on the MountainCar environment.

It sets up the necessary configurations, components, and dependencies for the evaluation process.
The MuZero model is loaded from the specified model path and the evaluation is performed using the MuZeroEvaluator.
The evaluation results, including trajectories and returns, are stored for further analysis and visualization.

This code provides a convenient way to evaluate the performance of a trained MuZero model in the MountainCar environment.

In [2]:
model_path = "your_path/mountain_car_muzero_seed0/ckpt/ckpt_best.pth.tar"
returns_mean_seeds = []
returns_seeds = []
seed = 0
num_episodes_each_seed = 1
total_test_episodes = num_episodes_each_seed
create_config.env_manager.type = 'base'  # Visualization requires the 'type' to be set as base
main_config.env.evaluator_env_num = 1  # Visualization requires the 'env_num' to be set as 1
main_config.env.n_evaluator_episode = total_test_episodes
main_config.env.replay_path = 'lz_result/video/mtcar_mz'
main_config.exp_name = f'lz_result/eval/muzero_eval_ls{main_config.policy.model.latent_state_dim}'

In [None]:
cfg, create_cfg = main_config, create_config
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
    "LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"

if cfg.policy.cuda and torch.cuda.is_available():
    cfg.policy.device = 'cuda'
else:
    cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
    policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
evaluator = MuZeroEvaluator(
    eval_freq=cfg.policy.eval_freq,
    n_evaluator_episode=cfg.env.n_evaluator_episode,
    stop_value=cfg.env.stop_value,
    env=evaluator_env,
    policy=policy.eval_mode,
    tb_logger=tb_logger,
    exp_name=cfg.exp_name,
    policy_config=policy_config
)

# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')

# ==============================================================
# eval trained model
# ==============================================================
stop_flag, episode_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, return_trajectory=True)
trajectorys = episode_info['trajectory']
returns = episode_info['eval_episode_return']
returns = np.array(returns)

## Original space enumeration

This code segment provides functions to work with the original state space of an environment.

The `create_grid` function creates a grid in the original state space based on the specified resolution.

The `get_state_space` function generates the state space grid using the observation space of the environment.

The `embedding_manifold` function computes the latent states, values, and policy logits for the given state space using a trained model.

The code then applies these functions to the MountainCar environment, printing the shapes of the original state space and the corresponding latent state space.
This code facilitates the exploration and analysis of the original state space and its embeddings using a trained model.

In [4]:
def create_grid(v_mins: List, v_maxs: List, resolution: int) -> np.ndarray:
    data = list(map(lambda r: np.linspace(*r, resolution), zip(v_mins, v_maxs)))
    grid = np.asarray(np.meshgrid(*data, indexing="ij")).T.reshape(-1, len(v_mins))
    return grid

def get_state_space(env, resolution: int = 25) -> np.ndarray:
    obs_space = env.observation_space
    state_space = create_grid(obs_space.low, obs_space.high, resolution)
    return state_space


def embedding_manifold(state_space, model, return_pis: bool = False, policy_cfg = None) -> Tuple:
    with torch.no_grad():
        network_output = model.initial_inference(state_space)
    latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output)
    inverse_scalar_transform_handler = InverseScalarTransform(
        policy_cfg.model.support_scale,
        policy_cfg.device,
        policy_cfg.model.categorical_distribution)
    value_real = inverse_scalar_transform_handler(value)

    if return_pis:
        return to_ndarray(latent_state.cpu()), to_ndarray(value_real.cpu()), to_ndarray(policy_logits.cpu())
    
    return to_ndarray(latent_state.cpu()), to_ndarray(value_real.cpu())

In [None]:
delta = 250
state_space = get_state_space(evaluator_env, delta)
state_space_tensor = to_device(to_tensor(state_space), policy_config.device)
latent_state_space, v_state_space = embedding_manifold(state_space_tensor, policy._model, policy_cfg=policy_config)
print(state_space.shape, latent_state_space.shape)

## PCA

This code segment provides functions related to Principal Component Analysis (PCA) for latent states.

The `embedding_PCA` function performs PCA on the given latent states array.
It accepts an optional `standardize` parameter to control whether the data should be standardized before performing PCA.
The function computes the principal components and the explained variance ratio.
It then creates a bar chart to visualize the explained variance ratio for each principal component.
Additionally, it creates a violin plot to show the distribution of the projected values.
The function returns the PCA object.

The code applies the `embedding_PCA` function to the `latent_state_space` array twice, once with standardization disabled and once with it enabled.
Finally, it transforms the `latent_state_space` using the computed PCA for further analysis or visualization.

This code provides a convenient way to perform PCA on latent states and visualize the results.

In [6]:
def embedding_PCA(latent_states: np.ndarray, standardize: bool = False):   
    x = latent_states
    if standardize:
        x = (x - x.mean(axis=0)) / x.std(axis=0)
    
    # Perform PCA on latent dimensions
    pca = PCA(n_components=x.shape[-1])
    pca.fit(x)
    spcs = pca.fit_transform(x)
    
    # Create barchart
    ns = list(range(x.shape[-1]))
    var = pca.explained_variance_ratio_
    
    bar = plt.bar(ns, var)
    
    plt.title(f"PCA on latent-states (standardize={standardize})")
    plt.ylabel("Explained Variance Ratio")
    plt.xlabel("Principal Component")
    
    
    for i in range(len(var)):
        plt.annotate(f'{var[i]:.3f}', xy=(ns[i],var[i]), ha='center', va='bottom')

    plt.show()
    
    # Create violinplot
    plt.violinplot(spcs)
    plt.xticks(range(1, x.shape[-1]+1), range(1, x.shape[-1]+1))
    
    plt.title(f"Projected values distribution (standardize={standardize})")
    plt.ylabel("PC values")
    plt.xlabel("Principal Component")
    
    plt.show()
    
    return pca

In [None]:
pca = embedding_PCA(latent_state_space, False)
pca_norm = embedding_PCA(latent_state_space, True)
pca_latent_state_space = pca.transform(latent_state_space)

## Original space/PCA latent state visualization

This code segment provides functions for visualizing the state space.

The `to_grid` function reshapes the input array into a grid with a specified resolution.

The `simple_PC_value_contour` function creates a scatter plot to visualize the PCA-transformed latent states.
It takes the first two PCA components (`pc_1` and `pc_2`) along with the corresponding values (`z`) as inputs.
The scatter points are colored based on the values (`z`) and a colorbar is added to indicate the value scale.

The `simple_MC_value_contour` function creates a contour plot to visualize the original state space of the MountainCar environment.
It takes the position values (`x`), velocity values (`y`), and corresponding values (`z`) as inputs.
The contour levels represent the values (`z`), and a colorbar is added to show the value scale.

The code applies these visualization functions to the PCA-transformed latent state space (`pca_latent_state_space`) and the original state space (`state_space`).
The resulting plots provide visual representations of the value distribution in the latent space and the original state space.

In [8]:
# State space visualization
def to_grid(x: np.ndarray, delta: int) -> np.ndarray:
    return x.reshape(delta, delta)


def simple_PC_value_contour(pc_1: np.ndarray, pc_2: np.ndarray, z: np.ndarray) -> None:
    # Draw the latent state after PCA
    plt.scatter(pc_1, pc_2, c=z, alpha=0.5, s=5, cmap='rainbow')

    cbar = plt.colorbar()
    cbar.set_label(r'$V_\theta(o_t)$')

    plt.title("Value Contour MuZero PC-Space")
    plt.ylabel(r"First PCA component $h_\theta(o_t)$")
    plt.xlabel(r"Second PCA component $h_\theta(o_t)$")


def simple_MC_value_contour(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None:
    # Draw the original state
    # Simple Example Figure for a 2-d env
    plt.title("Value Contour MuZero MountainCar")
    plt.ylabel("Velocity")
    plt.xlabel("Position")

    plt.contourf(x, y, z, levels=100, cmap='rainbow')

    cbar = plt.colorbar()
    cbar.set_label(r'$V_\theta(o_t)$')

In [None]:
simple_PC_value_contour(pca_latent_state_space[:, 0], pca_latent_state_space[:, 1], v_state_space)
plt.show()
simple_MC_value_contour(to_grid(state_space[:,0], delta), to_grid(state_space[:,1], delta), to_grid(v_state_space, delta))
plt.show()

# Get the trajectory obtained by eval

This code segment involves processing game trajectories and obtaining latent state dynamics.

The `get_latent_trajectory` function takes embeddings, actions, and a model as inputs and returns the latent state trajectory.
It initializes the latent state with the first embedding, and then iteratively computes the latent state using the model's recurrent inference.
The resulting latent states are stored in a list and concatenated to form the stacked latent state trajectory.

The code then loads a real state trajectory (`real_state`) and corresponding actions (`actions`).
These trajectories are converted to tensors and processed using the policy model.
The resulting latent state representations (`latent_state_represent`) and value trajectory (`v_trajectorys`) are converted to NumPy arrays.

Next, the code calls the `get_latent_trajectory` function to obtain the latent state dynamics based on the latent state representations and actions.
The obtained latent state dynamics are then projected to the PC-space using the precomputed PCA object (`pca`).
The resulting PC-space trajectories are stored in `pc_embedding_trajectory` and `pc_dynamics_trajectory`.

This code provides a way to process game trajectories, extract latent state dynamics, and project them into the PC-space for further analysis or visualization.

In [10]:
# Game trajectory dynamics latent state processing
def get_latent_trajectory(embeddings: torch.Tensor, actions: torch.Tensor, model) -> np.ndarray:
    latent_state = embeddings[0].unsqueeze(0)
    
    latent_states = list()
    latent_states.append(to_ndarray(latent_state.cpu()))
    with torch.no_grad():
        for i in range(len(actions)):
            
            network_output = model.recurrent_inference(latent_state, actions[i].unsqueeze(0))    # 这里action注意
            latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output)
        
            # memory = latent_state
            latent_states.append(to_ndarray(latent_state.cpu()))

    stacked = np.concatenate(latent_states)
    return stacked

In [11]:
#   1 Trajectory loading
real_state = np.array(trajectorys[0].obs_segment)
real_state_tensor = to_device(to_tensor(real_state), state_space_tensor.device)
actions = np.array(trajectorys[0].action_segment)
actions_tensor = to_device(to_tensor(actions).unsqueeze(1), state_space_tensor.device)
with torch.no_grad():
    network_output = policy._model.initial_inference(real_state_tensor)
latent_state_represent_tensor, reward, v_trajectorys_tensor, policy_logits = mz_network_output_unpack(network_output) 
latent_state_represent = to_ndarray(latent_state_represent_tensor.cpu())
v_trajectorys = to_ndarray(v_trajectorys_tensor.cpu())

#   2 Get latent state trajectory
latent_state_dynamics = get_latent_trajectory(latent_state_represent_tensor, actions_tensor, policy._model)

#   3 Project to PC-space
pc_embedding_trajectory = pca.transform(latent_state_represent.reshape(len(latent_state_represent), -1))
pc_dynamics_trajectory = pca.transform(latent_state_dynamics.reshape(len(latent_state_dynamics), -1))

## Observation distribution

In [None]:
#   1 latent state trajectory distribution
plt.violinplot(latent_state_represent.reshape(len(latent_state_represent), -1), np.arange(1, latent_state_space.shape[-1] + 1))
plt.violinplot(latent_state_dynamics.reshape(len(latent_state_dynamics), -1), np.arange(1, latent_state_space.shape[-1] + 1))

plt.scatter([], [], label='embedding')
plt.scatter([], [], label='dynamics')

plt.title("Value Distributions within latent-space")

plt.ylabel("Values")
plt.xlabel("Latent Dimension")
plt.xticks(range(1, latent_state_space.shape[-1] + 1), [f'dim {i}' for i in range(1, latent_state_space.shape[-1] + 1)], rotation=45)

plt.legend()
plt.show()


In [None]:
#   2 Latent state trajectory distribution after PCA
plt.violinplot(pc_embedding_trajectory, np.arange(1, latent_state_space.shape[-1] + 1))
plt.violinplot(pc_dynamics_trajectory, np.arange(1, latent_state_space.shape[-1] + 1))

plt.scatter([], [], label='embedding')
plt.scatter([], [], label='dynamics')

plt.title("Value Distributions within latent PC-space")

plt.ylabel("Values")
plt.xlabel("Latent Dimension")
plt.xticks(range(1, latent_state_space.shape[-1] + 1), [f'dim {i}' for i in range(1, latent_state_space.shape[-1] + 1)], rotation=45)

plt.legend()
plt.show()

## 3D trajectory visualization

This code segment provides functions for generating 3D visualizations using Plotly.

The `generate_3d_surface` function creates a 3D surface plot.

The `generate_3d_trajectory` function creates a 3D scatter plot for a trajectory.

The `generate_3d_valuefield` function creates a 3D scatter plot for a value field.

The code uses these functions to generate 3D visualizations.
It creates a 3D trajectory plot (`dynamics_trajectory`) and an embedding trajectory plot (`embedding_trajectory`) using the PC-space trajectories.
It also generates a 3D surface plot (`surface`) using the PCA-transformed latent state space and the corresponding values.

This code provides a way to visualize 3D trajectories, embedding trajectories, and value fields using Plotly.

In [14]:
def generate_3d_surface(x: np.ndarray, y: np.ndarray, z: np.ndarray, colors: np.ndarray, clim=None):
    return go.Surface(
        x=x, y=y, z=z,
        opacity=1, 
        surfacecolor=colors,
        colorscale='Viridis',
        cmin=colors.min() if clim is None else clim[0],
        cmax=colors.max() if clim is None else clim[1],
        colorbar=dict(title=dict(text='V',side='top'), thickness=50, tickmode='array')
    )

def generate_3d_trajectory(x: np.ndarray, y: np.ndarray, z: np.ndarray, color: str):
    return go.Scatter3d(
        x=x + np.random.rand()*0.01,
        y=y + np.random.rand()*0.01,
        z=z + np.random.rand()*0.01,
        mode='lines+markers',
        marker=dict(
            size=3,
            symbol='x',
            color=color,
            opacity=1
        ),
        line=dict(
            color=color,
            width=20
        )
    )

def generate_3d_valuefield(x: np.ndarray, y: np.ndarray, z: np.ndarray, colors: np.ndarray, clim=None):
    return go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=4,
            color=colors,
            colorscale='Viridis',
            cmin=colors.min() if clim is None else clim[0],
            cmax=colors.max() if clim is None else clim[1],
            opacity=1,
            colorbar=dict(title=dict(text='V',side='top'), thickness=50, tickmode='array')
        ),
    )

In [None]:
# 3D trajectory visualization
x = 3
dynamics_trajectory =  generate_3d_trajectory(
    pc_dynamics_trajectory[:, 0].ravel(),
    pc_dynamics_trajectory[:, 1].ravel(), 
    pc_dynamics_trajectory[:, 2].ravel(), 'grey')

embedding_trajectory = generate_3d_trajectory(
    pc_embedding_trajectory[:, 0].ravel(), 
    pc_embedding_trajectory[:, 1].ravel(), 
    pc_embedding_trajectory[:, 2].ravel(), 'black')

surface = generate_3d_valuefield(pca_latent_state_space[:,0], pca_latent_state_space[:,1], pca_latent_state_space[:,2], v_state_space)

fig = go.Figure(data=[embedding_trajectory, dynamics_trajectory, surface])
# 
# tight layout
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))

fig.show()