In [12]:
from ThesisPackage.Environments.pong.multi_pong_language_continuous import PongEnv
from ThesisPackage.RL.Decentralized_PPO.multi_ppo import PPO_Multi_Agent
from ThesisPackage.RL.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
from ThesisPackage.Wrappers.vecWrapper import PettingZooVectorizationParallelWrapper
import torch
import numpy as np
import os
import time

In [13]:
def make_env():
    sequence_length = 2
    vocab_size = 3
    max_episode_steps = 512
    env = PongEnv(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_episode_steps=max_episode_steps)
    # env = ParallelFrameStack(env, 4)
    return env

In [14]:
def load():
    models = {}
    path = "/home/cowolff/Documents/GitHub/ma.pong_rl/models/checkpoints/"
    model_paths = os.listdir(path)
    env = make_env()
    for model in model_paths:
        timestep = model.split("_")[-1].split(".")[0]
        agent = PPO_Multi_Agent(env, device="cpu")
        state_dict = torch.load(path + model)
        agent.agent.load_state_dict(state_dict)
        models[timestep] = agent
    return models

In [15]:
num_steps = 10000
agents = load()
env = make_env()

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [16]:
def integrated_gradients(inputs, model, target_label_idx, baseline=None, steps=100):
    if baseline is None:
        baseline = torch.zeros_like(inputs)
    assert baseline.shape == inputs.shape
    
    grads = []

    for i in range(num_steps):
        scaled_input = torch.tensor(baseline + (float(i) / steps) * (inputs - baseline), requires_grad=True)
        # scaled_input.requires_grad = True
        logits = model(scaled_input)
        loss = logits[0, target_label_idx]
        loss.backward()
        grads.append(scaled_input.grad.data.cpu().numpy())

    avg_grads = np.average(grads[:-1], axis=0)
    integrated_grad = (inputs.detach().cpu().numpy() - baseline.cpu().numpy()) * avg_grads

    return integrated_grad

In [27]:
def get_means(env, model, epochs):
    means = []
    tokens = {channel: {0: 0, 1: 1, 2: 2} for channel in [1, 2]}
    for i in range(epochs):
        obs, info = env.reset()
        while True:
            obs = [obs]
            obs = np.array(flatten_list(obs))
            means.append(obs)
            obs = torch.tensor(obs, dtype=torch.float32, requires_grad=True)
            with torch.no_grad():
                action, _, _, _ = model.agent.get_action_and_value(obs)
                action = reverse_flatten_list_with_agent_list(action, model.agents)
            obs, rewards, terminations, truncations, info = env.step(action[0])
            for channel in [1, 2]:
                for paddle in env.agents:
                    cur_token = int(obs[paddle][-1 * channel])
                    tokens[channel][cur_token] += 1
            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                break
    means = np.stack(means)
    return means, tokens

first_key = list(agents.keys())[0]
means, tokens = get_means(env, agents[first_key], 100)
means = means.reshape(-1, means.shape[-1])
means = np.mean(means, axis=0)
print(means, tokens)

[ 9.00000000e+00 -1.58129956e+00 -1.85712326e+00 -3.78558340e-01
  1.88339501e-02 -8.97281367e-05 -1.44373852e+00 -6.61838009e-01
  1.59951383e-02 -1.60171943e-03  9.81933770e-01  1.00353330e+00] {1: {0: 27320, 1: 22634, 2: 27597}, 2: {0: 26495, 1: 25959, 2: 25097}}


In [28]:
def test_integrated_gradients(agent, means, timesteps=4096, tracking_agent="paddle_1"):
    saliencies = []
    full_saliences = []
    obs, info = env.reset()
    average_length = []
    average_noise_share = []
    tokens = []
    for i in range(timesteps):
        timestep = 0
    
        timestep += 1
        traking_index = env.agents.index(tracking_agent)
        tokens.append(obs[tracking_agent][-1 * env.sequence_length:])
        obs = [obs]
        obs = np.array(flatten_list(obs))

        obs_track = torch.tensor(np.expand_dims(obs[traking_index], axis=0), dtype=torch.float32, requires_grad=True)
        
        baselines = torch.tensor(np.expand_dims(means, axis=0), dtype=torch.float32)
        baselines[0] = 9.0
        integrated_grads = integrated_gradients(obs_track, agent.agent.actor, 0, baseline=baselines, steps=20)

        integrated_grads = (integrated_grads - integrated_grads.min()) / (integrated_grads.max() - integrated_grads.min())
        language_saliences = np.sum(integrated_grads[0, -1 * env.sequence_length:])

        obs = torch.tensor(obs, dtype=torch.float32)
        with torch.no_grad():
            actions, _, _, _ = agent.agent.get_action_and_value(obs)
            actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

        actions = actions[0]
        actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

        saliencies.append(language_saliences)
        full_saliences.append(integrated_grads)

        obs, _, truncations, terminations, infos = env.step(actions)

        if any([truncations[agent] or terminations[agent] for agent in env.agents]):
            average_length.append(timestep)
            obs, info = env.reset()
            break
        
    full_saliences = np.stack(full_saliences, axis=0)
    return saliencies, average_length, full_saliences, np.mean(average_noise_share), np.array(tokens)

In [29]:
def test_saliency(agent, timesteps=4096, tracking_agent="paddle_1"):
    saliencies = []
    full_saliences = []
    obs, info = env.reset()
    average_length = []
    average_noise_share = []
    tokens = []
    noises = []
    for i in range(timesteps):
        timestep = 0

        timestep += 1
        traking_index = env.agents.index(tracking_agent)
        tokens.append(obs[tracking_agent][-1 * env.sequence_length:])
        obs = [obs]
        obs = np.array(flatten_list(obs))

        obs_track = torch.tensor(np.expand_dims(obs[traking_index], axis=0), dtype=torch.float32, requires_grad=True)
        logits = agent.agent.actor(obs_track)

        grad_tensor = torch.zeros_like(logits)
        grad_tensor[:, :-6] = 1

        logits.backward(grad_tensor)

        saliency = obs_track.grad.data.abs()

        saliency = saliency.numpy()
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())

        average_language = np.sum(saliency[:, -1 * env.sequence_length:])

        obs = torch.tensor(obs, dtype=torch.float32)
        with torch.no_grad():
            actions, _, _, _ = agent.agent.get_action_and_value(obs)
            actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

        actions = actions[0]
        actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

        saliencies.append(average_language)
        full_saliences.append(saliency)

        obs, _, truncations, terminations, infos = env.step(actions)

        if any([truncations[agent] or terminations[agent] for agent in env.agents]):
            average_length.append(timestep)
            noises = []
            average_noise_share.append(np.mean(noises))
            obs, info = env.reset()
            break
    full_saliences = np.stack(full_saliences, axis=0)
    return saliencies, average_length, full_saliences, np.mean(average_noise_share), np.array(tokens)

In [30]:
import copy

means_saliencies = {}
for agent_name in list(agents.keys()):
    saliencies, average_length, full_saliences, _, _ = test_integrated_gradients(agents[agent_name], copy.deepcopy(means), timesteps=32768)
    means_saliencies[agent_name] = np.mean(saliencies)

print(means)

  scaled_input = torch.tensor(baseline + (float(i) / steps) * (inputs - baseline), requires_grad=True)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


[ 9.00000000e+00 -1.58129956e+00 -1.85712326e+00 -3.78558340e-01
  1.88339501e-02 -8.97281367e-05 -1.44373852e+00 -6.61838009e-01
  1.59951383e-02 -1.60171943e-03  9.81933770e-01  1.00353330e+00]


In [31]:
print(means_saliencies)

{'3048': 1.3496369, '1524': 1.27434, '2540': 1.2016377, '3683': 1.3716365, '3175': 1.3308553, '1778': 1.2708437, '1016': 1.0753932, '762': 1.240871, '3429': 1.310518, '2286': 1.268866, '1270': 1.41319, '2667': 1.3216597, '2032': 1.3135335, '3302': 1.305145, '635': 0.91906345, '2159': 1.2909157, '1397': 1.2998027, '381': 1.1033112, '508': 1.1053104, '127': 1.0444089, '2921': 1.3981344, '1905': 1.3251836, '254': 1.0210496, '1651': 1.2641459, '1143': 1.2825036, '3810': 1.3267655, '3556': 1.2999666, '2794': 1.3242552, '889': 1.1656408, '2413': 1.2779877}


In [19]:
saliencies, average_length, full_saliences, _, tokens = test(epochs=1)