In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper

In [2]:
class BCAgent(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
#         pi = distrax.Categorical(logits=actor_mean)

        return actor_mean
    

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
        

# Behavioural Cloning Training Loop (WIP)

In [91]:
def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
#     # Keep minibatch in case needed for large datasets
#     config["MINIBATCH_SIZE"] = (
#         config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
#     )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    # Do I need a schedule on the LR for BC?
    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(synth_data, action_labels, rng):
        # Action labels are fixed, for now
        
        # 1. INIT NETWORK
        network = BCAgent(env.action_space(env_params).n, activation=config["ACTIVATION"])
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        
        # Setup optimizer
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
        
        # Train state carries everything needed for NN training
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )
        
        # 2. BC TRAIN LOOP
        def _bc_train(train_state, _rng):
            
            def _bc_update_step(bc_state, unused):
                
                train_state, rng = bc_state
                
                def _loss_and_acc(params, step_data, apply_fn, y_true, num_classes):
                    """Compute cross-entropy loss and accuracy."""
                    y_pred = apply_fn(params, step_data)
                    
                    acc = jnp.mean(jnp.argmax(y_pred, axis=-1) == y_true)
                    labels = jax.nn.one_hot(y_true, num_classes)
                    loss = -jnp.sum(labels * jax.nn.log_softmax(y_pred))
                    loss /= labels.shape[0]
                    return loss, acc
            
                grad_fn = jax.value_and_grad(_loss_and_acc, has_aux=True)
                
                # Not needed if using entire dataset
                rng, perm_rng = jax.random.split(rng)
                perm = jax.random.permutation(perm_rng, len(action_labels))
                step_data = synth_data[perm]
                y_true = action_labels[perm]
                
                loss_and_acc, grads = grad_fn(
                    train_state.params, step_data, train_state.apply_fn, y_true, env.action_space().n
                )
                
                train_state = train_state.apply_gradients(grads = grads)
                
                loss, acc = loss_and_acc
                bc_state = (train_state, rng)

                return bc_state, loss_and_acc
            
            bc_state = (train_state, _rng)
            bc_state, loss_and_acc = jax.lax.scan(
                _bc_update_step, bc_state, None, config["BC_EPOCHS"]
            )
            
            loss, acc = loss_and_acc
            return bc_state, loss, acc
            
        rng, _rng = jax.random.split(rng)
        bc_state, loss, acc = _bc_train(train_state, _rng)
        
        train_state = bc_state[0]
        # TODO: Double check the returns above and how scan handles multiple returns
        #^^^^^^^^^^^^^
        
        # INIT ENV (shouldn't need multiple eval environments, but no harm in keeping it)
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
        
        # 3. POLICY EVAL LOOP
        def _eval_ep(runner_state):
            # Environment stepper
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state
                
                # Select Action
                rng, _rng = jax.random.split(rng)
                pi = train_state.apply_fn(train_state.params, last_obs)
                if config["GREEDY_ACT"]:
                    action = pi.argmax(axis=1)  # if 2+ actions are equiprobable, returns first
                else:
                    probs = distrax.Categorical(logits=actor_mean)
                    action = probs.sample(seed=_rng)
#                 log_prob = pi.log_prob(action)

                # Step env
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
            
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, -1, reward, jax.nn.log_softmax(pi), last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition
                
            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )
            metric = traj_batch.info
            return runner_state, metric
        
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        
        runner_state, metric = _eval_ep(runner_state)
        
        # Why am I using scan here and not VMAP over a number of envs?
        # Actually, scan only makes sense if we are are having a (train, eval, train, eval...) loop.
#         runner_state, metric = jax.lax.scan(
#             _update_step, runner_state, None, config["NUM_UPDATES"]
#         )

        return {"runner_state": runner_state, "metrics": metric}

    return train

# Meta-learning the Dataset

In [92]:
from evosax import OpenES, ParameterReshaper
from evosax.problems import VisionFitness

In [98]:
config = {
    "LR": 5e-3,
    "NUM_ENVS": 1, #64,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 8,
#     "GAMMA": 0.99,
#     "GAE_LAMBDA": 0.95,
#     "CLIP_EPS": 0.2,
#     "ENT_COEF": 0.01,
#     "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu",
    "ENV_NAME": "Breakout-MinAtar",
    "ANNEAL_LR": True,
    "GREEDY_ACT": True,
    "BC_EPOCHS": 2,
}

In [102]:
env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

n_actions = env.action_space(env_params).n
num_rollouts = 16

es_config = {
    "popsize" : 10,
    "dataset_size" : n_actions * 5,
}

params = jnp.zeros((es_config["dataset_size"], *env.observation_space(env_params).shape))
param_reshaper = ParameterReshaper(params)

ParameterReshaper: 6000 parameters detected for optimization.


### Setup

In [107]:
# Initialize OpenES Strategy
rng = jax.random.PRNGKey(0)
rng, rng_init = jax.random.split(rng)

strategy = OpenES(popsize=es_config["popsize"], num_dims=param_reshaper.total_params, opt_name="adam", maximize=True)
state = strategy.initialize(rng_init)

def get_action_labels(d_size, n_actions):
    action_labels = jnp.array([i % n_actions for i in range(d_size)])
    action_labels = action_labels.sort()
    return action_labels


# Set up vectorized fitness function
train_fn = make_train(config)
action_labels = get_action_labels(es_config["dataset_size"], n_actions)

def single_seed_BC(rng_input, dataset):
    out = train_fn(dataset, action_labels, rng_input)
    return out
multi_seed_BC = jax.vmap(single_seed_BC, in_axes=(0, None))    # Vectorize over seeds
train_and_eval = jax.jit(jax.vmap(multi_seed_BC, in_axes=(None, 0)))  # Vectorize over datasets


### Run OpenES loop

In [104]:
for gen in range(10):
    # Gen new dataset
    rng, rng_ask, rng_inner = jax.random.split(rng, 3)
    datasets, state = jax.jit(strategy.ask)(rng_ask, state)   
    # Eval fitness
    batch_rng = jax.random.split(rng_inner, num_rollouts)
    
    with jax.disable_jit(True):
        shaped_datasets = param_reshaper.reshape(datasets)
        out = train_and_eval(batch_rng, shaped_datasets).mean(axis=1)
        fitness = out["metrics"]['returned_episode_returns'].mean()
    
    # Update ES strategy with fitness info
    state = jax.jit(strategy.tell)(datasets, fitness, state)
#     if gen % 20 == 0:
    print(f"Generation: {gen}, Fitness: {fitness.mean():.2f}, Best: {state.best_fitness:.2f}")

Generation: 0, Fitness: 0.18, Best: 0.33
Generation: 1, Fitness: 0.17, Best: 0.33
Generation: 2, Fitness: 0.17, Best: 0.33
Generation: 3, Fitness: 0.24, Best: 0.38
Generation: 4, Fitness: 0.34, Best: 0.53
Generation: 5, Fitness: 0.31, Best: 0.53
Generation: 6, Fitness: 0.24, Best: 0.53
Generation: 7, Fitness: 0.20, Best: 0.53
Generation: 8, Fitness: 0.18, Best: 0.53
Generation: 9, Fitness: 0.29, Best: 0.53
