In [54]:
from Continuous_Language.Environments.Collectors.collectors import Collectors
from Continuous_Language.Reinforcement_Learning.Centralized_PPO.multi_ppo import PPO_Multi_Agent_Centralized
from Continuous_Language.Reinforcement_Learning.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
import torch
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [55]:
def make_env(sequence_length=0):
    vocab_size = 4
    max_episode_steps = 2048
    env = Collectors(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_timesteps=max_episode_steps, timestep_countdown=15)
    # env = ParallelFrameStack(env, 4)
    return env

In [56]:
def load(path="models/checkpoints", vocab_size=3):
    models = {}
    for model in os.listdir(path):
        if "pt" in model:
            sequence_length = model.split("_")[-1]
            sequence_length = int(sequence_length.split(".")[0])
            if sequence_length > 0:
                env = make_env(sequence_length=sequence_length)
                state_dict = torch.load(os.path.join(path, model))
                try:
                    agent = PPO_Multi_Agent_Centralized(env, device="cpu")
                    agent.agent.load_state_dict(state_dict)
                    models[sequence_length] = agent
                except:
                    continue
    return models

In [57]:
def perturbation(inputs, model, vocab_size, sequence_length):
    
    # Extract environment inputs
    environment_inputs = inputs[:, :-1 * vocab_size * sequence_length]

    # Extract original logits
    inputs = torch.tensor(inputs, dtype=torch.float32)
    original_logits = model(inputs)
    original_logits = F.softmax(original_logits, dim=1).detach().numpy()
    original_logits = F.log_softmax(torch.tensor(original_logits), dim=1).detach()

    perturbation_logits = []
    for token in range(vocab_size):
        # One-hot encoded sequence of tokens
        utterances = np.array([token for _ in range(sequence_length)])
        utterances = np.eye(vocab_size)[utterances].flatten()
        utterances = np.expand_dims(utterances, axis=0)
        utterances = np.repeat(utterances, inputs.shape[0], axis=0)

        # Concatenate environment inputs with utterances
        perturbation_inputs = np.concatenate((environment_inputs, utterances), axis=1)
        perturbation_inputs = torch.tensor(perturbation_inputs, dtype=torch.float32)

        # Get logits for perturbed inputs
        current_logits = model(perturbation_inputs).detach().numpy()
        current_logits = F.softmax(torch.tensor(current_logits), dim=1).detach().numpy()

        perturbation_logits.append(current_logits)

    divergences = []
    for input_array in perturbation_logits:
        kl_divergences = []
        for i in range(len(input_array)):
            q = F.softmax(torch.tensor(input_array[i]), dim=0)
            kl_div = F.kl_div(original_logits, q, reduction='batchmean').item()
            kl_divergences.append(kl_div)

        divergences.append(kl_divergences)
    max_divergences = np.max(divergences, axis=0)
    return max_divergences

In [58]:
def test_perturbation(env, agent, epochs=1, tracking_agent="player_1", threshold=0.002, use_perturbation=False, above_threshold=False):
    language_importances = []
    obs, info = env.reset()
    state = env.state()
    average_length = []
    tokens = []

    for i in range(epochs):
        timestep = 0
        while True:
            timestep += 1
            tokens.append(obs[tracking_agent][-1 * env.sequence_length:])
            obs = [obs]
            state = [state]
            obs = np.array(flatten_list(obs))
            state = np.array(flatten_list(state))
            
            # integrated_grads = smoothgrad(obs_track, agent.agent.actor, 0, sigma=1.0, steps=30)
            language_perturbation = perturbation(obs, agent.agent.actor, env.vocab_size, env.sequence_length)
            language_importances.append(language_perturbation)

            # If any of language_importances is higher thnan 0.002
            if any([importance >= threshold for importance in language_perturbation]) and use_perturbation and above_threshold:
                language_observations = obs[:, -1 * env.vocab_size * env.sequence_length:]
                random_language = np.random.randint(0, env.vocab_size, (language_observations.shape[0], env.sequence_length))
                random_language = np.eye(env.vocab_size)[random_language]
                random_language = random_language.reshape(language_observations.shape[0], env.sequence_length * env.vocab_size)
                obs[:, -1 * env.vocab_size * env.sequence_length:] = random_language

            if all([importance <= threshold for importance in language_perturbation]) and use_perturbation and not above_threshold:
                language_observations = obs[:, -1 * env.vocab_size * env.sequence_length:]
                random_language = np.random.randint(0, env.vocab_size, (language_observations.shape[0], env.sequence_length))
                random_language = np.eye(env.vocab_size)[random_language]
                random_language = random_language.reshape(language_observations.shape[0], env.sequence_length * env.vocab_size)
                obs[:, -1 * env.vocab_size * env.sequence_length:] = random_language

            obs = torch.tensor(obs, dtype=torch.float32)
            state = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                actions, _, _, _ = agent.agent.get_action_and_value(obs, state)
                actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

            actions = actions[0]
            actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

            obs, _, truncations, terminations, infos = env.step(actions)
            state = env.state()

            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                average_length.append(timestep)
                obs, info = env.reset()
                state = env.state()
                break
    return np.array(language_importances), average_length

In [59]:
models = load("models/checkpoints_collectors_2/models")
print(models)

{4: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x365254e80>, 1: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x36534e110>, 2: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x36534ee90>, 3: <ThesisPackage.RL.Centralized_PPO.multi_ppo.PPO_Multi_Agent_Centralized object at 0x3654bde10>}


In [60]:
def share_above_threshold(arr, threshold=0.15):
    # Check if each element is greater than the threshold
    condition = arr > threshold
    # Count rows where at least one element satisfies the condition
    count = np.any(condition, axis=1).sum()
    # Calculate the share
    share = count / arr.shape[0]
    return share

In [61]:
num_epochs = 1000
threshold = 0.002

In [62]:
results = {}
for sequence_length, model in models.items():
    print(f"Testing sequence length {sequence_length}")
    results[sequence_length] = {}
    env = make_env(sequence_length=sequence_length)
    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=threshold, use_perturbation=False)
    noise_share = share_above_threshold(importances, threshold=threshold)
    results[sequence_length]["no noise"] = {"lengths": np.mean(lengths), "share_above_threshold": noise_share}

    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=threshold, use_perturbation=True, above_threshold=True)
    noise_share = share_above_threshold(importances, threshold=threshold)
    results[sequence_length]["above threshold"] = {"lengths": np.mean(lengths), "share_above_threshold": noise_share}

    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=threshold, use_perturbation=True, above_threshold=False)
    noise_share = share_above_threshold(importances, threshold=threshold)
    results[sequence_length]["below threshold"] = {"lengths": np.mean(lengths), "share_below_threshold": 1 - noise_share}

    importances, lengths = test_perturbation(env, model, epochs=num_epochs, threshold=0.00, use_perturbation=True, above_threshold=True)
    results[sequence_length]["all noise"] = {"lengths": np.mean(lengths), "share_above_threshold": 1}

Testing sequence length 4
Testing sequence length 1
Testing sequence length 2
Testing sequence length 3


In [65]:
print(results)

{4: {'no noise': {'lengths': 99.361, 'share_above_threshold': 0.7516128058292489}, 'above threshold': {'lengths': 83.439, 'share_above_threshold': 0.718536895216865}, 'below threshold': {'lengths': 84.903, 'share_below_threshold': 0.25743495518415127}, 'all noise': {'lengths': 76.167, 'share_above_threshold': 1}}, 1: {'no noise': {'lengths': 231.083, 'share_above_threshold': 0.9836941704928532}, 'above threshold': {'lengths': 47.004, 'share_above_threshold': 0.9106033529061357}, 'below threshold': {'lengths': 256.397, 'share_below_threshold': 0.016135134186437416}, 'all noise': {'lengths': 47.001, 'share_above_threshold': 1}}, 2: {'no noise': {'lengths': 186.954, 'share_above_threshold': 0.9545610150090397}, 'above threshold': {'lengths': 55.013, 'share_above_threshold': 0.8570156144911203}, 'below threshold': {'lengths': 180.497, 'share_below_threshold': 0.04769608359141708}, 'all noise': {'lengths': 52.779, 'share_above_threshold': 1}}, 3: {'no noise': {'lengths': 144.749, 'share_abo