In [None]:
import pathlib
import sys
import torch
root_path = pathlib.Path().absolute().parent
sys.path.append(str(root_path))

from src.train import run_experiment, OptimizerType
from src.models.dqn import ModelType
from src.environment.pred_prey import ImposterTrainingGround
from src.features.model_ready import FeaturizerType
from src.features.component import CoordinateAgentPositionsFeaturizer, CompositeFeaturizer, OneHotAgentPositionFeaturizer
from src.visualize import plot_experiment_metrics, setup_experiment_buttons


torch.set_printoptions(precision=3, sci_mode=False, linewidth=200)

In [None]:
BUF_SIZE = 300_000
N_IMPOSTERS = 1
N_JOBS = 0
N_CREW = 1
SEQUENCE_SIZE = 1

model_registry_path = root_path / 'model_registry'
model_registry_path.mkdir(exist_ok=True)

one_on_one_path = model_registry_path / '1v1'
one_on_one_path.mkdir(exist_ok=True)

featurizer_type = FeaturizerType.FLAT

experiments = [
    'no_wall_coord_features',
    'wall_coord_features',
    'one_hot_no_wall',
    'one_hot_wall',
]

envs = {
    'Wall': ImposterTrainingGround(n_crew=N_CREW, n_jobs=N_JOBS, debug=False, kill_reward=-3, sabotage_reward=0, end_of_game_reward=0, time_step_reward=0),
    'No Wall': ImposterTrainingGround(n_crew=N_CREW, n_jobs=N_JOBS, debug=False, kill_reward=-3, sabotage_reward=0, end_of_game_reward=0, time_step_reward=0, include_walls=False), 
}

gammas = {
    'no_wall_coord_features': [0.9],
    'wall_coord_features': [0.9],
    'one_hot_no_wall': [0.9],
    'one_hot_wall': [0.99, 0.9, 0.8],
}

featurizers = {
    'no_wall_coord_features': FeaturizerType.build(FeaturizerType.FLAT, envs['No Wall'], featurizers=CompositeFeaturizer([CoordinateAgentPositionsFeaturizer(envs['No Wall'])])),
    'wall_coord_features': FeaturizerType.build(FeaturizerType.FLAT, envs['Wall'], featurizers=CompositeFeaturizer([CoordinateAgentPositionsFeaturizer(envs['Wall'])])),
    'one_hot_no_wall': FeaturizerType.build(FeaturizerType.FLAT, envs['No Wall'], featurizers=CompositeFeaturizer([OneHotAgentPositionFeaturizer(envs['No Wall'])])),
    'one_hot_wall': FeaturizerType.build(FeaturizerType.FLAT, envs['Wall'], featurizers=CompositeFeaturizer([OneHotAgentPositionFeaturizer(envs['Wall'])])),
}

EXPERIMENTS_TO_RUN = experiments

configs = []

for experiment in EXPERIMENTS_TO_RUN:
    for j, gamma in enumerate(gammas[experiment]):
        env_idx = 'Wall' if 'wall' in experiment.lower() else 'No Wall'
        config = {
            'env': featurizers[experiment].env,
            'num_steps': 1_500_000,
            'imposter_model_args': {
                'layer_dims': [featurizers[experiment].featurized_shape[1].item()] + [256, 128, 64, 16] + [featurizers[experiment].env.n_imposter_actions],
            },
            'crew_model_args': {'n_actions': featurizers[experiment].env.n_crew_actions},
            'imposter_model_type': ModelType.MLP,
            'crew_model_type': ModelType.RANDOM,
            'featurizer': featurizers[experiment],
            'sequence_length': SEQUENCE_SIZE,
            'replay_buffer_size': BUF_SIZE,
            'replay_prepopulate_steps': 50_000,
            'batch_size': 8,
            'gamma': gamma,
            'scheduler_start_eps': 1,
            'scheduler_end_eps': 0.05,
            'scheduler_time_steps': 1_000_000,
            'train_imposter': True,
            'train_crew': False,
            'experiment_base_dir': one_on_one_path / experiment,
            'optimizer_type': OptimizerType.ADAM,
            'learning_rate': 0.001,
            'train_step_interval': 5,
            'num_checkpoint_saves': 5,
            
        }
        configs.append(config)


In [None]:
print('\n'.join([str(config) for config in configs]))

In [None]:
# total = len(configs)
# for i, config in enumerate(configs):
#     print(f'Running experiment {i+1}/{total}')
#     run_experiment(**config)

In [None]:
from src.visualize import plot_episode_lengths


separator_strings = ['no_wall', 'wall']
titles = ['Env No Wall', 'Env Wall']
plot_episode_lengths(one_on_one_path, separator_strings, titles)

In [None]:
# for experiment in experiments:
#     path = one_on_one_path / experiment
#     plot_experiment_metrics(path, label_attr='gamma', label_name="$\\gamma$")

In [None]:
setup_experiment_buttons(model_registry_path / '1v1', 'gamma', experiments, featurizers)
