In [124]:
from Continuous_Language.Environments.Multi_Pong.multi_pong import PongEnv
from Continuous_Language.Reinforcement_Learning.Centralized_PPO.multi_ppo import PPO_Multi_Agent_Centralized
from Continuous_Language.Reinforcement_Learning.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
import torch
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import copy

In [125]:
def make_env(max_episode_steps = 1024, sequence_length = 1, vocab_size = 3):
    env = PongEnv(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_episode_steps=max_episode_steps)
    return env

In [126]:
def load(path="models/checkpoints", sequence_length=1, vocab_size=3):
    env = make_env(sequence_length=sequence_length, vocab_size=vocab_size)
    models = {}
    for model in os.listdir(path):
        if "pong" in model:
            state_dict = torch.load(os.path.join(path, model))
            timestamp = model.split("_")[-1]
            timestamp = int(timestamp.split(".")[0])
            agent = PPO_Multi_Agent_Centralized(env, device="cpu")
            agent.agent.load_state_dict(state_dict)
            models[timestamp] = agent
    return models

In [127]:
def perturbation(inputs, model, vocab_size, sequence_length):
    
    # Extract environment inputs
    environment_inputs = inputs[:, :-1 * vocab_size * sequence_length]

    # Extract original logits
    inputs = torch.tensor(inputs, dtype=torch.float32)
    original_logits = model(inputs)
    original_logits = F.softmax(original_logits, dim=1).detach().numpy()
    original_logits = F.log_softmax(torch.tensor(original_logits), dim=1).detach()

    perturbation_logits = []
    for token in range(vocab_size):
        # One-hot encoded sequence of tokens
        utterances = np.array([token for _ in range(sequence_length)])
        utterances = np.eye(vocab_size)[utterances].flatten()
        utterances = np.expand_dims(utterances, axis=0)
        utterances = np.repeat(utterances, inputs.shape[0], axis=0)

        # Concatenate environment inputs with utterances
        perturbation_inputs = np.concatenate((environment_inputs, utterances), axis=1)
        perturbation_inputs = torch.tensor(perturbation_inputs, dtype=torch.float32)

        # Get logits for perturbed inputs
        current_logits = model(perturbation_inputs).detach().numpy()
        current_logits = F.softmax(torch.tensor(current_logits), dim=1).detach().numpy()

        perturbation_logits.append(current_logits)

    divergences = []
    for input_array in perturbation_logits:
        kl_divergences = []
        for i in range(len(input_array)):
            q = F.softmax(torch.tensor(input_array[i]), dim=0)
            kl_div = F.kl_div(original_logits, q, reduction='batchmean').item()
            kl_divergences.append(kl_div)

        divergences.append(kl_divergences)
    max_divergences = np.max(divergences, axis=0)
    return max_divergences

In [128]:
def test_perturbation(env, agent, epochs=1, tracking_agent="paddle_1", threshold=0.002, use_perturbation=False, above_threshold=False):
    language_importances = []
    obs, info = env.reset()
    state = env.state()
    average_length = []
    tokens = []

    for i in range(epochs):
        timestep = 0
        while True:
            timestep += 1
            tokens.append(obs[tracking_agent][-1 * env.sequence_length:])
            obs = [obs]
            state = [state]
            obs = np.array(flatten_list(obs))
            state = np.array(flatten_list(state))
            
            # integrated_grads = smoothgrad(obs_track, agent.agent.actor, 0, sigma=1.0, steps=30)
            language_perturbation = perturbation(obs, agent.agent.actor, env.vocab_size, env.sequence_length)
            language_importances.append(language_perturbation)

            # If any of language_importances is higher thnan 0.002
            if any([importance >= threshold for importance in language_perturbation]) and use_perturbation and above_threshold:
                language_observations = obs[:, -1 * env.vocab_size * env.sequence_length:]
                random_language = np.random.randint(0, env.vocab_size, (language_observations.shape[0], env.sequence_length))
                random_language = np.eye(env.vocab_size)[random_language]
                random_language = random_language.reshape(language_observations.shape[0], env.sequence_length * env.vocab_size)
                obs[:, -1 * env.vocab_size * env.sequence_length:] = random_language

            if all([importance <= threshold for importance in language_perturbation]) and use_perturbation and not above_threshold:
                language_observations = obs[:, -1 * env.vocab_size * env.sequence_length:]
                random_language = np.random.randint(0, env.vocab_size, (language_observations.shape[0], env.sequence_length))
                random_language = np.eye(env.vocab_size)[random_language]
                random_language = random_language.reshape(language_observations.shape[0], env.sequence_length * env.vocab_size)
                obs[:, -1 * env.vocab_size * env.sequence_length:] = random_language

            obs = torch.tensor(obs, dtype=torch.float32)
            state = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                actions, _, _, _ = agent.agent.get_action_and_value(obs, state)
                actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

            actions = actions[0]
            actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

            obs, _, truncations, terminations, infos = env.step(actions)
            state = env.state()

            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                average_length.append(timestep)
                obs, info = env.reset()
                state = env.state()
                break
    return np.array(language_importances), average_length

In [129]:
import os
directories = os.listdir("Plotting/saliencies_live/Multi_Pong/")
directories = [model for model in directories if ".DS_Store" not in model]
models = {}
for directory in directories:
    sequence_length = int(directory.split("_")[-1])
    agents = load(f"Plotting/saliencies_live/Multi_Pong/{directory}", sequence_length=sequence_length)
    agent_indizes = list(agents.keys())
    agent_indizes.sort()
    models[sequence_length] = agents[agent_indizes[-1]]

print(models)

{1: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x38584d600>, 2: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x33e600700>, 3: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x38584d990>}


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [130]:
def share_above_threshold(arr, threshold=0.15):
    # Check if each element is greater than the threshold
    condition = arr > threshold
    # Count rows where at least one element satisfies the condition
    count = np.any(condition, axis=1).sum()
    # Calculate the share
    share = count / arr.shape[0]
    return share

In [131]:
num_epochs = 100
threshold = 0.002

In [133]:
results = {}
for sequence_length, model in models.items():
    print(f"Testing sequence length {sequence_length}")
    results[sequence_length] = {}
    env = make_env(sequence_length=sequence_length)
    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=threshold, use_perturbation=False)
    noise_share = share_above_threshold(importances, threshold=threshold)
    results[sequence_length]["no noise"] = {"lengths": np.mean(lengths), "share_above_threshold": noise_share}

    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=threshold, use_perturbation=True, above_threshold=True)
    noise_share = share_above_threshold(importances, threshold=threshold)
    results[sequence_length]["above threshold"] = {"lengths": np.mean(lengths), "share_above_threshold": noise_share}

    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=threshold, use_perturbation=True, above_threshold=False)
    noise_share = share_above_threshold(importances, threshold=threshold)
    results[sequence_length]["below threshold"] = {"lengths": np.mean(lengths), "share_below_threshold": 1 - noise_share}

    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=0.00, use_perturbation=True, above_threshold=True)
    results[sequence_length]["all noise"] = {"lengths": np.mean(lengths), "share_above_threshold": 1}

Testing sequence length 1


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Testing sequence length 2
Testing sequence length 3


In [135]:
print(results)

{1: {'no noise': {'lengths': 843.99, 'share_above_threshold': 0.49174753255370324}, 'above threshold': {'lengths': 443.88, 'share_above_threshold': 0.4770658736595476}, 'below threshold': {'lengths': 803.13, 'share_below_threshold': 0.5113244431163074}, 'all noise': {'lengths': 460.4, 'share_above_threshold': 1}}, 2: {'no noise': {'lengths': 814.92, 'share_above_threshold': 0.2752171992342807}, 'above threshold': {'lengths': 438.84, 'share_above_threshold': 0.258249020144016}, 'below threshold': {'lengths': 809.63, 'share_below_threshold': 0.7205266603263218}, 'all noise': {'lengths': 427.4, 'share_above_threshold': 1}}, 3: {'no noise': {'lengths': 876.28, 'share_above_threshold': 0.2770689733875017}, 'above threshold': {'lengths': 379.26, 'share_above_threshold': 0.2593735168485999}, 'below threshold': {'lengths': 845.86, 'share_below_threshold': 0.7317995885843993}, 'all noise': {'lengths': 436.68, 'share_above_threshold': 1}}}
