# Example Usage

This notebook showcases the abilities of the ARLIN library. All available methods are 
shown, but not all will be needed for making decisions. Each method gives insight into 
different aspects of the trained RL model's policy and the user can use the output
analysis to help inlfuence decisions based on the adversary's goals.

In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import os
import gymnasium as gym
import numpy as np
import logging
import warnings

import arlin.dataset.loaders as loaders
from arlin.dataset import XRLDataset
from arlin.dataset.collectors import SB3PPODataCollector, SB3PPODatapoint

from arlin.generation import generate_clusters, generate_embeddings
import arlin.analysis.visualization as viz
from arlin.analysis import ClusterAnalyzer, LatentAnalyzer
from arlin.samdp import SAMDP

logging.basicConfig(level=logging.INFO, force=True)
warnings.filterwarnings("ignore", category=UserWarning) 

In [None]:
def create_dataset():
    """Create an XRL Dataset from a trained model operating within an environment.
    """
    # Create environment
    env = gym.make("LunarLander-v2")
    
    # Load the SB3 model from Huggingface
    model = loaders.load_hf_sb_model(repo_id="sb3/ppo-LunarLander-v2",
                                     filename="ppo-LunarLander-v2.zip",
                                     algo_str="ppo")
    
    # Create the datapoint collector for SB3 PPO Datapoints with the model's policy
    collector = SB3PPODataCollector(datapoint_cls=SB3PPODatapoint,
                                    policy=model.policy)
    
    # Instantiate the XRL Dataset
    dataset = XRLDataset(env, collector=collector)
    
    # Fill the dataset with 50k datapoints and add in additional analysis datapoints
    dataset.fill(num_datapoints=50000)
    
    return dataset

In [None]:
def get_embeddings(dataset: XRLDataset):
    """Generate latent space embeddings from the XRLDataset using T-SNE"""
    
    embeddings = generate_embeddings(dataset=dataset,
                                     activation_key="latent_actors",
                                     perplexity=20,
                                     n_train_iter=300,
                                     output_dim=2,
                                     seed=12345)

    return embeddings

def get_clusters(dataset: XRLDataset):
    """Cluster the latent space embeddings using K-Means and MeanShift"""
    
    clusters, _, _, _ = generate_clusters(
        dataset,
        ["latent_actors", "critic_values"],
        ["latent_actors", "critic_values", "rewards"],
        ["rewards"],
        10,
        seed=1234
        )
    return clusters

In [None]:
def graph_latent_analytics(embeddings: np.ndarray, 
                           clusters: np.ndarray, 
                           dataset: XRLDataset):
    """Graph visualizations of different latent space analytics over embeddings."""
    
    # Create a grapher to generate data used for analysis.
    grapher = LatentAnalyzer(embeddings, dataset)
    
    # Embeddings
    embeddings_data = grapher.embeddings_graph_data()
    # Clusters
    cluster_data = grapher.clusters_graph_data(clusters)
    # Action decision boundaries
    db_data = grapher.decision_boundary_graph_data()
    # Initial and terminal states
    init_term_data = grapher.initial_terminal_state_data()
    # Episode progression
    ep_prog_data = grapher.episode_prog_graph_data()
    # Greedy action confidence
    conf_data = grapher.confidence_data()
    
    base_path = os.path.join(".", "outputs", "usage", "latent_analytics")
    for data in [(embeddings_data, "embeddings.png"),
                 (cluster_data, "clusters.png"),
                 (db_data, "decision_boundaries.png"),
                 (init_term_data, "initial_terminal.png"),
                 (ep_prog_data, "episode_progression.png"),
                 (conf_data, "confidence.png")
                 ]:
        path = os.path.join(base_path, data[1])
        
        # Graph an individual data graph
        viz.graph_individual_data(path, data[0])
    
    # Graph multiple analytics as subplots in one plot
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path,
                                           figure_title='Latent Analytics', 
                                           graph_datas=[db_data, 
                                                        conf_data, 
                                                        cluster_data, 
                                                        ep_prog_data])

In [None]:
def graph_cluster_analytics(dataset, clusters):
    """Graph analytics for each cluster"""
    
    # Create grapher to graph cluster analytics
    grapher = ClusterAnalyzer(dataset, clusters)
    
    # Mean confidence per cluster
    cluster_conf = grapher.cluster_confidence()
    # Mean total reward per cluster
    cluster_rewards = grapher.cluster_rewards()
    # Mean value per cluster
    cluster_values = grapher.cluster_values()
    
    # Graph individual graphs per data
    base_path = os.path.join(".", "outputs", "usage", 'cluster_analytics')
    for data in [[cluster_conf, 'cluster_confidence.png'], 
                 [cluster_rewards, 'cluster_rewards.png'],
                 [cluster_values, 'cluster_values.png']
                 ]:
        path = os.path.join(base_path, data[1])
        viz.graph_individual_data(path, data[0])
    
    # Graph multiple subplots in one plot
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path, 
                                           figure_title='Cluster Analytics', 
                                           graph_datas=[cluster_rewards, cluster_values])

In [None]:
def samdp(clusters: np.ndarray,
          dataset: XRLDataset):
    """Generate a semi-aggregated Markov decision process."""
    
    # Create the SAMDP
    samdp = SAMDP(clusters, dataset)
    
    base_path = os.path.join(".", "outputs", "usage", 'samdp')
    
    # Complete graph with all possible conenctions by action taken
    complete_graph = samdp.save_complete_graph(f'{base_path}/samdp_complete.png')
    # Complete graph with only the most likely actions taken
    likely_graph = samdp.save_likely_graph(f'{base_path}/samdp_likely.png')
    # Simplified graph with all possible conenctions (regardless of action taken)
    simplified_graph = samdp.save_simplified_graph(f'{base_path}/samdp_simplified.png')
    
    path_path = os.path.join(base_path, f"samdp_path_14_9")
    
    # Path from cluster 14 to cluster 12
    # Action out of cluster 14 shown, all other movements are simplified
    samdp.save_paths(14, 
                     12, 
                     f'{path_path}.png')
    
    # Path from cluster 14 to cluster 12
    # Show all of the possible actions you can take
    samdp.save_paths(14, 
                     12, 
                     f'{path_path}_verbose.png', 
                     verbose=True)
    
    # Path from cluster 14 to cluster 12
    # Only the most likely path is shown
    samdp.save_paths(14, 
                     12,  
                     f'{path_path}_bp.png', 
                     best_path_only=True)
    
    # Show all paths that lead to cluster 12
    # Action into cluster 12 shown, rest is simplified
    samdp.save_all_paths_to(12, 
                            os.path.join(base_path, f"samdp_paths_to_12.png"))
    
    # Show all paths that lead to cluster 12 (all actions shown)
    samdp.save_all_paths_to(12, 
                            os.path.join(base_path, f"samdp_paths_to_12_verbose.png"),
                            verbose=True)
    
    # Show all connections into terminal nodes - all paths or only the best paths
    samdp.save_terminal_paths(os.path.join(base_path, f"samdp_terminal_paths.png"))
    samdp.save_terminal_paths(os.path.join(base_path, f"samdp_terminal_paths_bp.png"),
                              best_path=True)
    
    # Save a table representation of the SAMDP
    samdp.save_txt(f'{base_path}/samdp.txt')

In [None]:
dataset = create_dataset()
embeddings = get_embeddings(dataset=dataset)
clusters = get_clusters(dataset=dataset)

graph_latent_analytics(embeddings=embeddings, clusters=clusters, dataset=dataset)
graph_cluster_analytics(dataset=dataset, clusters=clusters)
samdp(clusters=clusters, dataset=dataset)