In [146]:
from ThesisPackage.Environments.multi_pong_sender_receiver_ball_onehot import PongEnvSenderReceiverBallOneHot
from ThesisPackage.RL.Centralized_PPO.multi_ppo import PPO_Multi_Agent_Centralized
from ThesisPackage.RL.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
from ThesisPackage.Wrappers.vecWrapper import PettingZooVectorizationParallelWrapper
import torch
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import copy

In [147]:
def make_env(max_episode_steps = 1024, sequence_length = 1, vocab_size = 3):
    env = PongEnvSenderReceiverBallOneHot(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_episode_steps=max_episode_steps)
    # env = ParallelFrameStack(env, 4)
    return env

In [148]:
def load(path="models/checkpoints", sequence_length=1, vocab_size=3):
    env = make_env(sequence_length=sequence_length, vocab_size=vocab_size)
    models = {}
    for model in os.listdir(path):
        if "pong" in model:
            state_dict = torch.load(os.path.join(path, model))
            timestamp = model.split("_")[-1]
            timestamp = int(timestamp.split(".")[0])
            agent = PPO_Multi_Agent_Centralized(env, device="cpu")
            agent.agent.load_state_dict(state_dict)
            models[timestamp] = agent
    return models

In [149]:
def perturbation(inputs, model, vocab_size, sequence_length):
    
    # Extract environment inputs
    environment_inputs = inputs[:, :-1 * vocab_size * sequence_length]

    # Extract original logits
    inputs = torch.tensor(inputs, dtype=torch.float32)
    original_logits = model(inputs)
    original_logits = F.softmax(original_logits, dim=1).detach().numpy()
    original_logits = F.log_softmax(torch.tensor(original_logits), dim=1).detach()

    perturbation_logits = []
    for token in range(vocab_size):
        # One-hot encoded sequence of tokens
        utterances = np.array([token for _ in range(sequence_length)])
        utterances = np.eye(vocab_size)[utterances].flatten()
        utterances = np.expand_dims(utterances, axis=0)
        utterances = np.repeat(utterances, inputs.shape[0], axis=0)

        # Concatenate environment inputs with utterances
        perturbation_inputs = np.concatenate((environment_inputs, utterances), axis=1)
        perturbation_inputs = torch.tensor(perturbation_inputs, dtype=torch.float32)

        # Get logits for perturbed inputs
        current_logits = model(perturbation_inputs).detach().numpy()
        current_logits = F.softmax(torch.tensor(current_logits), dim=1).detach().numpy()

        perturbation_logits.append(current_logits)

    divergences = []
    for input_array in perturbation_logits:
        kl_divergences = []
        for i in range(len(input_array)):
            q = F.softmax(torch.tensor(input_array[i]), dim=0)
            kl_div = F.kl_div(original_logits, q, reduction='batchmean').item()
            kl_divergences.append(kl_div)

        divergences.append(kl_divergences)
    max_divergences = np.max(divergences, axis=0)
    return max_divergences

In [150]:
from tqdm import tqdm

def test_perturbation(env, agent, threshold=0.02, number_samples=30000):
    language_importances = []
    obs, info = env.reset()
    state = env.state()
    average_length = []
    tokens = []
    data = {"paddle_1": [], "paddle_2": [], "paddle_1_obs": [], "paddle_2_obs": [], "ball_1_pos": [], "ball_2_pos": []}

    with tqdm(total=number_samples) as pbar:
        timestep = 0
        while True:
            for cur_agent in env.agents:
                data[cur_agent].append(copy.deepcopy(env.paddles[cur_agent]))
                data[cur_agent + "_obs"].append(copy.deepcopy(obs[cur_agent]))
            for cur_ball in env.balls.keys():
                data[cur_ball + "_pos"].append(copy.deepcopy(env.balls[cur_ball]["position"]))
            obs = [obs]
            state = [state]
            obs = np.array(flatten_list(obs))
            state = np.array(flatten_list(state))
            
            # integrated_grads = smoothgrad(obs_track, agent.agent.actor, 0, sigma=1.0, steps=30)
            language_perturbation = perturbation(obs, agent.agent.actor, env.vocab_size, env.sequence_length)
            language_importances.append(language_perturbation)

            if np.any(language_perturbation > threshold):
                pbar.update(1)
                timestep += 1

            if timestep > number_samples:
                break

            obs = torch.tensor(obs, dtype=torch.float32)
            state = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                actions, _, _, _ = agent.agent.get_action_and_value(obs, state)
                actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

            actions = actions[0]
            actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

            obs, _, truncations, terminations, infos = env.step(actions)
            state = env.state()

            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                average_length.append(env.timestep)
                obs, info = env.reset()
                state = env.state()
    return np.array(language_importances), data, average_length

In [151]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Step 2: Define the dataset and dataloader
class PositionDataset(Dataset):
    def __init__(self, data, labels, device):
        self.data = torch.tensor(data, dtype=torch.float32, device=device)
        self.labels = torch.tensor(labels, dtype=torch.long, device=device)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Step 3: Define the model architecture
class SimpleClassifier(nn.Module):
    def __init__(self, input_size):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 2)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [152]:
def train(num_epochs, dataloader, input_size, learning_rate, device):
    
    model = SimpleClassifier(input_size).to(device)

    # Step 4: Train the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.5)

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        
        for inputs, targets in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
        
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = correct_predictions / total_predictions

        # scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

    print("Training complete!")
    return epoch_accuracy, model

In [153]:
def test(model, dataloader, device):
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
    
    accuracy = correct_predictions / total_predictions
    print(f"Accuracy: {accuracy:.4f}")
    return accuracy

In [154]:
results = {}

In [155]:
def calculate_chance_level(labels):
    unique, counts = np.unique(labels, return_counts=True)
    index = np.argmax(counts)

    chance_level = counts[index] / len(labels)
    return chance_level

In [156]:
import math
from tqdm import tqdm


def generate_training_data(language_importances_indices, data, sequence_length=1, vocab_size=3):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]
    
    inputs = []
    labels = []

    pos_indices = paddle_1_indices[3:-6]
    pos_2_indices = [index + 6 for index in pos_indices]

    targets = [np.array(data["ball_1_pos"]), np.array(data["ball_2_pos"])]

    distances = [abs(np.array(data["paddle_1"])[pos_2_indices] - target[pos_2_indices, 1]) for target in targets]

    labels = np.argmin(distances, axis=0)

    paddle_1_obs = np.array(data["paddle_1_obs"])
    paddle_1_obs = np.array([paddle_1_obs[index - 3:index + 3] for index in pos_indices])
    player_1_lang = paddle_1_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_1_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_1_lang = player_1_lang.reshape(new_shape)

    paddle_2_obs = np.array(data["paddle_2_obs"])
    paddle_2_obs = np.array([paddle_2_obs[index - 3:index + 3] for index in pos_indices])
    player_2_lang = paddle_2_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_2_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_2_lang = player_2_lang.reshape(new_shape)

    inputs = np.concatenate((player_1_lang, player_2_lang), axis=1)

    inputs = np.array(inputs)
    labels = np.array(labels)

    return inputs, labels

def generate_training_data_obs(language_importances_indices, data, sequence_length=1, vocab_size=3, length=20, height=20, noise=False):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]

    labels = np.zeros_like(paddle_1_indices)
    labels = np.where(np.array(data["paddle_1"])[paddle_1_indices] > np.array(data["paddle_2"])[paddle_1_indices], 1, 0)

    inputs = np.array(data["paddle_1_obs"])[paddle_1_indices]

    if noise:
        random_utterances = np.random.randint(low=0, high=vocab_size, size=(inputs.shape[0], sequence_length))
        random_utterances = np.eye(vocab_size)[random_utterances]
        random_utterances = random_utterances.reshape((inputs.shape[0], sequence_length * vocab_size))
        inputs[:, -1 * sequence_length * vocab_size:] = random_utterances

    return inputs, labels

In [157]:
import os
directories = os.listdir("/Users/cowolff/Documents/GitHub/ma.pong_rl/Plotting/saliencies_live/Multi_Pong/")
directories = [model for model in directories if ".DS_Store" not in model]
models = {}
for directory in directories:
    sequence_length = int(directory.split("_")[-1])
    agents = load(f"/Users/cowolff/Documents/GitHub/ma.pong_rl/Plotting/saliencies_live/Multi_Pong/{directory}", sequence_length=sequence_length)
    agent_indizes = list(agents.keys())
    agent_indizes.sort()
    models[sequence_length] = agents[agent_indizes[-1]]

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [158]:
device = "mps"

In [131]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)

    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data(language_importances_larger, data, sequence_length=seq, vocab_size=3)
    dataset = PositionDataset(larger_inputs, larger_labels, device)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, device)
    results[seq]["above threshold"] = accuracy

    test_importances, test_data, test_average_length = test_perturbation(env, agent, threshold=0.02, number_samples=10000)
    test_importances_larger = np.where(test_importances > 0.02)
    test_inputs, test_labels = generate_training_data(test_importances_larger, test_data, sequence_length=seq, vocab_size=3)
    test_dataset = PositionDataset(test_inputs, test_labels, device)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    accuracy = test(model, test_dataloader, device)
    print(calculate_chance_level(test_labels))
    

Epoch 1/120, Loss: 0.6277, Accuracy: 0.6096
Epoch 2/120, Loss: 0.6246, Accuracy: 0.6104
Epoch 3/120, Loss: 0.6236, Accuracy: 0.6104
Epoch 4/120, Loss: 0.6239, Accuracy: 0.6107
Epoch 5/120, Loss: 0.6235, Accuracy: 0.6101
Epoch 6/120, Loss: 0.6238, Accuracy: 0.6115
Epoch 7/120, Loss: 0.6234, Accuracy: 0.6116
Epoch 8/120, Loss: 0.6230, Accuracy: 0.6108
Epoch 9/120, Loss: 0.6232, Accuracy: 0.6110
Epoch 10/120, Loss: 0.6231, Accuracy: 0.6090
Epoch 11/120, Loss: 0.6230, Accuracy: 0.6106
Epoch 12/120, Loss: 0.6233, Accuracy: 0.6095
Epoch 13/120, Loss: 0.6230, Accuracy: 0.6106
Epoch 14/120, Loss: 0.6232, Accuracy: 0.6106
Epoch 15/120, Loss: 0.6231, Accuracy: 0.6096
Epoch 16/120, Loss: 0.6234, Accuracy: 0.6074
Epoch 17/120, Loss: 0.6230, Accuracy: 0.6098
Epoch 18/120, Loss: 0.6230, Accuracy: 0.6087
Epoch 19/120, Loss: 0.6232, Accuracy: 0.6089
Epoch 20/120, Loss: 0.6231, Accuracy: 0.6099
Epoch 21/120, Loss: 0.6233, Accuracy: 0.6106


KeyboardInterrupt: 

In [None]:
print(results)

{1: {'above threshold': 0.7572159512557122}, 2: {'above threshold': 0.8025763319880967}, 3: {'above threshold': 0.905081998474447}}


In [159]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)
    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data_obs(language_importances_larger, data, sequence_length=seq, vocab_size=4)
    dataset = PositionDataset(larger_inputs, larger_labels, "cpu")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, "cpu")

    test_language_importances, test_data, test_average_length = test_perturbation(env, agent, number_samples=30000)
    test_language_importances_larger = np.where(test_language_importances > 0.02)
    test_larger_inputs, test_larger_labels = generate_training_data_obs(test_language_importances_larger, test_data, sequence_length=seq, vocab_size=4)
    test_dataset = PositionDataset(test_larger_inputs, test_larger_labels, "cpu")
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    test_input_size = test_larger_inputs.shape[1]
    test_accuracy = test(model, test_dataloader, "cpu")
    results[seq]["test"] = test_accuracy

300001it [33:38, 148.59it/s]                             


Epoch 1/120, Loss: 0.3880, Accuracy: 0.8211
Epoch 2/120, Loss: 0.3186, Accuracy: 0.8601
Epoch 3/120, Loss: 0.3056, Accuracy: 0.8653
Epoch 4/120, Loss: 0.2986, Accuracy: 0.8693
Epoch 5/120, Loss: 0.2929, Accuracy: 0.8711
Epoch 6/120, Loss: 0.2904, Accuracy: 0.8721
Epoch 7/120, Loss: 0.2873, Accuracy: 0.8738
Epoch 8/120, Loss: 0.2853, Accuracy: 0.8745
Epoch 9/120, Loss: 0.2836, Accuracy: 0.8754
Epoch 10/120, Loss: 0.2823, Accuracy: 0.8760
Epoch 11/120, Loss: 0.2807, Accuracy: 0.8768
Epoch 12/120, Loss: 0.2801, Accuracy: 0.8774
Epoch 13/120, Loss: 0.2794, Accuracy: 0.8784
Epoch 14/120, Loss: 0.2783, Accuracy: 0.8781
Epoch 15/120, Loss: 0.2773, Accuracy: 0.8792
Epoch 16/120, Loss: 0.2768, Accuracy: 0.8786
Epoch 17/120, Loss: 0.2762, Accuracy: 0.8792
Epoch 18/120, Loss: 0.2754, Accuracy: 0.8799
Epoch 19/120, Loss: 0.2750, Accuracy: 0.8799
Epoch 20/120, Loss: 0.2752, Accuracy: 0.8795
Epoch 21/120, Loss: 0.2742, Accuracy: 0.8792
Epoch 22/120, Loss: 0.2736, Accuracy: 0.8804
Epoch 23/120, Loss:

30001it [04:02, 123.90it/s]                           


Accuracy: 0.8816


300001it [1:07:45, 73.80it/s]                             


Epoch 1/120, Loss: 0.2927, Accuracy: 0.8762
Epoch 2/120, Loss: 0.2220, Accuracy: 0.9111
Epoch 3/120, Loss: 0.2073, Accuracy: 0.9164
Epoch 4/120, Loss: 0.2001, Accuracy: 0.9198
Epoch 5/120, Loss: 0.1949, Accuracy: 0.9213
Epoch 6/120, Loss: 0.1914, Accuracy: 0.9233
Epoch 7/120, Loss: 0.1889, Accuracy: 0.9237
Epoch 8/120, Loss: 0.1867, Accuracy: 0.9247
Epoch 9/120, Loss: 0.1842, Accuracy: 0.9259
Epoch 10/120, Loss: 0.1824, Accuracy: 0.9269
Epoch 11/120, Loss: 0.1809, Accuracy: 0.9273
Epoch 12/120, Loss: 0.1797, Accuracy: 0.9279
Epoch 13/120, Loss: 0.1779, Accuracy: 0.9279
Epoch 14/120, Loss: 0.1771, Accuracy: 0.9285
Epoch 15/120, Loss: 0.1763, Accuracy: 0.9288
Epoch 16/120, Loss: 0.1758, Accuracy: 0.9290
Epoch 17/120, Loss: 0.1753, Accuracy: 0.9301
Epoch 18/120, Loss: 0.1733, Accuracy: 0.9301
Epoch 19/120, Loss: 0.1732, Accuracy: 0.9305
Epoch 20/120, Loss: 0.1724, Accuracy: 0.9311
Epoch 21/120, Loss: 0.1716, Accuracy: 0.9311
Epoch 22/120, Loss: 0.1711, Accuracy: 0.9308
Epoch 23/120, Loss:

30001it [06:03, 82.63it/s]                            


Accuracy: 0.9332


300001it [1:25:48, 58.27it/s]                              


Epoch 1/120, Loss: 0.3024, Accuracy: 0.8722
Epoch 2/120, Loss: 0.2284, Accuracy: 0.9115
Epoch 3/120, Loss: 0.2162, Accuracy: 0.9155
Epoch 4/120, Loss: 0.2112, Accuracy: 0.9176
Epoch 5/120, Loss: 0.2056, Accuracy: 0.9200
Epoch 6/120, Loss: 0.2033, Accuracy: 0.9212
Epoch 7/120, Loss: 0.1999, Accuracy: 0.9225
Epoch 8/120, Loss: 0.1978, Accuracy: 0.9227
Epoch 9/120, Loss: 0.1966, Accuracy: 0.9232
Epoch 10/120, Loss: 0.1937, Accuracy: 0.9241
Epoch 11/120, Loss: 0.1922, Accuracy: 0.9252
Epoch 12/120, Loss: 0.1916, Accuracy: 0.9258
Epoch 13/120, Loss: 0.1904, Accuracy: 0.9257
Epoch 14/120, Loss: 0.1896, Accuracy: 0.9263
Epoch 15/120, Loss: 0.1880, Accuracy: 0.9268
Epoch 16/120, Loss: 0.1873, Accuracy: 0.9267
Epoch 17/120, Loss: 0.1865, Accuracy: 0.9273
Epoch 18/120, Loss: 0.1862, Accuracy: 0.9276
Epoch 19/120, Loss: 0.1850, Accuracy: 0.9280
Epoch 20/120, Loss: 0.1842, Accuracy: 0.9283
Epoch 21/120, Loss: 0.1841, Accuracy: 0.9285
Epoch 22/120, Loss: 0.1827, Accuracy: 0.9286
Epoch 23/120, Loss:

30001it [09:00, 55.50it/s]                            


Accuracy: 0.9308


In [160]:
print(results)

{1: {'test': 0.8816442239546421}, 2: {'test': 0.9331573604060913}, 3: {'test': 0.9308098340387084}}
