In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax

from jux.env import JuxEnv, JuxEnvBatch
from jux.config import JuxBufferConfig, EnvConfig

from models import NaiveActorCritic
from preprocess import get_feature, batch_get_feature
from constants import *
from space import ObsSpace, ActionSpace
from utils import get_seeds

In [4]:
from importlib import reload

from agent import naive_bid_agent, random_factory_agent_batched
import models
import ppo
import preprocess as pp

seed = 42

In [32]:
reload(models)
reload(ppo)
reload(pp)

env_config = EnvConfig()
buf_config = JuxBufferConfig(MAX_N_UNITS=1000)
ppo_config = ppo.PPOConfig(
    LR = 1e-4,
    MAX_GRAD_NORM = 0.5,
    N_ENVS = 4,
    N_UPDATES = 1000,
    N_EPISODES_PER_ENV= 16,
    UPDATE_EPOCHS = 4,
    NUM_MINIBATCHES = 32 ,

    GAMMA = 0.99,
    GAE_LAMBDA = 0.96,
    CLIP_EPS = 0.2,
    ENT_COEF = 0.01,  # Entropy loss coefficient
    VF_COEF = 0.5,  # Critic loss coefficient
)

rng = jax.random.PRNGKey(seed)
rng, _rng = jax.random.split(rng)

batch_env = JuxEnvBatch(env_config, buf_config)
num_envs = ppo_config.N_ENVS

# Initialize network
rng, _rng = jax.random.split(rng)

dummy_seeds = get_seeds(_rng, (1,))
dummy_state = batch_env.reset(dummy_seeds)
feature = pp.batch_get_feature(dummy_state)
network = models.NaiveActorCritic(
    env_config=env_config,
    buf_config=buf_config,
)
network_params = network.init(_rng, feature)

# Optimizer - clip_by_global_norm / adam
tx = optax.chain(
    optax.clip_by_global_norm(ppo_config.MAX_GRAD_NORM),
    optax.adam(ppo_config.LR, eps=1e-5),
)

train_state = TrainState.create(
    apply_fn=network.apply,
    params=network_params,
    tx=tx,
)

In [39]:
from jax.tree_util import tree_map, tree_reduce
tree_reduce(int.__add__, tree_map(lambda x: x.size, network_params))

1120385

In [41]:
rng, _rng = jax.random.split(rng)
update_state = ppo.UpdateState(train_state, _rng)

### UPDATE_STEP

In [43]:
# Initialize env_state
rng, _rng = jax.random.split(rng)
seeds = get_seeds(_rng, (num_envs,))
env_state = batch_env.reset(seeds)

In [45]:
# Bidding step
rng, _rng = jax.random.split(rng)
_rng = jax.random.split(_rng, num=num_envs)
bid, faction = jax.vmap(naive_bid_agent)(env_state, _rng)
env_state, _ = batch_env.step_bid(env_state, bid, faction)

In [None]:
# Factory placement step
n_factories = env_state.board.factories_per_team[0].astype(jnp.int32)

def _factory_placement_step(i, env_state_rng):
    env_state, rng = env_state_rng
    rng, _rng = jax.random.split(rng)
    _rng = jax.random.split(_rng, num=num_envs)
    factory_placement = jax.vmap(random_factory_agent)(env_state, _rng)
    env_state, _ = batch_env.step_factory_placement(env_state, *factory_placement)
    return env_state, rng

rng, _rng = jax.random.split(rng)
env_state = jax.lax.fori_loop(0, n_factories, _factory_placement_step, (env_state, rng))
