In [None]:
import pathlib
import sys
from gymnasium import spaces

sys.path.append(str(pathlib.Path().absolute().parent))

from src.replay_memory import FastReplayBuffer
from src.env import FourRoomEnvWithTagging


In [None]:
BUF_SIZE = 3000
env = FourRoomEnvWithTagging(n_imposters=2, n_crew=4, n_jobs=5)
m = FastReplayBuffer(max_size=BUF_SIZE, state_size=env.flattened_state_size, trajectory_size=3, n_agents=env.n_agents, n_imposters=2)
m.populate(env, 1000)

In [None]:
env.state_fields

In [None]:
from src.featurizers import SequenceStateFeaturizer
from src.visualize import SequenceStateVisualizer

batch = m.sample(1)

seq_states = batch.states.squeeze(0)
seq_imposters = batch.imposters.squeeze(0)

f = SequenceStateFeaturizer(env, seq_states, seq_imposters)
v = SequenceStateVisualizer(f)
v.visualize_sequence()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

feature_gen = f.generator()

def visualize_sequence(spatial_features, agent_idx):
    n_seq, n_agents, _, __ = spatial_features.size()
    fig, ax = plt.subplots(n_seq, n_agents, figsize=(n_agents * 5, n_seq * 5))

    agent_labels = np.arange(n_agents)
    agent_labels[0] = agent_idx
    if agent_idx > 0:
        agent_labels[agent_idx] = agent_idx - 1

    for seq in range(n_seq):
        for i, rep in enumerate(torch.unbind(spatial_features[seq, ...], dim=0)):
            rep = np.flipud(rep.t().numpy())
            ax[seq][i].imshow(rep)
            ax[seq][i].set_title(f"Agent {agent_labels[i]}")
        
    fig.suptitle(f"Agent {agent_idx}'s Perspective", fontsize=22)
    plt.tight_layout()
    plt.show()


for i in range(env.n_agents):
    spatial_features, non_spacial_features = next(feature_gen)
    visualize_sequence(spatial_features, i)
    print(non_spacial_features)
