In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import os
import gymnasium as gym
import numpy as np
import logging
import warnings

import arlin.dataset.loaders as loaders
from arlin.dataset import XRLDataset
from arlin.dataset.collectors import SB3PPODataCollector, SB3PPODatapoint

from arlin.generation import generate_clusters, generate_embeddings
import arlin.analysis.visualization as viz
from arlin.analysis import ClusterAnalyzer, LatentAnalyzer
from arlin.samdp import SAMDP

logging.basicConfig(level=logging.INFO, force=True)
warnings.filterwarnings("ignore", category=UserWarning) 

In [None]:
def create_dataset():
    # Create environment
    env = gym.make("LunarLander-v2")
    
    # Load the SB3 model from Huggingface
    model = loaders.load_hf_sb_model(repo_id="sb3/ppo-LunarLander-v2",
                                     filename="ppo-LunarLander-v2.zip",
                                     algo_str="ppo")
    
    # Create the datapoint collector for SB3 PPO Datapoints with the model's policy
    collector = SB3PPODataCollector(datapoint_cls=SB3PPODatapoint,
                                    policy=model.policy)
    
    # Instantiate the XRL Dataset
    dataset = XRLDataset(env, collector=collector)
    
    # Fill the dataset with 50k datapoints and add in additional analysis datapoints
    dataset.fill(num_datapoints=100)
    
    return dataset

In [None]:
def get_embeddings(dataset: XRLDataset):
    
    embeddings = generate_embeddings(dataset=dataset,
                                     activation_key="latent_actors",
                                     perplexity=20,
                                     n_train_iter=300,
                                     output_dim=2,
                                     seed=12345)

    return embeddings

def get_clusters(dataset: XRLDataset, embeddings: np.ndarray):
    
    clusters = generate_clusters(dataset=dataset,
                                 embeddings=embeddings,
                                 num_clusters=14)
    return clusters

In [None]:
def graph_latent_analytics(embeddings, clusters, dataset):
    grapher = LatentAnalyzer(embeddings, dataset)
    
    embeddings_data = grapher.embeddings_graph_data()
    cluster_data = grapher.clusters_graph_data(clusters)
    db_data = grapher.decision_boundary_graph_data()
    init_term_data = grapher.initial_terminal_state_data()
    ep_prog_data = grapher.episode_prog_graph_data()
    conf_data = grapher.confidence_data()
    
    base_path = os.path.join("./outputs/", "latent_analytics")
    for data in [(embeddings_data, "embeddings.png"),
                 (cluster_data, "clusters.png"),
                 (db_data, "decision_boundaries.png"),
                 (init_term_data, "initial_terminal.png"),
                 (ep_prog_data, "episode_progression.png"),
                 (conf_data, "confidence.png")
                 ]:
        path = os.path.join(base_path, data[1])
        
        viz.graph_individual_data(path, data[0])
    
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path,
                                           figure_title='Latent Analytics', 
                                           graph_datas=[db_data, 
                                                        conf_data, 
                                                        cluster_data, 
                                                        ep_prog_data])

In [None]:
def graph_cluster_analytics(dataset, clusters):
    grapher = ClusterAnalyzer(dataset, clusters)
    
    cluster_conf = grapher.cluster_confidence()
    cluster_rewards = grapher.cluster_rewards()
    cluster_values = grapher.cluster_values()
    
    base_path = os.path.join("./outputs", 'cluster_analytics')
    for data in [[cluster_conf, 'cluster_confidence.png'], 
                 [cluster_rewards, 'cluster_rewards.png'],
                 [cluster_values, 'cluster_values.png']
                 ]:
        path = os.path.join(base_path, data[1])
        viz.graph_individual_data(path, data[0])
    
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path, 
                                           figure_title='Cluster Analytics', 
                                           graph_datas=[cluster_rewards, cluster_values])

In [None]:
def samdp(clusters, dataset):
    samdp = SAMDP(clusters, dataset)
    
    base_path = os.path.join("./outputs", 'samdp')
    
    complete_graph = samdp.save_complete_graph(f'{base_path}/samdp_complete.png')
    likely_graph = samdp.save_likely_paths(f'{base_path}/samdp_likely.png')
    simplified_graph = samdp.save_simplified_graph(f'{base_path}/samdp_simplified.png')
    
    path_path = os.path.join(base_path, f"samdp_path_14_9")
    
    samdp.save_paths(14, 
                     12, 
                     f'{path_path}.png')
    
    samdp.save_paths(14, 
                     12, 
                     f'{path_path}_verbose.png', 
                     verbose=True)
    
    samdp.save_paths(14, 
                     12,  
                     f'{path_path}_bp.png', 
                     best_path_only=True)
    
    samdp.save_all_paths_to(12, 
                            os.path.join(base_path, f"samdp_paths_to_12_verbose.png"),
                            verbose=True)
    
    samdp.save_all_paths_to(12, 
                            os.path.join(base_path, f"samdp_paths_to_12.png"))
    
    samdp.save_txt(f'{base_path}/samdp.txt')

In [None]:
dataset = create_dataset()
embeddings = get_embeddings(dataset=dataset)
clusters = get_clusters(embeddings=embeddings, dataset=dataset)

graph_latent_analytics(embeddings=embeddings, clusters=clusters, dataset=dataset)
graph_cluster_analytics(dataset=dataset, clusters=clusters)
samdp(clusters=clusters, dataset=dataset)