Skip to content

Commit

Permalink
Made gym-microrts 0.5.0 work by downgrading SB3 to 1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii Kachaiev committed Dec 21, 2021
1 parent 6f04306 commit ff6b7d6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -11,7 +11,7 @@ Original code: [gym-microrts-paper](https://github.com/vwxyzjn/gym-microrts-pape
## Install

Prerequisites:
* Python 3.7+
* Python 3.7.1+
* Java 8.0+
* FFmpeg (for video capturing)

Expand Down
51 changes: 9 additions & 42 deletions ppo_gridnet_diverse_encode_decode_sb3.py
Expand Up @@ -8,48 +8,13 @@
from stable_baselines3.common.policies import ActorCriticPolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecEnvWrapper
from stable_baselines3.common.vec_env import VecMonitor

import gym
import gym_microrts
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
from gym_microrts import microrts_ai

# code from here:
# https://github.com/vwxyzjn/gym-microrts-paper/blob/cf291b303c04e98be2f00acbbe6bbb2c23a8bac5/ppo_gridnet_diverse_encode_decode.py#L96
class VecMonitor(VecEnvWrapper):
def __init__(self, venv):
VecEnvWrapper.__init__(self, venv)
self.eprets = None
self.eplens = None
self.epcount = 0
self.tstart = time.time()

def reset(self):
obs = self.venv.reset()
self.eprets = np.zeros(self.num_envs, 'f')
self.eplens = np.zeros(self.num_envs, 'i')
return obs

def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()
self.eprets += rews
self.eplens += 1

newinfos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
ret = self.eprets[i]
eplen = self.eplens[i]
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
info['episode'] = epinfo
self.epcount += 1
self.eprets[i] = 0
self.eplens[i] = 0
newinfos[i] = info
return obs, rews, dones, newinfos


class CustomMicroRTSGridMode(MicroRTSGridModeVecEnv):

Expand All @@ -59,9 +24,9 @@ def __init__(self, *args, **kwargs):
kwargs['num_bot_envs'] = len(kwargs.get('ai2s', []))
super().__init__(*args, **kwargs)
self.num_cells = self.height*self.width
self.action_space = gym.spaces.MultiDiscrete(np.array([
[6, 4, 4, 4, 4, len(self.utt['unitTypes']), 7 * 7]
] * self.height * self.width).flatten())
# self.action_space = gym.spaces.MultiDiscrete(np.array([
# [6, 4, 4, 4, 4, len(self.utt['unitTypes']), 7 * 7]
# ] * self.height * self.width).flatten())
self.observation_space = gym.spaces.Dict({
"obs": self.observation_space,
"masks": gym.spaces.Box(
Expand All @@ -73,8 +38,7 @@ def __init__(self, *args, **kwargs):
})

def get_action_mask(self):
action_mask = np.array(self.vec_client.getMasks(0))
return action_mask[:,:,:,1:].reshape(self.num_envs, -1)
return super().get_action_mask().reshape(self.num_envs, -1)

def step_async(self, action):
action = action.reshape(self.num_envs, self.num_cells, -1)
Expand Down Expand Up @@ -218,6 +182,9 @@ def __init__(self, *args, **kwargs):
self.value_net = nn.Identity()

def _build_mlp_extractor(self) -> None:
# xxx(okachaiev): would be nice if SB3 provided configuration for
# MlpExtractor class. in this case I wouldn't need to reload
# "internal" function of the class
self.mlp_extractor = MicroRTSExtractor(
input_channels=27,
# output_channels=self.action_space.nvec[1:].sum(),
Expand All @@ -243,7 +210,7 @@ def extract_features(self, obs):
# microrts_ai.workerRushAI,
# microrts_ai.coacAI,
],
map_path="maps/16x16/basesWorkers16x16.xml",
map_paths=["maps/16x16/basesWorkers16x16.xml"],
reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0])
)
envs = VecMonitor(envs)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
@@ -1,2 +1,2 @@
gym-microrts==0.4.3
stable_baselines3==1.3.0
gym-microrts==0.5.0
stable_baselines3==1.1.0

0 comments on commit ff6b7d6

Please sign in to comment.