In [122]:
from Continuous_Language.Environments.Collectors.collectors import Collectors
from Continuous_Language.Reinforcement_Learning.Centralized_PPO.multi_ppo import PPO_Multi_Agent_Centralized
from Continuous_Language.Reinforcement_Learning.Decentralized_PPO.util import flatten_list, reverse_flatten_list_with_agent_list
import torch
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
import copy

In [123]:
def make_env(sequence_length=0):
    vocab_size = 4
    max_episode_steps = 2048
    env = Collectors(width=20, height=20, vocab_size=vocab_size, sequence_length=sequence_length, max_timesteps=max_episode_steps, timestep_countdown=15)
    # env = ParallelFrameStack(env, 4)
    return env

In [124]:
def load(path="models/checkpoints", vocab_size=3):
    models = {}
    for model in os.listdir(path):
        if "pt" in model:
            sequence_length = model.split("_")[-1]
            sequence_length = int(sequence_length.split(".")[0])
            if sequence_length > 0:
                env = make_env(sequence_length=sequence_length)
                state_dict = torch.load(os.path.join(path, model))
                try:
                    agent = PPO_Multi_Agent_Centralized(env, device="cpu")
                    agent.agent.load_state_dict(state_dict)
                    models[sequence_length] = agent
                except:
                    continue
    return models

In [125]:
def perturbation(inputs, model, vocab_size, sequence_length):
    
    # Extract environment inputs
    environment_inputs = inputs[:, :-1 * vocab_size * sequence_length]

    # Extract original logits
    inputs = torch.tensor(inputs, dtype=torch.float32)
    original_logits = model(inputs)
    original_logits = F.softmax(original_logits, dim=1).detach().numpy()
    original_logits = F.log_softmax(torch.tensor(original_logits), dim=1).detach()

    perturbation_logits = []
    for token in range(vocab_size):
        # One-hot encoded sequence of tokens
        utterances = np.array([token for _ in range(sequence_length)])
        utterances = np.eye(vocab_size)[utterances].flatten()
        utterances = np.expand_dims(utterances, axis=0)
        utterances = np.repeat(utterances, inputs.shape[0], axis=0)

        # Concatenate environment inputs with utterances
        perturbation_inputs = np.concatenate((environment_inputs, utterances), axis=1)
        perturbation_inputs = torch.tensor(perturbation_inputs, dtype=torch.float32)

        # Get logits for perturbed inputs
        current_logits = model(perturbation_inputs).detach().numpy()
        current_logits = F.softmax(torch.tensor(current_logits), dim=1).detach().numpy()

        perturbation_logits.append(current_logits)

    divergences = []
    for input_array in perturbation_logits:
        kl_divergences = []
        for i in range(len(input_array)):
            q = F.softmax(torch.tensor(input_array[i]), dim=0)
            kl_div = F.kl_div(original_logits, q, reduction='batchmean').item()
            kl_divergences.append(kl_div)

        divergences.append(kl_divergences)
    max_divergences = np.max(divergences, axis=0)
    return max_divergences

In [126]:
from tqdm import tqdm

def test_perturbation(env, agent, threshold=0.02, number_samples=30000):
    language_importances = []
    obs, info = env.reset()
    state = env.state()
    average_length = []
    tokens = []
    data = {"player_1": [], "player_2": [], "player_1_obs": [], "player_2_obs":[], "player_1_direction": [], "player_2_direction": [], "target_1": [], "target_2": [], "target_3":[]}

    with tqdm(total=number_samples) as pbar:
        timestep = 0
        while True:
            for cur_agent in env.agents:
                data[cur_agent].append(copy.deepcopy(env.players[cur_agent].position))
                data[cur_agent + "_direction"].append(copy.deepcopy(env.players[cur_agent].direction))
                data[cur_agent + "_obs"].append(copy.deepcopy(obs[cur_agent]))
            for i in range(env.num_targets):
                if len(env.targets) > i:
                    target = env.targets[i].position
                else:
                    target = np.array([-1, -1])
                data["target_" + str(i+1)].append(copy.deepcopy(target))
            obs = [obs]
            state = [state]
            obs = np.array(flatten_list(obs))
            state = np.array(flatten_list(state))
            
            # integrated_grads = smoothgrad(obs_track, agent.agent.actor, 0, sigma=1.0, steps=30)
            language_perturbation = perturbation(obs, agent.agent.actor, env.vocab_size, env.sequence_length)
            language_importances.append(language_perturbation)

            if np.any(language_perturbation > threshold):
                pbar.update(1)
                timestep += 1

            if timestep > number_samples:
                break

            obs = torch.tensor(obs, dtype=torch.float32)
            state = torch.tensor(state, dtype=torch.float32)
            with torch.no_grad():
                actions, _, _, _ = agent.agent.get_action_and_value(obs, state)
                actions = reverse_flatten_list_with_agent_list(actions, agent.agents)

            actions = actions[0]
            actions = {agent: action.cpu().numpy() for agent, action in actions.items()}

            obs, _, truncations, terminations, infos = env.step(actions)
            state = env.state()

            if any([truncations[agent] or terminations[agent] for agent in env.agents]):
                average_length.append(timestep)
                obs, info = env.reset()
                state = env.state()
    return np.array(language_importances), data, average_length

In [127]:
import numpy as np
import math

def generate_training_data(language_importances_indices, data, sequence_length=1, vocab_size=3):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]
    
    inputs = []
    labels = []
    
    pos_indices = paddle_1_indices[3:-6]
    pos_2_indices = [index + 6 for index in pos_indices]

    targets = [np.array(data["target_1"])[pos_indices], np.array(data["target_2"])[pos_indices], np.array(data["target_3"])[pos_indices]]

    distances = [np.linalg.norm(np.array(data["player_1"])[pos_indices] - target, axis=1) - np.linalg.norm(np.array(data["player_1"])[pos_2_indices] - target, axis=1) for target in targets]

    # distances where sum of target coordinates is -2

    distances = np.array([[cur_distance if cur_target.sum() != -2 else np.inf for cur_target, cur_distance in zip(target, distance)]for target, distance in zip(targets, distances)])

    labels = np.argmin(distances, axis=0)

    player_1_obs = np.array(data["player_1_obs"])
    player_1_obs = np.array([player_1_obs[index - 3:index + 3] for index in pos_indices])
    player_1_lang = player_1_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_1_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_1_lang = player_1_lang.reshape(new_shape)

    player_2_obs = np.array(data["player_2_obs"])
    player_2_obs = np.array([player_2_obs[index - 3:index + 3] for index in pos_indices])
    player_2_lang = player_2_obs[:, :, -1 * sequence_length * vocab_size:]
    shape = player_2_lang.shape
    new_shape = (shape[0], shape[1] * shape[2])
    player_2_lang = player_2_lang.reshape(new_shape)

    inputs = np.concatenate((player_1_lang, player_2_lang), axis=1)

    inputs = np.array(inputs)
    labels = np.array(labels)

    print("Label 0", np.count_nonzero(labels == 0))
    print("Label 1", np.count_nonzero(labels == 1))
    print("Label 2", np.count_nonzero(labels == 2))

    return inputs, labels

def generate_training_data_obs(language_importances_indices, data, sequence_length=1, vocab_size=3, length=20, height=20):
    paddle_1_indices = np.where(language_importances_indices[1] == 0)
    paddle_1_indices = language_importances_indices[0][paddle_1_indices]

    pos_indices = paddle_1_indices[3:-6]
    pos_2_indices = [index + 6 for index in pos_indices]

    targets = [np.array(data["target_1"])[pos_indices], np.array(data["target_2"])[pos_indices], np.array(data["target_3"])[pos_indices]]

    distances = [np.linalg.norm(np.array(data["player_1"])[pos_indices] - target, axis=1) - np.linalg.norm(np.array(data["player_1"])[pos_2_indices] - target, axis=1) for target in targets]

    # distances where sum of target coordinates is -2

    distances = np.array([[cur_distance if cur_target.sum() != -2 else np.inf for cur_target, cur_distance in zip(target, distance)]for target, distance in zip(targets, distances)])

    labels = np.argmin(distances, axis=0)

    inputs = np.array(data["player_1_obs"])[pos_indices]

    inputs = np.array(inputs)
    labels = np.array(labels)

    return inputs, labels

In [128]:
models = load("/Users/cowolff/Documents/GitHub/ma.pong_rl/models/checkpoints_collectors_2/models")

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [129]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Step 2: Define the dataset and dataloader
class PositionDataset(Dataset):
    def __init__(self, data, labels, device):
        self.data = torch.tensor(data, dtype=torch.float32, device=device)
        self.labels = torch.tensor(labels, dtype=torch.long, device=device)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Step 3: Define the model architecture
class SimpleClassifier(nn.Module):
    def __init__(self, input_size):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 3)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [130]:
def train(num_epochs, dataloader, input_size, learning_rate, device, weights_array):
    model = SimpleClassifier(input_size).to(device)

    class_weights = torch.tensor([1.0, 10.0, 10.0], device=device)

    # Step 4: Train the model
    criterion = nn.CrossEntropyLoss(weight=weights_array)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.5)

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = [0, 0, 0]
        total_predictions = [0, 0, 0]
        
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            for i in range(3):  # Assuming 3 classes
                correct_predictions[i] += ((predicted == i) & (targets == i)).sum().item()
                total_predictions[i] += (targets == i).sum().item()
        
        epoch_loss = running_loss / len(dataloader)
        epoch_accuracy = sum(correct_predictions) / sum(total_predictions)
        
        # Label-specific accuracy
        label_accuracies = [correct_predictions[i] / total_predictions[i] if total_predictions[i] > 0 else 0 for i in range(3)]
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Overall Accuracy: {epoch_accuracy:.4f}")
        for i in range(3):
            print(f"Label {i} Accuracy: {label_accuracies[i]:.4f}")

    print("Training complete!")
    return epoch_accuracy, model

In [131]:
def test(model, dataloader, device):
    correct_predictions = [0, 0, 0]
    total_predictions = [0, 0, 0]
    with torch.no_grad():
        for inputs, targets in dataloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            for i in range(3):  # Assuming 3 classes
                correct_predictions[i] += ((predicted == i) & (targets == i)).sum().item()
                total_predictions[i] += (targets == i).sum().item()
    
    accuracy = sum(correct_predictions) / sum(total_predictions)
    
    label_accuracies = [correct_predictions[i] / total_predictions[i] if total_predictions[i] > 0 else 0 for i in range(3)]

    print(f"Accuracy: {accuracy:.4f}")
    for i in range(3):
        print(f"Label {i} Accuracy: {label_accuracies[i]:.4f}")
    return accuracy

In [132]:
def calculate_chance_level(labels):
    unique, counts = np.unique(labels, return_counts=True)
    index = np.argmax(counts)

    chance_level = counts[index] / len(labels)
    return chance_level

In [133]:
from collections import Counter
def normalize_label_weights(labels):
    # Count the frequency of each label
    label_counts = Counter(labels)
    
    # Find the minimum frequency (i.e., least common label)
    min_frequency = min(label_counts.values())
    
    # Calculate the normalized weight for each label
    label_weights = {label: min_frequency / count for label, count in label_counts.items()}
    
    # Determine the maximum label number to size the weight array correctly
    max_label = max(labels)
    
    # Initialize a weight array with zeros (or None if preferred)
    weights_array = np.zeros(max_label + 1)
    
    # Assign the calculated weights to the appropriate indices
    for label, weight in label_weights.items():
        weights_array[label] = weight
    
    weights_array = torch.tensor(weights_array, dtype=torch.float32)
    return weights_array

In [None]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)
    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data(language_importances_larger, data, sequence_length=seq, vocab_size=4)
    weights_array = normalize_label_weights(larger_labels)
    dataset = PositionDataset(larger_inputs, larger_labels, "cpu")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, "cpu", weights_array)

    test_language_importances, test_data, test_average_length = test_perturbation(env, agent, number_samples=30000)
    test_language_importances_larger = np.where(test_language_importances > 0.02)
    test_larger_inputs, test_larger_labels = generate_training_data(test_language_importances_larger, test_data, sequence_length=seq, vocab_size=4)
    test_dataset = PositionDataset(test_larger_inputs, test_larger_labels, "cpu")
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
    test_input_size = test_larger_inputs.shape[1]
    test_accuracy = test(model, test_dataloader, "cpu")
    results[seq]["test"] = test_accuracy
    

In [None]:
print(results)

In [None]:
results = {}
for seq, agent in models.items():
    env = make_env(sequence_length=seq)
    results[seq] = {}

    language_importances, data, average_length = test_perturbation(env, agent, number_samples=300000)
    language_importances_larger = np.where(language_importances > 0.02)
    larger_inputs, larger_labels = generate_training_data_obs(language_importances_larger, data, sequence_length=seq, vocab_size=4)
    weights_array = normalize_label_weights(larger_labels)
    dataset = PositionDataset(larger_inputs, larger_labels, "cpu")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    input_size = larger_inputs.shape[1]
    accuracy, model = train(120, dataloader, input_size, 0.001, "cpu", weights_array)
    

    language_importances_test, data_test, average_length_test = test_perturbation(env, agent, number_samples=30000)
    language_importances_larger_test = np.where(language_importances_test > 0.02)
    larger_inputs_test, larger_labels_test = generate_training_data_obs(language_importances_larger_test, data_test, sequence_length=seq, vocab_size=4)
    dataset_test = PositionDataset(larger_inputs_test, larger_labels_test, "cpu")
    dataloader_test = DataLoader(dataset_test, batch_size=32, shuffle=True)
    input_size = larger_inputs_test.shape[1]
    accuracy_test = test(model, dataloader_test, "cpu")
    results[seq]["test"] = accuracy_test

In [None]:
print(results)