# Example Attack

This notebook looks into how ARLIN can be used to create more effective adversarial attacks. The notebook will show the average reward gained and total number of attacks in various attack scenarios against the same trained RL model:

- Random action every step
- Worst-case action every step
- Worst-case action every 10 steps
- Least-preferred action based on threshold (https://arxiv.org/pdf/1703.06748.pdf)
- ARLIN-informed actions

In [1]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
import os
import gymnasium as gym
import numpy as np
import logging
import warnings

import arlin.dataset.loaders as loaders
from arlin.dataset import XRLDataset
from arlin.dataset.collectors import SB3PPODataCollector, SB3PPODatapoint

from arlin.generation import generate_clusters, generate_embeddings
import arlin.analysis.visualization as viz
from arlin.analysis import ClusterAnalyzer, LatentAnalyzer
from arlin.samdp import SAMDP
import arlin.utils.saving_loading as sl_utils

logging.basicConfig(level=logging.INFO, force=True)
warnings.filterwarnings("ignore", category=UserWarning) 

In [3]:
load_data = True
load_embeddings = True
load_clusters = True

# Create environment
env = gym.make("LunarLander-v2", render_mode='rgb_array')

# Load the SB3 model from Huggingface
model = loaders.load_hf_sb_model(repo_id="sb3/ppo-LunarLander-v2",
                                 filename="ppo-LunarLander-v2.zip",
                                 algo_str="ppo")

adv_model = loaders.load_sb_model('./models/adv_ppo_lunar.zip', 'ppo')

# Create the datapoint collector for SB3 PPO Datapoints with the model's policy
collector = SB3PPODataCollector(datapoint_cls=SB3PPODatapoint,
                                policy=model.policy)

# Instantiate the XRL Dataset
dataset = XRLDataset(env, collector=collector)

if load_data:
    # Load the dataset, embeddings, and clusters
    dataset.load('./data/LunarLander-50000.npz')
else:
    dataset.fill(num_datapoints=50000)
    dataset.save(file_path='./data/LunarLander-50000.npz')

if load_embeddings:
    embeddings = sl_utils.load_data(file_path='./data/LunarLander-50000-Embeddings.npy')
else:
    embeddings = generate_embeddings(dataset=dataset,
                                activation_key='latent_actors',
                                perplexity=500,
                                n_train_iter=2000,
                                output_dim=2,
                                seed=12345)
    sl_utils.save_data(embeddings, './data/LunarLander-50000-Embeddings.npy')

if load_clusters:
    clusters = sl_utils.load_data(file_path='./data/LunarLander-50000-Clusters.npy')
    [start_algo, term_algo, mid_algo] = sl_utils.load_data(file_path='./models/cluster_algos.npy', allow_pickle=True)
else:
    clusters, start_algo, term_algo, mid_algo = generate_clusters(dataset=dataset,
                                     num_clusters=30)

    sl_utils.save_data(clusters, './data/LunarLander-50000-Clusters.npy')
    sl_utils.save_data(data=[start_algo, term_algo, mid_algo], file_path='./models/cluster_algos.npy')

INFO:root:Loading model sb3/ppo-LunarLander-v2/ppo-LunarLander-v2.zip from huggingface...
INFO:root:Loading ppo model ppo-LunarLander-v2.zip with stable_baselines3...
INFO:root:Loading ppo model adv_ppo_lunar.zip with stable_baselines3...
INFO:root:Loading data from ./data/LunarLander-50000-Embeddings.npy...
INFO:root:	Data loaded successfully.
INFO:root:Loading data from ./data/LunarLander-50000-Clusters.npy...
INFO:root:	Data loaded successfully.
INFO:root:Loading data from ./models/cluster_algos.npy...
INFO:root:	Data loaded successfully.


## ARLIN Usage

Let's use the ARLIN toolkit to identify when we should be performing our adversarial
attack, and which actions we should target.

In [42]:
def graph_latent_analytics(embeddings: np.ndarray, 
                           clusters: np.ndarray, 
                           dataset: XRLDataset):
    """Graph visualizations of different latent space analytics over embeddings."""
    
    # Create a grapher to generate data used for analysis.
    grapher = LatentAnalyzer(embeddings, dataset)
    
    # Clusters
    cluster_data = grapher.clusters_graph_data(clusters)
    # Episode progression
    ep_prog_data = grapher.episode_prog_graph_data()
    # Greedy action confidence
    conf_data = grapher.confidence_data()
    
    base_path = os.path.join(".", "outputs", "attack", "latent_analytics")
    
    # Graph multiple analytics as subplots in one plot
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path,
                                           figure_title='Latent Analytics', 
                                           graph_datas=[conf_data, 
                                                        cluster_data, 
                                                        ep_prog_data])

def graph_cluster_analytics(dataset, clusters):
    """Graph analytics for each cluster"""
    
    # Create grapher to graph cluster analytics
    grapher = ClusterAnalyzer(dataset, clusters)
    
    # grapher.cluster_state_analysis(7,
    #                                env, 
    #                                os.path.join(".", "outputs", "attack", "cluster_state_analysis"))
    
    # grapher.cluster_state_analysis(10,
    #                                env, 
    #                                os.path.join(".", "outputs", "attack", "cluster_state_analysis"))
    
    # grapher.cluster_state_analysis(4,
    #                                env, 
    #                                os.path.join(".", "outputs", "attack", "cluster_state_analysis"))
    
    # grapher.cluster_state_analysis(17,
    #                                env, 
    #                                os.path.join(".", "outputs", "attack", "cluster_state_analysis"))
    
    # Mean confidence per cluster
    cluster_conf = grapher.cluster_confidence()
    # Mean total reward per cluster
    cluster_rewards = grapher.cluster_rewards()
    # Mean value per cluster
    cluster_values = grapher.cluster_values()
    
    # Graph individual graphs per data
    base_path = os.path.join(".", "outputs", "attack", 'cluster_analytics')
    
    # Graph multiple subplots in one plot
    combined_path = os.path.join(base_path, 'combined_analytics.png')
    viz.graph_multiple_data(file_path=combined_path, 
                                           figure_title='Cluster Analytics', 
                                           graph_datas=[cluster_conf,
                                                        cluster_values,
                                                        cluster_rewards])

def samdp(clusters: np.ndarray,
          dataset: XRLDataset):
    """Generate a semi-aggregated Markov decision process."""
    
    # Create the SAMDP
    samdp = SAMDP(clusters, dataset)
    
    base_path = os.path.join(".", "outputs", "attack", 'samdp')
    
    # Simplified graph with all possible connections (regardless of action taken)
    # simplified_graph = samdp.save_simplified_graph(f'{base_path}/samdp_simplified.png')
    
    # path_path = os.path.join(base_path, f"samdp_path_15_18")
    
    # # Path from cluster 15 to cluster 18
    # # Only the most likely path is shown
    # samdp.save_paths(15, 
    #                  18,  
    #                  f'{os.path.join(base_path, f"samdp_path_15_18")}_bp.png', 
    #                  best_path_only=True)
    
    # Path from cluster 9 to cluster 18
    # Only the most likely path is shown
    # samdp.save_paths(13, 
    #                  18,  
    #                  f'{os.path.join(base_path, f"samdp_path_13_18")}_bp.png', 
    #                  best_path_only=True)
    
    # # Path from cluster 9 to cluster 18
    # # Only the most likely path is shown
    # samdp.save_paths(10, 
    #                  18,  
    #                  f'{os.path.join(base_path, f"samdp_path_10_18")}_bp.png', 
    #                  best_path_only=True)
    
    # Path from cluster 9 to cluster 18
    # Only the most likely path is shown
    samdp.save_paths(1, 
                     34,  
                     f'{os.path.join(base_path, f"samdp_path_1_34")}_bp.png', 
                     best_path_only=True)
    
    samdp.save_paths(1, 
                     33,  
                     f'{os.path.join(base_path, f"samdp_path_1_33")}_bp.png', 
                     best_path_only=True)

    samdp.save_paths(9, 
                     34,  
                     f'{os.path.join(base_path, f"samdp_path_9_34")}_bp.png', 
                     best_path_only=True)
    
    samdp.save_paths(9, 
                     33,  
                     f'{os.path.join(base_path, f"samdp_path_9_33")}_bp.png', 
                     best_path_only=True)
    
    # Show all paths that lead to cluster 18
    # # Action into cluster 18 shown, rest is simplified
    # samdp.save_all_paths_to(24, 
    #                         os.path.join(base_path, f"samdp_paths_to_24.png"))
    
    # # Show all paths that lead to cluster 18
    # # Action into cluster 18 shown, rest is simplified
    # samdp.save_all_paths_to(23, 
    #                         os.path.join(base_path, f"samdp_paths_to_23.png"))
    
    # # Show all paths that lead to cluster 18
    # # Action into cluster 18 shown, rest is simplified
    # samdp.save_all_paths_to(22, 
    #                         os.path.join(base_path, f"samdp_paths_to_22.png"))
    
    # samdp.save_all_paths_to(34, 
    #                         os.path.join(base_path, f"samdp_paths_to_34.png"))
    
    # samdp.save_txt('./outputs/attack/samdp/text.txt')

In [43]:
# graph_latent_analytics(embeddings, clusters, dataset)
# graph_cluster_analytics(dataset, clusters)
samdp(clusters, dataset)

INFO:root:Generating SAMDP.


INFO:root:Generating SAMDP Graph.
INFO:root:Finding paths from Cluster 1 to Cluster 34...
INFO:root:Highest probability of getting from Cluster 1 to Cluster 34:
INFO:root:	via Action 0: 0.41%
INFO:root:		Cluster 1 to Cluster 9 with 76.64%
INFO:root:		Cluster 9 to Cluster 34 with 0.53%
INFO:root:	via Action 1: 0.41%
INFO:root:		Cluster 1 to Cluster 9 with 76.67%
INFO:root:		Cluster 9 to Cluster 34 with 0.53%
INFO:root:	via Action 2: 0.32%
INFO:root:		Cluster 1 to Cluster 9 with 60.0%
INFO:root:		Cluster 9 to Cluster 34 with 0.53%
INFO:root:	via Action 3: 0.37%
INFO:root:		Cluster 1 to Cluster 9 with 69.57%
INFO:root:		Cluster 9 to Cluster 34 with 0.53%
INFO:root:	Best Option: Action 1 with 0.41%
INFO:root:	Best Path:
INFO:root:		Cluster 1 to Cluster 9 with 76.67%
INFO:root:		Cluster 9 to Cluster 34 with 0.53%
INFO:root:Saving SAMDP path from Cluster 1 to Cluster 34 png to ./outputs/attack/samdp/samdp_path_1_34_bp.png...
INFO:root:Finding paths from Cluster 1 to Cluster 33...
INFO:root:H

In [4]:
def should_attack(model_type: str, 
                  timestep: int, 
                  freq: int = 0,
                  preference: float = 0, 
                  threshold: float = 1.0) -> bool:
    """Check whether or not we should attack at the given timestep.

    Args:
        model_type (str): Type of model we want to run.
        timestep (int): Current timestep
        freq (int, optional): Frequency of attack. Defaults to 0.
        preference (float, optional): Delta between most and least preferred action.
            Defaults to 0.
        threshold (float, optional): Threshold for preference attack. Defaults to 1.0.

    Raises:
        ValueError: If invalid model type is given.

    Returns:
        bool: Whether or not to attack
    """
    
    if model_type == 'baseline':
        return False
    elif model_type == 'random' or model_type == 'adversarial':
        if timestep % freq == 0:
            return True
        else:
            return False
    elif model_type == 'preference':
        if preference > threshold:
            return True
        else:
            return False
    else:
        raise ValueError(f"Invalid model_type {model_type} given.")

def get_action(obs: np.ndarray,
               model_type: str, 
               timestep: int, 
               freq: int = 0,
               preference: float = 0, 
               threshold: float = 1.0) -> int:
    """Get the action to take at the given timestep.

    Args:
        obs (np.ndarray): Current observation from the agent.
        model_type (str): Type of model we want to run.
        timestep (int): Current timestep
        freq (int, optional): Frequency of attack. Defaults to 0.
        preference (float, optional): Delta between most and least preferred action.
            Defaults to 0.
        threshold (float, optional): Threshold for preference attack. Defaults to 1.0.

    Returns:
        int: Action value to take.
    """
    adv = False
    if should_attack(model_type, timestep, freq, preference, threshold):
        adv = True
        if model_type == 'random':
            rng = np.random.default_rng(12345)
            action = rng.integers(low=0, high=env.action_space.n, size=1).item()
        else:
            action, _ = adv_model.predict(obs, deterministic=True)
    else:
        action, _ = model.predict(obs, deterministic=True)

    return action, adv

In [29]:
def get_average_reward(model_type: str ='baseline', 
                       freq: int = 0,
                       threshold: int = 1) -> float:
    """Average reward over 10 episodes while the model is being attacked.
    
    Attacks happen at the given freq and come from the given model type. 
        - Baseline does not include any adversarial attacks.
        - Random chooses the action randomly.
        - Adversarial chooses the worst possible action at that point in time.
        - Preference chooses the least preferred action when the pref is above a threshold
    """
    
    episode_rewards = []
    episode_attacks = []
    
    dir_name = os.path.join("./outputs/attack/gifs/", model_type)
    
    if freq != 0:
        dir_name = dir_name + f"-{freq}_freq"
    
    if threshold != 1:
        dir_name = dir_name + f"-{threshold}_thresh"
    
    os.makedirs(dir_name, exist_ok=True)
    
    gif_lists = []
    
    for ep in range(10):
        obs, _ = env.reset(seed=1234 + ep)
        images = [Image.fromarray(env.render())]
        done = False
        step = 0
        ep_rew = 0
        adv_attacks = 0
        
        while not done:
            internal_data, _ = collector.collect_internal_data(obs)
            probs = internal_data.dist_probs
            preference = probs.max() - probs.min()
            
            action, adv = get_action(obs, model_type, step, freq, preference, threshold)
            if adv:
                adv_attacks += 1
            
            obs, reward, terminated, truncated, _ = env.step(action)
            images.append(Image.fromarray(env.render()))
            ep_rew += reward
            done = terminated or truncated
            step += 1
        
        gif_lists.append(images)
        episode_rewards.append(ep_rew)
        episode_attacks.append(adv_attacks)
    
    idx = episode_rewards.index(max(episode_rewards))
    save_path = os.path.join(dir_name, f'episode_{idx}-max.gif')
    gif_lists[idx][0].save(save_path, save_all=True, append_images=gif_lists[idx], duration=30)
    
    idx = episode_rewards.index(min(episode_rewards))

    save_path = os.path.join(dir_name, f'episode_{idx}-min.gif')
    gif_lists[idx][0].save(save_path, save_all=True, append_images=gif_lists[idx], duration=30)
    
    return sum(episode_rewards) / 10, sum(episode_attacks) / 10

In [30]:
baseline, base_n_attack = get_average_reward('baseline')
rand_every_1, rand1_n_attack = get_average_reward('random', freq=1)
rand_every_10, rand10_n_attack = get_average_reward('random', freq=10)
worst_every_1, worst1_n_attack = get_average_reward('adversarial', freq=1)
worst_every_10, worst10_n_attack = get_average_reward('adversarial', freq=10)
preference_50, pref50_n_attack = get_average_reward('preference', threshold=0.50)
preference_75, pref75_n_attack = get_average_reward('preference', threshold=0.75)
preference_90, pref90_n_attack = get_average_reward('preference', threshold=0.9)

print(f"Baseline Avg Reward: {baseline} with {base_n_attack} attacks")
print(f"Random Action Every 1 Avg Reward: {rand_every_1} with {rand1_n_attack} attacks")
print(f"Random Action Every 10 Avg Reward: {rand_every_10} with {rand10_n_attack} attacks")
print(f"Worst Action Every 1 Avg Reward: {worst_every_1} with {worst1_n_attack} attacks")
print(f"Worst Action Every 10 Avg Reward: {worst_every_10} with {worst10_n_attack} attacks")
print(f"Preference at .50 Avg Reward: {preference_50} with {pref50_n_attack} attacks")
print(f"Preference at .75 Avg Reward: {preference_75} with {pref75_n_attack} attacks")
print(f"Preference at .90 Avg Reward: {preference_90} with {pref90_n_attack} attacks")

Baseline Avg Reward: 247.46209077695022 with 0.0 attacks
Random Action Every 1 Avg Reward: -588.4916522579238 with 95.7 attacks
Random Action Every 10 Avg Reward: 138.2434657663849 with 100.0 attacks
Worst Action Every 1 Avg Reward: -594.7959537977795 with 63.0 attacks
Worst Action Every 10 Avg Reward: 139.18397864543132 with 100.0 attacks
Preference at .50 Avg Reward: -547.7386702780462 with 59.0 attacks
Preference at .75 Avg Reward: -25.699487103558965 with 18.4 attacks
Preference at .90 Avg Reward: 168.53780060602693 with 5.8 attacks


In [46]:
from PIL import Image

np.random.seed(1234)
episodes = []
attacks = []
gif_lists = []

dir_name = "./outputs/attack/gifs/arlin/"
os.makedirs(dir_name, exist_ok=True)

for ep in range(10):
    obs, _ = env.reset(seed=1234 + ep)
    images = [Image.fromarray(env.render())]
    
    done = False
    step = 0
    total_reward = 0
    last_pred = -1
    n_attacks = 0
    
    while not done:
        internal_data, _ = collector.collect_internal_data(obs)
        probs = internal_data.dist_probs
        preference = probs.max() - probs.min()
        
        if step == 0:
            prediction = start_algo.predict(np.array([internal_data.critic_values]).reshape(-1, 1))
            if last_pred != prediction + 30:
                # print(f'Initial: {prediction + 14}')
                last_pred = prediction
        else:
            latent = internal_data.latent_actors
            value = internal_data.critic_values
            confidence = np.amax(internal_data.dist_probs)
            
            data = np.concatenate([latent,
                                # np.expand_dims(action, axis=-1),
                                np.expand_dims(value, axis=-1),
                                np.expand_dims(reward, axis=-1),
                                # np.expand_dims(total_reward, axis=-1),
                                np.expand_dims(confidence, axis=-1)], axis=-1)
            prediction = mid_algo.predict(data.reshape(1, -1))
            if last_pred != prediction:
                # print(f'Mid: {prediction}')
                last_pred = prediction

        if prediction == 9:
            action = 0
            n_attacks += 1
        elif prediction == 1:
            action = 1
            n_attacks += 1
        else:
            action, _ = model.predict(obs, deterministic=True)
        
        obs, reward, terminated, truncated, _ = env.step(action)
        images.append(Image.fromarray(env.render()))
        total_reward += reward
        done = terminated or truncated
        step += 1

    gif_lists.append(images)
    prediction = start_algo.predict(np.array([total_reward]).reshape(-1, 1))
    print(f'Terminal: {prediction + 30 + 2} {total_reward} {n_attacks}')
    # render = env.render()
    # im = Image.fromarray(render)
    # im.save(f'./outputs/attack/images/{ep}_{prediction + 20}.png')
    episodes.append(total_reward)
    attacks.append(n_attacks)

idx = episodes.index(max(episodes))
save_path = os.path.join(dir_name, f'episode_{idx}-max.gif')
gif_lists[idx][0].save(save_path, save_all=True, append_images=gif_lists[idx], duration=30)

idx = episodes.index(min(episodes))
save_path = os.path.join(dir_name, f'episode_{idx}-min.gif')
gif_lists[idx][0].save(save_path, save_all=True, append_images=gif_lists[idx], duration=30)

print(f'Final: {sum(episodes) / 10} with {sum(attacks) / 10} attacks')
    

Terminal: [32] 268.6094182400691 65
Terminal: [32] 235.27491673455904 41
Terminal: [32] 223.95616732714305 42
Terminal: [32] 273.6596252490928 43
Terminal: [32] 223.4218360915082 49
Terminal: [33] 30.233268838799972 15
Terminal: [32] 280.1934542704997 58
Terminal: [32] 200.99772107592568 378
Terminal: [32] 277.72075190658995 31
Terminal: [32] 252.6146954932242 306
Final: 226.66818552274117 with 102.8 attacks
