In [204]:
from ThesisPackage.Environments.collectors.collectors_env_discrete_onehot import Collectors
from ThesisPackage.RL.Centralized_PPO.multi_ppo import PPO_Multi_Agent_Centralized
from ThesisPackage.RL.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
from ThesisPackage.Wrappers.vecWrapper import PettingZooVectorizationParallelWrapper
import torch
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import copy

In [205]:
def make_env(sequence_length=0):
    vocab_size = 4
    max_episode_steps = 2048
    env = Collectors(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_timesteps=max_episode_steps, timestep_countdown=15)
    # env = ParallelFrameStack(env, 4)
    return env

In [206]:
def load(path="models/checkpoints", vocab_size=3):
    models = {}
    for model in os.listdir(path):
        if "pt" in model:
            sequence_length = model.split("_")[-1]
            sequence_length = int(sequence_length.split(".")[0])
            if sequence_length > 0:
                env = make_env(sequence_length=sequence_length)
                state_dict = torch.load(os.path.join(path, model))
                try:
                    agent = PPO_Multi_Agent_Centralized(env, device="cpu")
                    agent.agent.load_state_dict(state_dict)
                    models[sequence_length] = agent
                except:
                    continue
    return models

In [207]:
def perturbation(inputs, model, vocab_size, sequence_length):
    
    # Extract environment inputs
    environment_inputs = inputs[:, :-1 * vocab_size * sequence_length]

    # Extract original logits
    inputs = torch.tensor(inputs, dtype=torch.float32)
    original_logits = model(inputs)
    original_logits = F.softmax(original_logits, dim=1).detach().numpy()
    original_logits = F.log_softmax(torch.tensor(original_logits), dim=1).detach()

    perturbation_logits = []
    for token in range(vocab_size):
        # One-hot encoded sequence of tokens
        utterances = np.array([token for _ in range(sequence_length)])
        utterances = np.eye(vocab_size)[utterances].flatten()
        utterances = np.expand_dims(utterances, axis=0)
        utterances = np.repeat(utterances, inputs.shape[0], axis=0)

        # Concatenate environment inputs with utterances
        perturbation_inputs = np.concatenate((environment_inputs, utterances), axis=1)
        perturbation_inputs = torch.tensor(perturbation_inputs, dtype=torch.float32)

        # Get logits for perturbed inputs
        current_logits = model(perturbation_inputs).detach().numpy()
        current_logits = F.softmax(torch.tensor(current_logits), dim=1).detach().numpy()

        perturbation_logits.append(current_logits)

    divergences = []
    for input_array in perturbation_logits:
        kl_divergences = []
        for i in range(len(input_array)):
            q = F.softmax(torch.tensor(input_array[i]), dim=0)
            kl_div = F.kl_div(original_logits, q, reduction='batchmean').item()
            kl_divergences.append(kl_div)

        divergences.append(kl_divergences)
    max_divergences = np.max(divergences, axis=0)
    return max_divergences

In [208]:
from tqdm import tqdm

def test_perturbation(env, agent, threshold=0.02, number_samples=30000):
    language_importances = []
    obs, info = env.reset()
    state = env.state()
    average_length = []
    tokens = []
    data = {"player_1": [], "player_2": [], "player_1_obs": [], "player_2_obs":[], "player_1_direction": [], "player_2_direction": [], "target_1": [], "target_2": [], "target_3":[]}

    with tqdm(total=number_samples) as pbar:
        timestep = 0
        while True:
            for cur_agent in env.agents:
                data[cur_agent].append(copy.deepcopy(env.players[cur_agent].position))
                data[cur_agent + "_direction"].append(copy.deepcopy(env.players[cur_agent].direction))
                data[cur_agent + "_obs"].append(copy.deepcopy(obs[cur_agent]))
            for i in range(env.num_targets):
                if len(env.targets) > i:
                    target = env.targets[i].position
                else:
                    target = np.array([-1, -1])
                data["target_" + str(i+1)].append(copy.deepcopy(target))
            obs = [obs]
            state = [state]
            obs = np.array(flatten_list(obs))
            state = np.array(flatten_list(state))
            
            # integrated_grads = smoothgrad(obs_track, agent.agent.actor, 0, sigma=1.0, steps=30)
            language_perturbation = perturbation(obs, agent.agent.actor, env.vocab_size, env.sequence_length)
            language_importances.append(language_perturbation)

            if np.any(language_perturbation > threshold):
                pbar.update(1)
                timestep += 1

            if timestep > number_samples:
                break

            obs = torch.tensor(obs, dtype=torch.float32)
            state = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                actions, _, _, _ = agent.agent.get_action_and_value(obs, state)
                actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

            actions = actions[0]
            actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

            obs, _, truncations, terminations, infos = env.step(actions)
            state = env.state()

            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                average_length.append(env.timestep)
                obs, info = env.reset()
                state = env.state()
    return np.array(language_importances), data, average_length

In [209]:
import numpy as np
import math

def generate_training_data(language_importances_indices, data, sequence_length=1, vocab_size=3, length=20, height=20):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]
    paddle_1_indices = np.array(paddle_1_indices)[3:-3]
    
    inputs = []
    labels = []
        
    player_1_data = np.array(data["player_1"])[paddle_1_indices]
        
    for index in paddle_1_indices:
        if data["player_1"][index][0] >= length / 2 and data["player_1"][index][1] < height / 2:
            labels.append(1)
        elif data["player_1"][index][0] < length / 2 and data["player_1"][index][1] >= height / 2:
            labels.append(2)
        elif data["player_1"][index][0] >= length / 2 and data["player_1"][index][1] >= height / 2:
            labels.append(3)
        else:
            labels.append(0)

    targets = np.zeros(len(paddle_1_indices))
    targets = np.where((player_1_data[:, 0] >= length / 2) & (player_1_data[:, 1] < height / 2), 1, targets)
    targets = np.where((player_1_data[:, 0] < length / 2) & (player_1_data[:, 1] >= height / 2), 2, targets)
    targets = np.where((player_1_data[:, 0] >= length / 2) & (player_1_data[:, 1] >= height / 2), 3, targets)
    
    player_1_obs = np.array(data["player_1_obs"])
    player_1_obs = np.array([player_1_obs[index - 3:index + 3] for index in paddle_1_indices])
    player_1_lang = player_1_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_1_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_1_lang = player_1_lang.reshape(new_shape)

    player_2_obs = np.array(data["player_2_obs"])
    player_2_obs = np.array([player_2_obs[index - 3:index + 3] for index in paddle_1_indices])
    player_2_lang = player_2_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_2_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_2_lang = player_2_lang.reshape(new_shape)

    inputs = np.concatenate((player_1_lang, player_2_lang), axis=1)

    inputs = np.array(inputs)
    labels = np.array(labels)

    return inputs, labels

def generate_training_data_obs(language_importances_indices, data, sequence_length=1, vocab_size=3, length=20, height=20, noise=False):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]
    paddle_1_indices = np.array(paddle_1_indices)[3:-3]
    
    inputs = []
    labels = []
        
    player_1_data = np.array(data["player_1"])[paddle_1_indices]
        
    for index in paddle_1_indices:
        if data["player_2"][index][0] >= length / 2 and data["player_2"][index][1] < height / 2:
            labels.append(1)
        elif data["player_2"][index][0] < length / 2 and data["player_2"][index][1] >= height / 2:
            labels.append(2)
        elif data["player_2"][index][0] >= length / 2 and data["player_2"][index][1] >= height / 2:
            labels.append(3)
        else:
            labels.append(0)
    
    player_1_obs = np.array(data["player_1_obs"])
    inputs = np.array([player_1_obs[index] for index in paddle_1_indices])

    if noise:
        random_utterances = np.random.randint(low=0, high=vocab_size, size=(inputs.shape[0], sequence_length))
        random_utterances = np.eye(vocab_size)[random_utterances]
        random_utterances = random_utterances.reshape((inputs.shape[0], sequence_length * vocab_size))
        inputs[:, -1 * sequence_length * vocab_size:] = random_utterances

    inputs = np.array(inputs)
    labels = np.array(labels)

    return inputs, labels

In [210]:
models = load("/Users/cowolff/Documents/GitHub/ma.pong_rl/models/checkpoints_collectors_2/models")

In [211]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Step 2: Define the dataset and dataloader
class PositionDataset(Dataset):
    def __init__(self, data, labels, device):
        self.data = torch.tensor(data, dtype=torch.float32, device=device)
        self.labels = torch.tensor(labels, dtype=torch.long, device=device)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Step 3: Define the model architecture
class SimpleClassifier(nn.Module):
    def __init__(self, input_size):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 4)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [212]:
def train(num_epochs, dataloader, input_size, learning_rate, device):
    model = SimpleClassifier(input_size).to(device)

    # Step 4: Train the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.5)

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = [0, 0, 0]
        total_predictions = [0, 0, 0]
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            for i in range(3):  # Assuming 3 classes
                correct_predictions[i] += ((predicted == i) & (targets == i)).sum().item()
                total_predictions[i] += (targets == i).sum().item()
        
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = sum(correct_predictions) / sum(total_predictions)
        
        # Label-specific accuracy
        label_accuracies = [correct_predictions[i] / total_predictions[i] if total_predictions[i] > 0 else 0 for i in range(3)]
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Overall Accuracy: {epoch_accuracy:.4f}")
        for i in range(3):
            print(f"Label {i} Accuracy: {label_accuracies[i]:.4f}")

    print("Training complete!")
    return epoch_accuracy, model

In [213]:
def test(model, dataloader, device):
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
    
    accuracy = correct_predictions / total_predictions
    print(f"Accuracy: {accuracy:.4f}")
    return accuracy

In [214]:
def calculate_chance_level(labels):
    unique, counts = np.unique(labels, return_counts=True)
    index = np.argmax(counts)

    chance_level = counts[index] / len(labels)
    return chance_level

In [190]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=500000)

    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data(language_importances_larger, data, sequence_length=seq, vocab_size=4)
    dataset = PositionDataset(larger_inputs, larger_labels, "cpu")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy = train(120, dataloader, input_size, 0.001, "cpu")
    results[seq]["above threshold"] = accuracy

  0%|          | 552/500000 [00:04<1:02:09, 133.91it/s]


KeyboardInterrupt: 

In [215]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)
    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data_obs(language_importances_larger, data, sequence_length=seq, vocab_size=4)
    dataset = PositionDataset(larger_inputs, larger_labels, "cpu")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, "cpu")

    language_importances_test, data_test, average_length_test = test_perturbation(env, agent, number_samples=30000)
    language_importances_larger_test = np.where(language_importances_test > 0.02)
    larger_inputs_test, larger_labels_test = generate_training_data_obs(language_importances_larger_test, data_test, sequence_length=seq, vocab_size=4)
    dataset_test = PositionDataset(larger_inputs_test, larger_labels_test, "cpu")
    dataloader_test = DataLoader(dataset_test, batch_size=32, shuffle=True)
    input_size = larger_inputs_test.shape[1]
    accuracy_test = test(model, dataloader_test, "cpu")

    language_importances_test_noise, data_test_noise, average_length_test_noise = test_perturbation(env, agent, number_samples=30000)
    language_importances_larger_test_noise = np.where(language_importances_test_noise > 0.02)
    larger_inputs_test_noise, larger_labels_test_noise = generate_training_data_obs(language_importances_larger_test_noise, data_test_noise, sequence_length=seq, vocab_size=4, noise=True)
    dataset_test_noise = PositionDataset(larger_inputs_test_noise, larger_labels_test_noise, "cpu")
    dataloader_test_noise = DataLoader(dataset_test_noise, batch_size=32, shuffle=True)
    input_size = larger_inputs_test_noise.shape[1]
    accuracy_test_noise = test(model, dataloader_test_noise, "cpu")

300001it [35:52, 139.40it/s]                            


Epoch 1/120, Loss: 1.1294, Overall Accuracy: 0.4921
Label 0 Accuracy: 0.5717
Label 1 Accuracy: 0.4653
Label 2 Accuracy: 0.4251
Epoch 2/120, Loss: 1.0970, Overall Accuracy: 0.5116
Label 0 Accuracy: 0.5946
Label 1 Accuracy: 0.4854
Label 2 Accuracy: 0.4399
Epoch 3/120, Loss: 1.0849, Overall Accuracy: 0.5174
Label 0 Accuracy: 0.5990
Label 1 Accuracy: 0.4916
Label 2 Accuracy: 0.4469
Epoch 4/120, Loss: 1.0766, Overall Accuracy: 0.5229
Label 0 Accuracy: 0.6037
Label 1 Accuracy: 0.4977
Label 2 Accuracy: 0.4529
Epoch 5/120, Loss: 1.0709, Overall Accuracy: 0.5250
Label 0 Accuracy: 0.6047
Label 1 Accuracy: 0.5009
Label 2 Accuracy: 0.4551
Epoch 6/120, Loss: 1.0675, Overall Accuracy: 0.5272
Label 0 Accuracy: 0.6018
Label 1 Accuracy: 0.5065
Label 2 Accuracy: 0.4596
Epoch 7/120, Loss: 1.0641, Overall Accuracy: 0.5277
Label 0 Accuracy: 0.6025
Label 1 Accuracy: 0.5086
Label 2 Accuracy: 0.4580
Epoch 8/120, Loss: 1.0615, Overall Accuracy: 0.5295
Label 0 Accuracy: 0.6054
Label 1 Accuracy: 0.5086
Label 2 A

30001it [03:42, 134.66it/s]                           


Accuracy: 0.5212


30001it [03:44, 133.52it/s]                           


Accuracy: 0.5007


300001it [04:51, 1029.19it/s]                            


Epoch 1/120, Loss: 0.9340, Overall Accuracy: 0.5609
Label 0 Accuracy: 0.5125
Label 1 Accuracy: 0.6056
Label 2 Accuracy: 0.5593
Epoch 2/120, Loss: 0.9009, Overall Accuracy: 0.5674
Label 0 Accuracy: 0.5392
Label 1 Accuracy: 0.6036
Label 2 Accuracy: 0.5554
Epoch 3/120, Loss: 0.8942, Overall Accuracy: 0.5702
Label 0 Accuracy: 0.5469
Label 1 Accuracy: 0.6101
Label 2 Accuracy: 0.5495
Epoch 4/120, Loss: 0.8903, Overall Accuracy: 0.5730
Label 0 Accuracy: 0.5523
Label 1 Accuracy: 0.6138
Label 2 Accuracy: 0.5487
Epoch 5/120, Loss: 0.8874, Overall Accuracy: 0.5750
Label 0 Accuracy: 0.5543
Label 1 Accuracy: 0.6145
Label 2 Accuracy: 0.5521
Epoch 6/120, Loss: 0.8849, Overall Accuracy: 0.5789
Label 0 Accuracy: 0.5538
Label 1 Accuracy: 0.6250
Label 2 Accuracy: 0.5532
Epoch 7/120, Loss: 0.8831, Overall Accuracy: 0.5792
Label 0 Accuracy: 0.5525
Label 1 Accuracy: 0.6229
Label 2 Accuracy: 0.5575
Epoch 8/120, Loss: 0.8815, Overall Accuracy: 0.5816
Label 0 Accuracy: 0.5499
Label 1 Accuracy: 0.6285
Label 2 A

30001it [00:52, 569.28it/s]                           


Accuracy: 0.5606


30001it [00:51, 577.58it/s]                           


Accuracy: 0.5027


300001it [14:30, 344.70it/s]                            


Epoch 1/120, Loss: 0.8559, Overall Accuracy: 0.5993
Label 0 Accuracy: 0.6209
Label 1 Accuracy: 0.5986
Label 2 Accuracy: 0.5769
Epoch 2/120, Loss: 0.8239, Overall Accuracy: 0.6102
Label 0 Accuracy: 0.6277
Label 1 Accuracy: 0.6164
Label 2 Accuracy: 0.5858
Epoch 3/120, Loss: 0.8169, Overall Accuracy: 0.6125
Label 0 Accuracy: 0.6320
Label 1 Accuracy: 0.6158
Label 2 Accuracy: 0.5887
Epoch 4/120, Loss: 0.8126, Overall Accuracy: 0.6138
Label 0 Accuracy: 0.6381
Label 1 Accuracy: 0.6192
Label 2 Accuracy: 0.5830
Epoch 5/120, Loss: 0.8098, Overall Accuracy: 0.6152
Label 0 Accuracy: 0.6376
Label 1 Accuracy: 0.6206
Label 2 Accuracy: 0.5862
Epoch 6/120, Loss: 0.8070, Overall Accuracy: 0.6156
Label 0 Accuracy: 0.6402
Label 1 Accuracy: 0.6215
Label 2 Accuracy: 0.5840
Epoch 7/120, Loss: 0.8051, Overall Accuracy: 0.6179
Label 0 Accuracy: 0.6444
Label 1 Accuracy: 0.6224
Label 2 Accuracy: 0.5856
Epoch 8/120, Loss: 0.8033, Overall Accuracy: 0.6186
Label 0 Accuracy: 0.6422
Label 1 Accuracy: 0.6225
Label 2 A

30001it [01:27, 340.99it/s]                           


Accuracy: 0.6153


30001it [01:28, 338.96it/s]                           


Accuracy: 0.5608


300001it [22:09, 225.73it/s]                            


Epoch 1/120, Loss: 0.9361, Overall Accuracy: 0.5597
Label 0 Accuracy: 0.7147
Label 1 Accuracy: 0.5400
Label 2 Accuracy: 0.3759
Epoch 2/120, Loss: 0.9012, Overall Accuracy: 0.5720
Label 0 Accuracy: 0.7250
Label 1 Accuracy: 0.5570
Label 2 Accuracy: 0.3855
Epoch 3/120, Loss: 0.8928, Overall Accuracy: 0.5757
Label 0 Accuracy: 0.7330
Label 1 Accuracy: 0.5621
Label 2 Accuracy: 0.3818
Epoch 4/120, Loss: 0.8874, Overall Accuracy: 0.5772
Label 0 Accuracy: 0.7313
Label 1 Accuracy: 0.5618
Label 2 Accuracy: 0.3897
Epoch 5/120, Loss: 0.8840, Overall Accuracy: 0.5778
Label 0 Accuracy: 0.7322
Label 1 Accuracy: 0.5607
Label 2 Accuracy: 0.3918
Epoch 6/120, Loss: 0.8810, Overall Accuracy: 0.5779
Label 0 Accuracy: 0.7318
Label 1 Accuracy: 0.5584
Label 2 Accuracy: 0.3953
Epoch 7/120, Loss: 0.8788, Overall Accuracy: 0.5790
Label 0 Accuracy: 0.7342
Label 1 Accuracy: 0.5598
Label 2 Accuracy: 0.3941
Epoch 8/120, Loss: 0.8767, Overall Accuracy: 0.5787
Label 0 Accuracy: 0.7323
Label 1 Accuracy: 0.5572
Label 2 A

30001it [02:16, 220.14it/s]                           


Accuracy: 0.5953


30001it [02:14, 223.11it/s]                           


Accuracy: 0.5182


In [None]:
print(results)

{4: {'above threshold': 0.7620285219245635, 'below threshold': 0.7815622161671208, 'all': 0.6368836291913215}, 1: {'above threshold': 0.6119611263393969, 'below threshold': 0.7461061798693687, 'all': 0.5867644799508372}, 2: {'above threshold': 0.712832077443485, 'below threshold': 0.7211507028440667, 'all': 0.6332010400985356}, 3: {'above threshold': 0.8062483929030599, 'below threshold': 0.8038863791076181, 'all': 0.6753586630931736}}


In [None]:
def generate_latex_table(data):
    latex_table = "\\begin{table}[h!]\n\\centering\n\\begin{tabular}{|c|c|c|c|}\n"
    latex_table += "\\hline\n"
    latex_table += "\\textbf{Seq} & \\textbf{All} & \\textbf{\\textless T=0.02} & \\textbf{\\textgreater T=0.02} \\\\\n"
    latex_table += "\\hline\n"

    for seq_len, values in data.items():
        above_threshold_len = values['below threshold']
        noised_share_len = values['above threshold']
        all_utterances = values['all']
        latex_table += f"{seq_len} & {all_utterances:.3f} & {above_threshold_len:.3f} & {noised_share_len:.3f} \\\\\n"
        latex_table += "\\hline\n"

    latex_table += "\\end{tabular}\n\\caption{Collectors Diagnostic classifier results}\n\\label{table:collectors_classifier}\n\\end{table}\n"
    return latex_table

# Generate and print the LaTeX table
latex_table = generate_latex_table(results)
print(latex_table)

\begin{table}[h!]
\centering
\begin{tabular}{|c|c|c|c|}
\hline
\textbf{Seq} & \textbf{All} & \textbf{\textless T=0.02} & \textbf{\textgreater T=0.02} \\
\hline
4 & 0.637 & 0.782 & 0.762 \\
\hline
1 & 0.587 & 0.746 & 0.612 \\
\hline
2 & 0.633 & 0.721 & 0.713 \\
\hline
3 & 0.675 & 0.804 & 0.806 \\
\hline
\end{tabular}
\caption{Collectors Diagnostic classifier results}
\label{table:collectors_classifier}
\end{table}

