From ff6b7d644b9a9e182684c8791c914983382589e4 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 20 Dec 2021 19:19:58 -0800 Subject: [PATCH] Made gym-microrts 0.5.0 work by downgrading SB3 to 1.1 --- README.md | 2 +- ppo_gridnet_diverse_encode_decode_sb3.py | 51 +++++------------------- requirements.txt | 4 +- 3 files changed, 12 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 29feb5f..4d965ab 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Original code: [gym-microrts-paper](https://github.com/vwxyzjn/gym-microrts-pape ## Install Prerequisites: -* Python 3.7+ +* Python 3.7.1+ * Java 8.0+ * FFmpeg (for video capturing) diff --git a/ppo_gridnet_diverse_encode_decode_sb3.py b/ppo_gridnet_diverse_encode_decode_sb3.py index 4a3f4eb..3aef546 100644 --- a/ppo_gridnet_diverse_encode_decode_sb3.py +++ b/ppo_gridnet_diverse_encode_decode_sb3.py @@ -8,48 +8,13 @@ from stable_baselines3.common.policies import ActorCriticPolicy, register_policy from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from stable_baselines3.common.utils import get_device -from stable_baselines3.common.vec_env import VecEnvWrapper +from stable_baselines3.common.vec_env import VecMonitor import gym import gym_microrts from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv from gym_microrts import microrts_ai -# code from here: -# https://github.com/vwxyzjn/gym-microrts-paper/blob/cf291b303c04e98be2f00acbbe6bbb2c23a8bac5/ppo_gridnet_diverse_encode_decode.py#L96 -class VecMonitor(VecEnvWrapper): - def __init__(self, venv): - VecEnvWrapper.__init__(self, venv) - self.eprets = None - self.eplens = None - self.epcount = 0 - self.tstart = time.time() - - def reset(self): - obs = self.venv.reset() - self.eprets = np.zeros(self.num_envs, 'f') - self.eplens = np.zeros(self.num_envs, 'i') - return obs - - def step_wait(self): - obs, rews, dones, infos = self.venv.step_wait() - self.eprets += rews - self.eplens += 1 - - newinfos = list(infos[:]) - for i in range(len(dones)): - if dones[i]: - info = infos[i].copy() - ret = self.eprets[i] - eplen = self.eplens[i] - epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} - info['episode'] = epinfo - self.epcount += 1 - self.eprets[i] = 0 - self.eplens[i] = 0 - newinfos[i] = info - return obs, rews, dones, newinfos - class CustomMicroRTSGridMode(MicroRTSGridModeVecEnv): @@ -59,9 +24,9 @@ def __init__(self, *args, **kwargs): kwargs['num_bot_envs'] = len(kwargs.get('ai2s', [])) super().__init__(*args, **kwargs) self.num_cells = self.height*self.width - self.action_space = gym.spaces.MultiDiscrete(np.array([ - [6, 4, 4, 4, 4, len(self.utt['unitTypes']), 7 * 7] - ] * self.height * self.width).flatten()) + # self.action_space = gym.spaces.MultiDiscrete(np.array([ + # [6, 4, 4, 4, 4, len(self.utt['unitTypes']), 7 * 7] + # ] * self.height * self.width).flatten()) self.observation_space = gym.spaces.Dict({ "obs": self.observation_space, "masks": gym.spaces.Box( @@ -73,8 +38,7 @@ def __init__(self, *args, **kwargs): }) def get_action_mask(self): - action_mask = np.array(self.vec_client.getMasks(0)) - return action_mask[:,:,:,1:].reshape(self.num_envs, -1) + return super().get_action_mask().reshape(self.num_envs, -1) def step_async(self, action): action = action.reshape(self.num_envs, self.num_cells, -1) @@ -218,6 +182,9 @@ def __init__(self, *args, **kwargs): self.value_net = nn.Identity() def _build_mlp_extractor(self) -> None: + # xxx(okachaiev): would be nice if SB3 provided configuration for + # MlpExtractor class. in this case I wouldn't need to reload + # "internal" function of the class self.mlp_extractor = MicroRTSExtractor( input_channels=27, # output_channels=self.action_space.nvec[1:].sum(), @@ -243,7 +210,7 @@ def extract_features(self, obs): # microrts_ai.workerRushAI, # microrts_ai.coacAI, ], - map_path="maps/16x16/basesWorkers16x16.xml", + map_paths=["maps/16x16/basesWorkers16x16.xml"], reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0]) ) envs = VecMonitor(envs) diff --git a/requirements.txt b/requirements.txt index b5121ad..e128a59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -gym-microrts==0.4.3 -stable_baselines3==1.3.0 \ No newline at end of file +gym-microrts==0.5.0 +stable_baselines3==1.1.0 \ No newline at end of file