In [None]:
import pathlib
import sys

sys.path.append(str(pathlib.Path().absolute().parent))

from src.replay_memory import FastReplayBuffer
from src.env import FourRoomEnvWithTagging
from src.featurizers import PerspectiveFeaturizer, GlobalFeaturizer
from src.visualize import StateSequenceVisualizer

import torch

torch.set_printoptions(precision=3, sci_mode=False, linewidth=200)


In [None]:
BUF_SIZE = 3000
N_IMPOSTERS = 2
N_JOBS = 5
N_CREW = 4
SEQUENCE_SIZE = 2
DEBUG = True

env = FourRoomEnvWithTagging(n_imposters=N_IMPOSTERS, n_crew=N_CREW, n_jobs=N_JOBS, debug=DEBUG)
m = FastReplayBuffer(max_size=BUF_SIZE, state_size=env.flattened_state_size, trajectory_size=SEQUENCE_SIZE, n_agents=env.n_agents, n_imposters=N_IMPOSTERS)
m.populate(env, 1000)

In [None]:
env.state_fields

In [None]:
batch = m.sample(1)

seq_states = batch.states.squeeze(0)
seq_imposters = batch.imposters.squeeze(0)

f = GlobalFeaturizer(env, SEQUENCE_SIZE)
f.fit(seq_states, seq_imposters)
v = StateSequenceVisualizer(f)
v.visualize_perspectives()

In [None]:
from src.models.dqn import SpatialDQN
from torch import nn
import torchinfo

model = SpatialDQN(
    input_image_size=env.n_cols,
    non_spatial_input_size=5,
    n_channels=[7, 9, 9],
    strides=[1, 3, 3],
    paddings=[
        1, 1, 1
    ],
    kernel_sizes=[2, 3, 2],
    rnn_layers=3,
    rnn_hidden_dim=64,
    rnn_dropout=0.2,
    mlp_hidden_layer_dims=[16, 16],
    n_actions=env.n_imposter_actions,
)

torchinfo.summary(model)

In [None]:
model

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
batch = m.sample(1)

seq_states = batch.states.squeeze(0)
seq_imposters = batch.imposters.squeeze(0)

f = GlobalFeaturizer(env, SEQUENCE_SIZE)
f.fit(seq_states, seq_imposters)

# s: b, s, c, x, y

for s, np in f.generator():

    print(s.shape)
    print(np.shape)


    model(s, np)
