# Example Attack

This notebook looks into how ARLIN can be used to create more effective adversarial attacks. The notebook will show the average reward gained and total number of attacks in various attack scenarios against the same trained RL model:

- Random action every step
- Worst-case action every step
- Worst-case action every 10 steps
- Least-preferred action based on threshold (https://arxiv.org/pdf/1703.06748.pdf)
- ARLIN-informed actions

In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import os
import gymnasium as gym
import numpy as np
import logging
import warnings

import arlin.dataset.loaders as loaders
from arlin.dataset import XRLDataset
from arlin.dataset.collectors import SB3PPODataCollector
from arlin.dataset.collectors.datapoints import SB3PPODatapoint

from arlin.generation import generate_clusters, generate_embeddings
import arlin.analysis.visualization as viz
from arlin.analysis import ClusterAnalyzer, LatentAnalyzer
from arlin.samdp import SAMDP
import arlin.utils.saving_loading as sl_utils
from arlin.adversarial.attacks import run_baseline, run_adversarial, run_arlin
from arlin.adversarial.metrics import plot_cosine_sim, plot_divergences, plot_episode_rewards
from arlin.adversarial.utils import create_dirs

logging.basicConfig(level=logging.INFO, force=True)
warnings.filterwarnings("ignore", category=UserWarning) 

In [None]:
load_data = True
load_embeddings = True
load_clusters = True

# Create environment
env = gym.make("LunarLander-v2", render_mode='rgb_array')

# Load the SB3 model from Huggingface
model = loaders.load_hf_sb_model(repo_id="sb3/ppo-LunarLander-v2",
                                 filename="ppo-LunarLander-v2.zip",
                                 algo_str="ppo")

adv_model = loaders.load_sb_model('./models/adv_ppo_lunar.zip', 'ppo')

# Create the datapoint collector for SB3 PPO Datapoints with the model's policy
collector = SB3PPODataCollector(datapoint_cls=SB3PPODatapoint,
                                policy=model.policy)

# Instantiate the XRL Dataset
dataset = XRLDataset(env, collector=collector)

if load_data:
    # Load the dataset, embeddings, and clusters
    dataset.load('./data/LunarLander-50000.npz')
else:
    dataset.fill(num_datapoints=50000, randomness=0.2)
    dataset.save(file_path='./data/LunarLander-50000.npz')

if load_embeddings:
    embeddings = sl_utils.load_data(file_path='./data/LunarLander-50000-Embeddings.npy')
else:
    embeddings = generate_embeddings(dataset=dataset,
                                activation_key='latent_actors',
                                perplexity=500,
                                n_train_iter=1500,
                                output_dim=2,
                                seed=12345)
    sl_utils.save_data(embeddings, './data/LunarLander-50000-Embeddings.npy')

if load_clusters:
    clusters = sl_utils.load_data(file_path='./data/LunarLander-50000-Clusters.npy')
    [start_algo, mid_algo, term_algo] = sl_utils.load_data(file_path='./models/cluster_algos.npy', allow_pickle=True)
else:
    clusters, start_algo, mid_algo, term_algo = generate_clusters(
        dataset,
        ["latent_actors", "critic_values"],
        ["latent_actors", "critic_values"],
        ["latent_actors", "critic_values", "rewards"],
        20,
        seed=1234
        )

    sl_utils.save_data(clusters, './data/LunarLander-50000-Clusters.npy')
    sl_utils.save_data(data=[start_algo, mid_algo, term_algo], file_path='./models/cluster_algos.npy')

## ARLIN Usage

Let's use the ARLIN toolkit to identify when we should be performing our adversarial
attack, and which actions we should target.

In [None]:
def graph_latent_analytics(embeddings: np.ndarray, 
                           clusters: np.ndarray, 
                           dataset: XRLDataset):
    """Graph visualizations of different latent space analytics over embeddings."""
    
    # Create a grapher to generate data used for analysis.
    grapher = LatentAnalyzer(embeddings, dataset)
    
    embeddings_data = grapher.embeddings_graph_data()
    # Clusters
    cluster_data = grapher.clusters_graph_data(clusters)
    
    decision_boundaries = grapher.decision_boundary_graph_data()
    # Episode progression
    ep_prog_data = grapher.episode_prog_graph_data()
    # Greedy action confidence
    conf_data = grapher.confidence_data()
    
    base_path = os.path.join(".", "outputs", "attack", "latent_analytics")
    
    # Graph multiple analytics as subplots in one plot
    combined_path = os.path.join(base_path, 'combined_analytics-total.png')
    viz.graph_multiple_data(file_path=combined_path,
                                           figure_title='Latent Analytics', 
                                           graph_datas=[ep_prog_data, 
                                                        conf_data, 
                                                        decision_boundaries],
                                           horizontal=True)
    combined_path_2 = os.path.join(base_path, 'combined_analytics-generate.png')
    viz.graph_multiple_data(file_path=combined_path_2,
                                           figure_title='Latent Analytics', 
                                           graph_datas=[embeddings_data, 
                                                        cluster_data],
                                           horizontal=True)

def graph_cluster_analytics(dataset, clusters):
    """Graph analytics for each cluster"""
    
    # Create grapher to graph cluster analytics
    grapher = ClusterAnalyzer(dataset, clusters)
    
    # for i in range(22, 25):
    #     grapher.cluster_state_analysis(i,
    #                                    env,
    #                                    os.path.join(".", "outputs", "attack", "cluster_state_analysis"))

    # grapher.cluster_state_analysis(9,
    #                                env, 
    #                                os.path.join(".", "outputs", "attack", "cluster_state_analysis"))
    
    # Mean confidence per cluster
    cluster_conf = grapher.cluster_confidence()
    # Mean total reward per cluster
    cluster_rewards = grapher.cluster_rewards()
    # Mean value per cluster
    cluster_values = grapher.cluster_values()
    
    # Graph individual graphs per data
    base_path = os.path.join(".", "outputs", "attack", 'cluster_analytics')
    
    # Graph multiple subplots in one plot
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path, 
                                           figure_title='Cluster Analytics', 
                                           graph_datas=[cluster_conf,
                                                        cluster_values,
                                                        cluster_rewards],
                                           horizontal=True)

def samdp(clusters: np.ndarray,
          dataset: XRLDataset):
    """Generate a semi-aggregated Markov decision process."""
    
    # Create the SAMDP
    samdp = SAMDP(clusters, dataset)
    
    base_path = os.path.join(".", "outputs", "attack", 'samdp')
    
    # Simplified graph with all possible connections (regardless of action taken)
    simplified_graph = samdp.save_simplified_graph(f'{base_path}/samdp_simplified.png')
    complete_graph = samdp.save_complete_graph(f'{base_path}/samdp_complete.png')
    
    samdp.save_terminal_paths(f'{os.path.join(base_path, f"samdp_terminals_23")}.png', 
                              best_path=True,
                              term_cluster_id=23)
    
    # samdp.save_txt('./outputs/attack/samdp/text.txt')

In [None]:
# graph_latent_analytics(embeddings, clusters, dataset)
# graph_cluster_analytics(dataset, clusters)
samdp(clusters, dataset)

In [None]:
num_episodes = 25
baseline_obs, baseline_rew, baseline_renders = run_baseline(env, 
                                                            model, 
                                                            num_episodes=num_episodes)

rand1_obs, rand1_rew, rand1_div, rand1_renders = run_adversarial('random',
                                                             collector,
                                                             env,
                                                             model,
                                                             adv_model,
                                                             attack_freq=1,
                                                             num_episodes=num_episodes
                                                             )

rand10_obs, rand10_rew, rand10_div, rand10_renders = run_adversarial('random',
                                                             collector,
                                                             env,
                                                             model,
                                                             adv_model,
                                                             attack_freq=10,
                                                             num_episodes=num_episodes
                                                             )

adv1_obs, adv1_rew, adv1_div, adv1_renders = run_adversarial('adversarial',
                                                             collector,
                                                             env,
                                                             model,
                                                             adv_model,
                                                             attack_freq=1,
                                                             num_episodes=num_episodes
                                                             )

adv10_obs, adv10_rew, adv10_div, adv10_renders = run_adversarial('adversarial',
                                                             collector,
                                                             env,
                                                             model,
                                                             adv_model,
                                                             attack_freq=10,
                                                             num_episodes=num_episodes
                                                             )

pref50_obs, pref50_rew, pref50_div, pref50_renders = run_adversarial('preference',
                                                                     collector,
                                                                     env,
                                                                     model,
                                                                     adv_model,
                                                                     pref_threshold=0.50,
                                                                     num_episodes=num_episodes
                                                                     )

pref75_obs, pref75_rew, pref75_div, pref75_renders = run_adversarial('preference',
                                                                     collector,
                                                                     env,
                                                                     model,
                                                                     adv_model,
                                                                     pref_threshold=0.75,
                                                                     num_episodes=num_episodes
                                                                     )

pref90_obs, pref90_rew, pref90_div, pref90_renders = run_adversarial('preference',
                                                                     collector,
                                                                     env,
                                                                     model,
                                                                     adv_model,
                                                                     pref_threshold=0.90,
                                                                     num_episodes=num_episodes
                                                                     )

arlin_obs, arlin_rew, arlin_div, arlin_renders = run_arlin(collector,
                                                           env,
                                                           model,
                                                           start_algo,
                                                           mid_algo,
                                                           term_algo,
                                                           num_episodes=num_episodes
                                                           )

In [None]:
names = ['Baseline', 'Adversarial 1', 'Preference .75', 'Preference .90', 'ARLIN']
obs = [baseline_obs, adv1_obs, pref75_obs, pref90_obs, arlin_obs]
rewards = [baseline_rew, adv1_rew, pref75_rew, pref90_rew, arlin_rew]
divergences = [adv1_div, pref75_div, pref90_div, arlin_div]

create_dirs('./outputs/attack', names)
plot_divergences(divergences, names[1:], './outputs/attack/metrics/kl_divergence.png')

num_eval = 10

if num_eval > len(baseline_obs):
    num_eval = len(baseline_obs)

for i in range(num_eval):
    cs_save_path = f'./outputs/attack/metrics/cosine_similarity/episode_{i}.png'
    rew_save_path = f'./outputs/attack/metrics/episode_rewards/episode_{i}.png'
    
    cs_obs = [j[i] for j in obs]
    plot_cosine_sim(baseline_obs[i], cs_obs, names, cs_save_path)
    
    rews = [j[i] for j in rewards]
    plot_episode_rewards(rews, names, rew_save_path)
