Skip to content

Commit

Permalink
[codecraft] Load DCC policies and add ENHANCED task (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
cswinter committed Feb 21, 2022
1 parent c3b61a1 commit fb1a92b
Show file tree
Hide file tree
Showing 18 changed files with 1,865 additions and 961 deletions.
51 changes: 29 additions & 22 deletions enn_ppo/enn_ppo/train.py
Expand Up @@ -26,7 +26,7 @@
import numpy.typing as npt
from entity_gym.examples import ENV_REGISTRY
from enn_zoo.griddly import GRIDDLY_ENVS, create_env
from enn_zoo.codecraft.cc_vec_env import CodeCraftEnv, CodeCraftVecEnv
from enn_zoo.codecraft.cc_vec_env import codecraft_env_class, CodeCraftVecEnv
from enn_zoo.codecraft.codecraftnet.adapter import CCNetAdapter
from entity_gym.serialization import SampleRecordingVecEnv
from enn_ppo.simple_trace import Tracer
Expand Down Expand Up @@ -96,6 +96,7 @@ def parse_args(override_args: Optional[List[str]] = None) -> argparse.Namespace:
parser.add_argument('--eval-capture-videos', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
help='If --eval-render-videos is set, videos will be recorded of the environments during evaluation')
parser.add_argument('--codecraft-eval', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True)
parser.add_argument('--codecraft-eval-opponent', type=str, default=None)

# Network architecture
parser.add_argument('--d-model', type=int, default=64,
Expand Down Expand Up @@ -482,6 +483,7 @@ def run_eval(
global_step: int,
capture_videos: bool = False,
codecraft_eval: bool = False,
eval_opponent: Optional[str] = None,
) -> None:
# TODO: metrics are biased towards short episodes
eval_envs: VecEnv
Expand All @@ -497,19 +499,22 @@ def run_eval(
else:
eval_envs = EnvList(env_cls, args.eval_env_kwargs or env_kwargs, num_envs)
if codecraft_eval:
random_agent = PPOActor(
obs_space,
dict(action_space),
d_model=args.d_model,
n_head=args.n_head,
n_layer=args.n_layer,
pooling_op=args.pooling_op,
).to(device)
if eval_opponent is None:
opponent = PPOActor(
obs_space,
dict(action_space),
d_model=args.d_model,
n_head=args.n_head,
n_layer=args.n_layer,
pooling_op=args.pooling_op,
).to(device)
else:
opponent = CCNetAdapter(device, load_from=eval_opponent) # type: ignore
agents: Union[PPOActor, List[Tuple[npt.NDArray[np.int64], PPOActor]]] = [
(np.array([2 * i for i in range(num_envs // 2)]), agent),
(
np.array([2 * i + 1 for i in range(num_envs // 2)]),
random_agent,
opponent,
),
]
else:
Expand Down Expand Up @@ -610,23 +615,24 @@ def train(args: argparse.Namespace) -> float:
device = torch.device("cuda" if cuda else "cpu")
tracer = Tracer(cuda=cuda)

env_kwargs = json.loads(args.env_kwargs)
if args.eval_env_kwargs is not None:
eval_env_kwargs = json.loads(args.eval_env_kwargs)
else:
eval_env_kwargs = env_kwargs

if args.gym_id in ENV_REGISTRY:
env_cls = ENV_REGISTRY[args.gym_id]
elif args.gym_id in GRIDDLY_ENVS:
env_cls = create_env(**GRIDDLY_ENVS[args.gym_id])
elif args.gym_id == "CodeCraft":
env_cls = CodeCraftEnv
env_cls = codecraft_env_class(env_kwargs.get("objective", "ALLIED_WEALTH"))
else:
raise KeyError(
f"Unknown gym_id: {args.gym_id}\nAvailable environments: {list(ENV_REGISTRY.keys()) + list(GRIDDLY_ENVS.keys())}"
)

# env setup
env_kwargs = json.loads(args.env_kwargs)
if args.eval_env_kwargs is not None:
eval_env_kwargs = json.loads(args.eval_env_kwargs)
else:
eval_env_kwargs = env_kwargs
envs: VecEnv
if args.gym_id == "CodeCraft":
envs = CodeCraftVecEnv(args.num_envs, **env_kwargs)
Expand Down Expand Up @@ -712,6 +718,7 @@ def _run_eval() -> None:
rollout.global_step,
args.eval_capture_videos,
args.codecraft_eval,
args.codecraft_eval_opponent,
)

start_time = time.time()
Expand Down Expand Up @@ -858,7 +865,7 @@ def _run_eval() -> None:
]

# TODO: not invariant to microbatch size, should be normalizing full batch or minibatch instead
mb_advantages = b_advantages[mb_inds]
mb_advantages = b_advantages[mb_inds] # type: ignore
if args.norm_adv:
assert (
len(mb_advantages) > 1
Expand Down Expand Up @@ -910,17 +917,17 @@ def _run_eval() -> None:
with tracer.span("value_loss"):
newvalue = newvalue.view(-1)
if args.clip_vloss:
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 # type: ignore
v_clipped = b_values[mb_inds] + torch.clamp( # type: ignore
newvalue - b_values[mb_inds], # type: ignore
-args.clip_coef,
args.clip_coef,
)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 # type: ignore
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() # type: ignore

# TODO: what's correct way of combining entropy loss from multiple actions/actors on the same timestep?
if args.anneal_entropy:
Expand Down

0 comments on commit fb1a92b

Please sign in to comment.