In [1]:
import numpy as np
import torch

from src.utils import flatten
from src.wrapper import RestrictionWrapper
from examples.agents.td3 import TD3
from examples.envs.navigation import NavigationEnvironment
from examples.restrictors.navigation_restrictor import NavigationRestrictor
from examples.utils import ReplayBuffer

In [2]:
env_config = {
    'STEPS_PER_EPISODE': 40,
    'ACTION_RANGE': 220,
    'DT': 1.0,
    'REWARD': {
        'TIMESTEP_PENALTY_COEFFICIENT': 0.05,
        'REWARD_COEFFICIENT': 0.5,
        'GOAL': 5,
        'COLLISION': -1.0
    },
    'MAP': {
        'HEIGHT': 15.0,
        'WIDTH': 15.0,
        'AGENT': {'x': 1.0, 'y': 1.0, 'angle': 90.0, 'step_size': 1.0, 'radius': 0.4},
        'GOAL': {'x': 12.0, 'y': 12.0, 'radius': 0.5}
    }
}
restrictor = NavigationRestrictor(0, [[5.0, 0.0], [0.0, 5.0]], 1.0, 0.2, 0.5, 50, 8, -110.0, 110.0)
environment = NavigationEnvironment(env_config)

restricted_environment = RestrictionWrapper(environment, restrictor)

In [3]:
td3_config = {
    'state_dim': 22,
    'action_dim': 1,
    'max_action': 110.0,
    'discount': 0.99,
    'tau': 0.005,
    'policy_noise': 0.2,
    'noise_clip:': 0.5,
    'policy_freq': 2,
    'exploration_noise': 0.05,
    'batch_size': 256,
    'train_after_timesteps': 2000
}

total_timesteps = 5000
td3 = TD3(**td3_config)
replay_buffer = ReplayBuffer(state_dim=td3_config['state_dim'], action_dim=td3_config['action_dim'])

In [4]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

In [7]:
episode_num = 0
training_timesteps = 0

while training_timesteps < total_timesteps:
    restricted_environment.reset()
    episode_reward = 0
    episode_timesteps = 0
    episode_num += 1
    observation = None
    action = None
    last_td3_action = None

    for agent in restricted_environment.agent_iter():
        next_observation, reward, termination, truncation, info = restricted_environment.last()

        # None action if episode is done
        if termination or truncation:
            action = None
        # Turn of the agent
        elif agent == 'agent_0':
            episode_reward += reward
            episode_timesteps += 1

            flattened_next_observation = flatten(restricted_environment.observation_space('agent_0'),
                                                 next_observation, max_len=8, raise_error=False)

            if episode_timesteps > 1:
                replay_buffer.add(observation,
                                  last_td3_action,
                                  flattened_next_observation,
                                  reward,
                                  termination)
            observation = flattened_next_observation

            training_timesteps += 1
            if training_timesteps < td3_config['train_after_timesteps']:
                    action = restricted_environment.action_space(agent).sample()
            else:
                det_action = td3.select_action(observation)
                noise = np.random.normal(0, td3_config['max_action'] * td3_config['exploration_noise'], size=td3_config['action_dim'])
                action = (det_action + noise).clip(-td3_config['max_action'], td3_config['max_action'])

            if training_timesteps >= td3_config['train_after_timesteps']:
                td3.train(replay_buffer, td3_config['batch_size'])
            last_td3_action = action
        # Or restrictor
        else:
            action = restrictor.act(next_observation)
        restricted_environment.step(action)

    print(f'Finished episode {episode_num} with reward {episode_reward}')

restricted_environment.close()

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


KeyboardInterrupt: 