In [79]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper

In [80]:
class BCAgent(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        return pi
    

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
        

In [237]:

pi = distrax.Categorical(logits=[0,1,1.0000001,0])

In [240]:
pi.num_categories

4

# Behavioural Cloning Training Loop (WIP)

In [4]:
def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
#     # Keep minibatch in case needed for large datasets
#     config["MINIBATCH_SIZE"] = (
#         config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
#     )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    # Do I need a schedule on the LR for BC?
    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(syth_data, rng):
        # 1. INIT NETWORK
        network = BCAgent(env.action_space(env_params).n, activation=config["ACTIVATION"])
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        
        # Setup optimizer
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
        
        # Train state carries everything needed for NN training
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )
        
        # 2. BC TRAIN LOOP
        def _bc_train(train_state, bc_rng):
            
            # shuffle dataset
            # get predictions
            # compute XENT loss
            # Get loss on data
            # 
            
        ...
        
        
        # INIT ENV (shouldn't need multiple eval environments, but no harm in keeping it)
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
        
        # 3. POLICY EVAL LOOP
        def _eval_ep(runner_state):
            # Environment stepper
            def _env_step(runner_state):
                train_state, env_state, last_obs, rng = runner_state
                
                # Select Action
                rng, _rng = jax.random.split(rng)
                pi = network.apply(train_state.params, last_obs)
                # TODO: CONSIDER ARGMAX INSTEAD OF SAMPLING
                if config["GREEDY_ACT"]:
                    action = pi.mode()  # if 2+ actions are equiprobable, returns first
                else:
                    action = pi.sample(seed=_rng)
#                 log_prob = pi.log_prob(action)

                # Step env
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition
                
            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )
            
            metric = traj_batch.info
            
            return runner_state, metric
        
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        
        runner_state, metric = _eval_ep(runner_state)
        
        # Why am I using scan here and not VMAP over a number of envs?
        # Actually, scan only makes sense if we are are having a (train, eval, train, eval...) loop.
#         runner_state, metric = jax.lax.scan(
#             _update_step, runner_state, None, config["NUM_UPDATES"]
#         )

        return {"runner_state": runner_state, "metrics": metric}

    return train

SyntaxError: invalid syntax (2084863037.py, line 42)

# Meta-learning the Dataset

In [34]:
from evosax import OpenES, ParameterReshaper
from evosax.problems import VisionFitness

In [25]:
config = {
    "LR": 5e-3,
    "NUM_ENVS": 1, #64,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 8,
#     "GAMMA": 0.99,
#     "GAE_LAMBDA": 0.95,
#     "CLIP_EPS": 0.2,
#     "ENT_COEF": 0.01,
#     "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu",
    "ENV_NAME": "Breakout-MinAtar",
    "ANNEAL_LR": True,
    "GREEDY_ACT": True,
}

In [98]:
env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

n_actions = env.action_space(env_params).n

es_config = {
    "popsize" : 100,
    "dataset_size" : n_actions * 5,
}

params = jnp.zeros((es_config["dataset_size"], *env.observation_space(env_params).shape))
param_reshaper = ParameterReshaper(params)

ParameterReshaper: 6000 parameters detected for optimization.


### Setup

In [161]:
# Initialize OpenES Strategy
rng = jax.random.PRNGKey(0)
rng, rng_init = jax.random.split(rng)

strategy = OpenES(popsize=es_config["popsize"], num_dims=param_reshaper.total_params, opt_name="adam", maximize=False)
state = strategy.initialize(rng_init)

# Set up vectorized fitness function
train_fn = make_train(config)
def single_seed_BC(rng_input, dataset):
    out = train_fn(dataset, rng_input)
    return out["metrics"]['returned_episode_returns'].mean()

multi_seed_BC = jax.vmap(single_seed_BC, in_axes=(0, None))    # Vectorize over seeds
train_and_eval = jax.jit(jax.vmap(multi_seed_BC, in_axes=(None, 0)))  # Vectorize over datasets


### Run OpenES loop

In [163]:
for gen in range(100):
    # Gen new dataset
    rng, rng_ask, rng_inner = jax.random.split(rng, 3)
    datasets, state = jax.jit(strategy.ask)(rng_ask, state)
    if gen == 0:
        print(data.mean())
    
    # Eval fitness [PLACEHOLDER. TODO: REPLACE WITH BC LOOP]
    batch_rng = jax.random.split(rng_inner, num_rollouts)
    fitness = train_and_eval(batch_rng, datasets).mean(axis=1)
    
    # Update ES strategy with fitness info
    state = jax.jit(strategy.tell)(data, fitness, state)
    print(f"Generation: {gen}, Fitness: {fitness.mean():.2f}, Best: {state.best_fitness:.2f}")

0.14906682
Generation: 0, Fitness: 0.86, Best: 0.86
Generation: 1, Fitness: 0.86, Best: 0.86
Generation: 2, Fitness: 0.85, Best: 0.85
Generation: 3, Fitness: 0.85, Best: 0.85
Generation: 4, Fitness: 0.85, Best: 0.84
Generation: 5, Fitness: 0.84, Best: 0.84
Generation: 6, Fitness: 0.84, Best: 0.84
Generation: 7, Fitness: 0.83, Best: 0.83
Generation: 8, Fitness: 0.83, Best: 0.83
Generation: 9, Fitness: 0.83, Best: 0.82
Generation: 10, Fitness: 0.82, Best: 0.82
Generation: 11, Fitness: 0.82, Best: 0.82
Generation: 12, Fitness: 0.82, Best: 0.81
Generation: 13, Fitness: 0.81, Best: 0.81
Generation: 14, Fitness: 0.81, Best: 0.81
Generation: 15, Fitness: 0.80, Best: 0.80
Generation: 16, Fitness: 0.80, Best: 0.80
Generation: 17, Fitness: 0.80, Best: 0.79
Generation: 18, Fitness: 0.79, Best: 0.79
Generation: 19, Fitness: 0.79, Best: 0.79
Generation: 20, Fitness: 0.79, Best: 0.78
Generation: 21, Fitness: 0.78, Best: 0.78
Generation: 22, Fitness: 0.78, Best: 0.78
Generation: 23, Fitness: 0.78, Be

In [160]:
state.mean

Array([1.0265108, 0.9039204, 0.8840454, ..., 1.2643888, 0.9027829,
       0.8671209], dtype=float32)