In [146]:
from Continuous_Language.Environments.Multi_Pong.multi_pong import PongEnv
from Continuous_Language.Reinforcement_Learning.Centralized_PPO.multi_ppo import PPO_Multi_Agent_Centralized
from Continuous_Language.Reinforcement_Learning.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
import torch
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import copy

In [147]:
def make_env(max_episode_steps = 1024, sequence_length = 1, vocab_size = 3):
    env = PongEnv(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_episode_steps=max_episode_steps)
    # env = ParallelFrameStack(env, 4)
    return env

In [148]:
def load(path="models/checkpoints", sequence_length=1, vocab_size=3):
    env = make_env(sequence_length=sequence_length, vocab_size=vocab_size)
    models = {}
    for model in os.listdir(path):
        if "pong" in model:
            state_dict = torch.load(os.path.join(path, model))
            timestamp = model.split("_")[-1]
            timestamp = int(timestamp.split(".")[0])
            agent = PPO_Multi_Agent_Centralized(env, device="cpu")
            agent.agent.load_state_dict(state_dict)
            models[timestamp] = agent
    return models

In [149]:
def perturbation(inputs, model, vocab_size, sequence_length):
    
    # Extract environment inputs
    environment_inputs = inputs[:, :-1 * vocab_size * sequence_length]

    # Extract original logits
    inputs = torch.tensor(inputs, dtype=torch.float32)
    original_logits = model(inputs)
    original_logits = F.softmax(original_logits, dim=1).detach().numpy()
    original_logits = F.log_softmax(torch.tensor(original_logits), dim=1).detach()

    perturbation_logits = []
    for token in range(vocab_size):
        # One-hot encoded sequence of tokens
        utterances = np.array([token for _ in range(sequence_length)])
        utterances = np.eye(vocab_size)[utterances].flatten()
        utterances = np.expand_dims(utterances, axis=0)
        utterances = np.repeat(utterances, inputs.shape[0], axis=0)

        # Concatenate environment inputs with utterances
        perturbation_inputs = np.concatenate((environment_inputs, utterances), axis=1)
        perturbation_inputs = torch.tensor(perturbation_inputs, dtype=torch.float32)

        # Get logits for perturbed inputs
        current_logits = model(perturbation_inputs).detach().numpy()
        current_logits = F.softmax(torch.tensor(current_logits), dim=1).detach().numpy()

        perturbation_logits.append(current_logits)

    divergences = []
    for input_array in perturbation_logits:
        kl_divergences = []
        for i in range(len(input_array)):
            q = F.softmax(torch.tensor(input_array[i]), dim=0)
            kl_div = F.kl_div(original_logits, q, reduction='batchmean').item()
            kl_divergences.append(kl_div)

        divergences.append(kl_divergences)
    max_divergences = np.max(divergences, axis=0)
    return max_divergences

In [150]:
from tqdm import tqdm

def test_perturbation(env, agent, threshold=0.02, number_samples=30000):
    language_importances = []
    obs, info = env.reset()
    state = env.state()
    average_length = []
    tokens = []
    data = {"paddle_1": [], "paddle_2": [], "paddle_1_obs": [], "paddle_2_obs": [], "ball_1_pos": [], "ball_2_pos": []}

    with tqdm(total=number_samples) as pbar:
        timestep = 0
        while True:
            for cur_agent in env.agents:
                data[cur_agent].append(copy.deepcopy(env.paddles[cur_agent]))
                data[cur_agent + "_obs"].append(copy.deepcopy(obs[cur_agent]))
            for cur_ball in env.balls.keys():
                data[cur_ball + "_pos"].append(copy.deepcopy(env.balls[cur_ball]["position"]))
            obs = [obs]
            state = [state]
            obs = np.array(flatten_list(obs))
            state = np.array(flatten_list(state))
            
            # integrated_grads = smoothgrad(obs_track, agent.agent.actor, 0, sigma=1.0, steps=30)
            language_perturbation = perturbation(obs, agent.agent.actor, env.vocab_size, env.sequence_length)
            language_importances.append(language_perturbation)

            if np.any(language_perturbation > threshold):
                pbar.update(1)
                timestep += 1

            if timestep > number_samples:
                break

            obs = torch.tensor(obs, dtype=torch.float32)
            state = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                actions, _, _, _ = agent.agent.get_action_and_value(obs, state)
                actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

            actions = actions[0]
            actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

            obs, _, truncations, terminations, infos = env.step(actions)
            state = env.state()

            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                average_length.append(env.timestep)
                obs, info = env.reset()
                state = env.state()
    return np.array(language_importances), data, average_length

In [151]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Step 2: Define the dataset and dataloader
class PositionDataset(Dataset):
    def __init__(self, data, labels, device):
        self.data = torch.tensor(data, dtype=torch.float32, device=device)
        self.labels = torch.tensor(labels, dtype=torch.long, device=device)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Step 3: Define the model architecture
class SimpleClassifier(nn.Module):
    def __init__(self, input_size):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 2)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [152]:
def train(num_epochs, dataloader, input_size, learning_rate, device):
    
    model = SimpleClassifier(input_size).to(device)

    # Step 4: Train the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.5)

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        
        for inputs, targets in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
        
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = correct_predictions / total_predictions

        # scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

    print("Training complete!")
    return epoch_accuracy, model

In [153]:
def test(model, dataloader, device):
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
    
    accuracy = correct_predictions / total_predictions
    print(f"Accuracy: {accuracy:.4f}")
    return accuracy

In [154]:
results = {}

In [155]:
def calculate_chance_level(labels):
    unique, counts = np.unique(labels, return_counts=True)
    index = np.argmax(counts)

    chance_level = counts[index] / len(labels)
    return chance_level

In [156]:
import math
from tqdm import tqdm


def generate_training_data(language_importances_indices, data, sequence_length=1, vocab_size=3):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices][3:-3]
    
    inputs = []
    labels = []

    labels = np.zeros(len(paddle_1_indices))
    labels = np.where(np.array(data["paddle_1"])[paddle_1_indices] > np.array(data["paddle_2"])[paddle_1_indices], 1, 0)

    paddle_1_obs = np.array(data["paddle_1_obs"])
    paddle_1_obs = np.array([paddle_1_obs[index - 3:index + 3] for index in paddle_1_indices])
    player_1_lang = paddle_1_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_1_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_1_lang = player_1_lang.reshape(new_shape)

    paddle_2_obs = np.array(data["paddle_2_obs"])
    paddle_2_obs = np.array([paddle_2_obs[index - 3:index + 3] for index in paddle_1_indices])
    player_2_lang = paddle_2_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_2_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_2_lang = player_2_lang.reshape(new_shape)

    inputs = np.concatenate((player_1_lang, player_2_lang), axis=1)

    inputs = np.array(inputs)
    labels = np.array(labels)

    return inputs, labels

def generate_training_data_obs(language_importances_indices, data, sequence_length=1, vocab_size=3, length=20, height=20):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]

    labels = np.zeros_like(paddle_1_indices)
    labels = np.where(np.array(data["paddle_1"])[paddle_1_indices] > np.array(data["paddle_2"])[paddle_1_indices], 1, 0)

    inputs = np.array(data["paddle_1_obs"])[paddle_1_indices]

    return inputs, labels

In [157]:
import os
directories = os.listdir("/Users/cowolff/Documents/GitHub/ma.pong_rl/Plotting/saliencies_live/Multi_Pong/")
directories = [model for model in directories if ".DS_Store" not in model]
models = {}
for directory in directories:
    sequence_length = int(directory.split("_")[-1])
    agents = load(f"/Users/cowolff/Documents/GitHub/ma.pong_rl/Plotting/saliencies_live/Multi_Pong/{directory}", sequence_length=sequence_length)
    agent_indizes = list(agents.keys())
    agent_indizes.sort()
    models[sequence_length] = agents[agent_indizes[-1]]

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [158]:
device = "mps"

In [None]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)

    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data(language_importances_larger, data, sequence_length=seq, vocab_size=3)
    dataset = PositionDataset(larger_inputs, larger_labels, device)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, device)
    results[seq]["above threshold"] = accuracy

    test_importances, test_data, test_average_length = test_perturbation(env, agent, threshold=0.02, number_samples=10000)
    test_importances_larger = np.where(test_importances > 0.02)
    test_inputs, test_labels = generate_training_data(test_importances_larger, test_data, sequence_length=seq, vocab_size=3)
    test_dataset = PositionDataset(test_inputs, test_labels, device)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    accuracy = test(model, test_dataloader, device)
    print(calculate_chance_level(test_labels))
    

In [None]:
print(results)

In [None]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)
    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data_obs(language_importances_larger, data, sequence_length=seq, vocab_size=4)
    dataset = PositionDataset(larger_inputs, larger_labels, "cpu")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, "cpu")

    test_language_importances, test_data, test_average_length = test_perturbation(env, agent, number_samples=30000)
    test_language_importances_larger = np.where(test_language_importances > 0.02)
    test_larger_inputs, test_larger_labels = generate_training_data_obs(test_language_importances_larger, test_data, sequence_length=seq, vocab_size=4)
    test_dataset = PositionDataset(test_larger_inputs, test_larger_labels, "cpu")
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    test_input_size = test_larger_inputs.shape[1]
    test_accuracy = test(model, test_dataloader, "cpu")
    results[seq]["test"] = test_accuracy

In [None]:
print(results)