Skip to content

Commit

Permalink
rm firezap, add nearby observation, add build principal
Browse files Browse the repository at this point in the history
  • Loading branch information
ezhang7423 committed Mar 13, 2024
1 parent d6847dc commit 567986c
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
}
},
"python.formatting.provider": "none",
Expand Down
28 changes: 11 additions & 17 deletions meltingpot/configs/substrates/commons_harvest__open.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,16 @@
]
}

# Primitive action components.
# Primitive action components. with zap removed
# pylint: disable=bad-whitespace
# pyformat: disable
NOOP = {"move": 0, "turn": 0, "fireZap": 0}
FORWARD = {"move": 1, "turn": 0, "fireZap": 0}
STEP_RIGHT = {"move": 2, "turn": 0, "fireZap": 0}
BACKWARD = {"move": 3, "turn": 0, "fireZap": 0}
STEP_LEFT = {"move": 4, "turn": 0, "fireZap": 0}
TURN_LEFT = {"move": 0, "turn": -1, "fireZap": 0}
TURN_RIGHT = {"move": 0, "turn": 1, "fireZap": 0}
FIRE_ZAP = {"move": 0, "turn": 0, "fireZap": 1}
NOOP = {"move": 0, "turn": 0}
FORWARD = {"move": 1, "turn": 0}
STEP_RIGHT = {"move": 2, "turn": 0}
BACKWARD = {"move": 3, "turn": 0}
STEP_LEFT = {"move": 4, "turn": 0}
TURN_LEFT = {"move": 0, "turn": -1}
TURN_RIGHT = {"move": 0, "turn": 1}
# pyformat: enable
# pylint: enable=bad-whitespace

Expand All @@ -269,7 +268,6 @@
STEP_RIGHT,
TURN_LEFT,
TURN_RIGHT,
FIRE_ZAP,
)

TARGET_SPRITE_SELF = {
Expand Down Expand Up @@ -473,11 +471,10 @@ def create_avatar_object(player_idx: int,
"speed": 1.0,
"spawnGroup": spawn_group,
"postInitialSpawnGroup": "spawnPoints",
"actionOrder": ["move", "turn", "fireZap"],
"actionOrder": ["move", "turn"],
"actionSpec": {
"move": {"default": 0, "min": 0, "max": len(_COMPASS)},
"turn": {"default": 0, "min": -1, "max": 1},
"fireZap": {"default": 0, "min": 0, "max": 1},
},
"view": {
"left": 5,
Expand All @@ -500,9 +497,6 @@ def create_avatar_object(player_idx: int,
"rewardForZapping": 0,
}
},
{
"component": "ReadyToShootObservation",
},
]
}
if _ENABLE_DEBUG_OBSERVATIONS:
Expand Down Expand Up @@ -540,7 +534,7 @@ def get_config():
# Observation format configuration.
config.individual_observation_names = [
"RGB",
"READY_TO_SHOOT",
"NEARBY"
]
config.global_observation_names = [
"WORLD.RGB",
Expand All @@ -550,9 +544,9 @@ def get_config():
config.action_spec = specs.action(len(ACTION_SET))
config.timestep_spec = specs.timestep({
"RGB": specs.OBSERVATION["RGB"],
"READY_TO_SHOOT": specs.OBSERVATION["READY_TO_SHOOT"],
# Debug only (do not use the following observations in policies).
"WORLD.RGB": specs.rgb(144, 192),
".NEARBY": specs.int32(10)
})

# The roles assigned to each player.
Expand Down
25 changes: 25 additions & 0 deletions meltingpot/lua/modules/avatar_library.lua
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,31 @@ function Avatar:addObservations(tileSet, world, observations)
end
}
observations[#observations + 1] = spec

observations[#observations + 1] = {
name = stringId .. '.NEARBY',
type = 'tensor.Int32Tensor',
shape = {},
func = function(grid)
-- List of avatar ids
local resultsList = {}
local objectsOnLayer = self:queryPartialObservationWindow("upperPhysical")
for _, object in ipairs(objectsOnLayer) do
if object:hasComponent('Avatar') then
local index = object:getComponent('Avatar'):getIndex()
table.insert(resultsList, index)
end
end
-- Then reformat list as int32 tensor to output
local numPlayers = self.gameObject.simulation:getNumPlayers()
local resultTensor = tensor.Int32Tensor(numPlayers):fill(0)
for _, avatarId in ipairs(resultsList) do
resultTensor(avatarId):add(1)
end
return resultTensor
end
}

end

function Avatar:reset()
Expand Down
3 changes: 2 additions & 1 deletion meltingpot/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _scenarios_by_substrate() -> Mapping[str, Collection[str]]:
'STAMINA',
'VOTING',
# An extra observation that is never necessary but could perhaps help.
'COLLECTIVE_REWARD'
'COLLECTIVE_REWARD',
'NEARBY'
})


Expand Down
22 changes: 22 additions & 0 deletions meltingpot/substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from meltingpot.utils.substrates import substrate_factory
from ml_collections import config_dict

from SocialEnvDesign import principal_substrate
from SocialEnvDesign.principal import Principal

SUBSTRATES = substrate_configs.SUBSTRATES


Expand Down Expand Up @@ -59,6 +62,25 @@ def build_from_config(
"""
return get_factory_from_config(config).build(roles)

def build_principal_from_config(
config: config_dict.ConfigDict,
*,
roles: Sequence[str],
principal: Principal
) -> principal_substrate.PrincipalSubstrate:
"""Builds a substrate from the provided config.
Args:
config: config resulting from `get_config`.
roles: sequence of strings defining each player's role. The length of
this sequence determines the number of players.
principal: the principal
Returns:
The training substrate.
"""
return get_factory_from_config(config).build_principal(roles, principal)


def get_factory(name: str) -> substrate_factory.SubstrateFactory:
"""Returns the factory for the specified substrate."""
Expand Down
2 changes: 0 additions & 2 deletions meltingpot/utils/substrates/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
shape=(), dtype=np.float64, minimum=0, maximum=1, name='discount')
REWARD = dm_env.specs.Array(shape=(), dtype=np.float64, name='reward')
OBSERVATION = immutabledict.immutabledict({
'READY_TO_SHOOT': dm_env.specs.Array(
shape=(), dtype=np.float64, name='READY_TO_SHOOT'),
'RGB': dm_env.specs.Array(shape=(88, 88, 3), dtype=np.uint8, name='RGB'),
'POSITION': dm_env.specs.Array(shape=(2,), dtype=np.int32, name='POSITION'),
'ORIENTATION': dm_env.specs.Array(
Expand Down
20 changes: 20 additions & 0 deletions meltingpot/utils/substrates/substrate_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from meltingpot.utils.substrates import builder
from meltingpot.utils.substrates import substrate

from SocialEnvDesign import principal_substrate
from SocialEnvDesign.principal import Principal


class SubstrateFactory:
"""Factory for building specific substrates."""
Expand Down Expand Up @@ -93,3 +96,20 @@ def build(self, roles: Sequence[str]) -> substrate.Substrate:
individual_observations=self._individual_observations,
global_observations=self._global_observations,
action_table=self._action_table)

def build_principal(self, roles: Sequence[str], principal: Principal) -> principal_substrate.PrincipalSubstrate:
"""Builds the substrate.
Args:
roles: the role each player will take.
Returns:
The constructed substrate.
"""
return principal_substrate.build_substrate(
lab2d_settings=self._lab2d_settings_builder(roles),
individual_observations=self._individual_observations,
global_observations=self._global_observations,
action_table=self._action_table,
principal=principal
)

0 comments on commit 567986c

Please sign in to comment.