In [None]:
import pathlib
import json
import sys
import ipywidgets as widgets
import torch
root_path = pathlib.Path().absolute().parent
sys.path.append(str(root_path))

from src.train import run_experiment, OptimizerType, run_game
from src.models.dqn import ModelType
from src.environment.pred_prey import ImposterTrainingGround
from src.features.model_ready import FeaturizerType
from src.features.component import CoordinateAgentPositionsFeaturizer, CompositeFeaturizer, OneHotAgentPositionFeaturizer
from src.visualize import plot_experiment_metrics


torch.set_printoptions(precision=3, sci_mode=False, linewidth=200)

In [None]:
BUF_SIZE = 300_000
N_IMPOSTERS = 1
N_JOBS = 0
N_CREW = 1
SEQUENCE_SIZE = 1

model_registry_path = root_path / 'model_registry'
model_registry_path.mkdir(exist_ok=True)

one_on_one_path = model_registry_path / '1v1'
one_on_one_path.mkdir(exist_ok=True)

featurizer_type = FeaturizerType.FLAT

experiments = [
    'no_wall_coord_features',
    'wall_coord_features',
    'one_hot_no_wall',
    'one_hot_wall',
]

envs = {
    'Wall': ImposterTrainingGround(n_crew=N_CREW, n_jobs=N_JOBS, debug=False, kill_reward=-3, sabotage_reward=0, end_of_game_reward=0, time_step_reward=0),
    'No Wall': ImposterTrainingGround(n_crew=N_CREW, n_jobs=N_JOBS, debug=False, kill_reward=-3, sabotage_reward=0, end_of_game_reward=0, time_step_reward=0, include_walls=False), 
}

gammas = {
    'no_wall_coord_features': [0.9],
    'wall_coord_features': [0.9],
    'one_hot_no_wall': [0.9],
    'one_hot_wall': [0.99, 0.9, 0.8],
}

featurizers = {
    'no_wall_coord_features': FeaturizerType.build(FeaturizerType.FLAT, envs['No Wall'], featurizers=CompositeFeaturizer([CoordinateAgentPositionsFeaturizer(envs['No Wall'])])),
    'wall_coord_features': FeaturizerType.build(FeaturizerType.FLAT, envs['Wall'], featurizers=CompositeFeaturizer([CoordinateAgentPositionsFeaturizer(envs['Wall'])])),
    'one_hot_no_wall': FeaturizerType.build(FeaturizerType.FLAT, envs['No Wall'], featurizers=CompositeFeaturizer([OneHotAgentPositionFeaturizer(envs['No Wall'])])),
    'one_hot_wall': FeaturizerType.build(FeaturizerType.FLAT, envs['Wall'], featurizers=CompositeFeaturizer([OneHotAgentPositionFeaturizer(envs['Wall'])])),
}

experiments_to_run = [
    'wall_coord_features',
    'one_hot_wall',
]

configs = []

for experiment in experiments:
    for j, gamma in enumerate(gammas[experiment]):
        env_idx = 'Wall' if 'wall' in experiment.lower() else 'No Wall'
        config = {
            'env': featurizers[experiment].env,
            'num_steps': 1_500_000,
            'imposter_model_args': {
                'layer_dims': [featurizers[experiment].featurized_shape[1].item()] + [256, 128, 64, 16] + [featurizers[experiment].env.n_imposter_actions],
            },
            'crew_model_args': {'n_actions': featurizers[experiment].env.n_crew_actions},
            'imposter_model_type': ModelType.MLP,
            'crew_model_type': ModelType.RANDOM,
            'featurizer': featurizers[experiment],
            'sequence_length': SEQUENCE_SIZE,
            'replay_buffer_size': BUF_SIZE,
            'replay_prepopulate_steps': 50_000,
            'batch_size': 8,
            'gamma': gamma,
            'scheduler_start_eps': 1,
            'scheduler_end_eps': 0.05,
            'scheduler_time_steps': 1_000_000,
            'train_imposter': True,
            'train_crew': False,
            'experiment_base_dir': one_on_one_path / experiment,
            'optimizer_type': OptimizerType.ADAM,
            'learning_rate': 0.001,
            'train_step_interval': 5,
            'num_checkpoint_saves': 5,
            
        }
        configs.append(config)


In [None]:
print('\n'.join([str(config) for config in configs]))

In [None]:
# total = len(configs)
# for i, config in enumerate(configs):
#     print(f'Running experiment {i+1}/{total}')
#     run_experiment(**config)

In [None]:
# imposter_model_config = {
#     # "layer_dims": [imput_dim, 256,16, 16, 16, env.n_imposter_actions],
#     'pretrained_model_path': './model_registry/1v1_imposter_no_walls/2024-04-24_11-19-34/imposter_mlp_100%.pt'
# }

# imposter_model = ModelType.build(ModelType.MLP, **imposter_model_config)
# crew_model = ModelType.build(ModelType.RANDOM, **{'n_actions': env.n_crew_actions})

# featurizer = FeaturizerType.build(FeaturizerType.FLAT, env)

# run_game(env, imposter_model, crew_model, featurizer, sequence_length=1)



In [None]:
# plot_experiment_metrics(model_registry_path / '1v1_one_hot_wall', label_attr='gamma', label_name="$\\gamma$")
# plot_experiment_metrics(model_registry_path / '1v1_one_hot_no_wall', label_attr='gamma', label_name="$\\gamma$")
# plot_experiment_metrics(model_registry_path / '1v1_wall_coord_features', label_attr='gamma', label_name="$\\gamma$")
# plot_experiment_metrics(model_registry_path / '1v1_no_wall_coord_features', label_attr='gamma', label_name="$\\gamma$")

In [None]:
import ipywidgets as widgets
from IPython.display import display

def setup_experiment_buttons(base_path, identifier_attribute):
    experiments = [p for p in base_path.iterdir() if p.is_dir()]

    # Find the maximum width needed for the experiment names
    max_exp_name_length = max(len(exp.name) for exp in experiments)
    min_width = max_exp_name_length * 8  # Approximate width allocation per character
    
    def button_callback(button, buttons_container):
        # Disable all buttons within the same container to prevent multiple runs
        for b in buttons_container.children:
            b.disabled = True
        exp_idx = int(button.exp_idx)  # Custom attribute to hold the experiment index
        print(f"Running experiment configuration {button.description} of experiment {exp_idx}")
    
    for exp_idx, exp in enumerate(experiments):
        buttons = []
        for version_dir in sorted(exp.iterdir(), key=lambda x: x.name):
            config_path = version_dir / 'config.json'
            if not config_path.exists():
                continue
            
            with open(config_path, 'r') as file:
                config = json.load(file)
            
            button_label = f"{identifier_attribute}={config.get(identifier_attribute, 'Unknown')}"
            button = widgets.Button(description=button_label)
            button.exp_idx = exp_idx  # Store experiment index in button for reference
            buttons.append(button)
        
        if buttons:
            # Set a fixed width for the label to align them
            exp_label = widgets.Label(value=f"Experiment: {exp.name}", 
                                      layout=widgets.Layout(min_width=f'{min_width}px'))
            buttons_box = widgets.HBox(buttons)
            for button in buttons:
                button.on_click(lambda b: button_callback(b, buttons_box))
            row = widgets.HBox([exp_label, buttons_box], layout=widgets.Layout(align_items='center'))
            display(row)

# Example usage:
setup_experiment_buttons(model_registry_path / '1v1', 'gamma')
