In [1]:
import os
NUM_DEVICES = 20
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={NUM_DEVICES}'

import jax
jax.devices()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8),
 CpuDevice(id=9),
 CpuDevice(id=10),
 CpuDevice(id=11),
 CpuDevice(id=12),
 CpuDevice(id=13),
 CpuDevice(id=14),
 CpuDevice(id=15),
 CpuDevice(id=16),
 CpuDevice(id=17),
 CpuDevice(id=18),
 CpuDevice(id=19)]

In [2]:
# from finite.rcmdp import RCMDP
# from finite.garnet import create_rcmdp, compute_policy_worst_values, S, A, N, DISCOUNT, ITER_LENGTH, NUM_SEEDS, FIGNAME

from KL.rcmdp import RCMDP
from KL.garnet import create_rcmdp, compute_policy_worst_values, S, A, N, DISCOUNT, ITER_LENGTH, NUM_SEEDS, FIGNAME

assert NUM_DEVICES >= NUM_SEEDS, "NUM_DEVICES must be greater than NUM_SEEDS"

KeyboardInterrupt: 

In [3]:
import jax.numpy as jnp

@jax.jit
def projection_to_simplex(y):
    """project y to a probability simplex
    see：https://arxiv.org/pdf/1309.1541
    Args:
        y (jnp.ndarray): (A)-vector

    Returns:
        x (jnp.ndarray): (A)-vector
    """
    D = len(y)
    u = jnp.sort(y)[::-1]
    u_sum = jnp.cumsum(u)
    rho_pos_flag = (u + (1 - u_sum) / (jnp.arange(D) + 1)) > 0
    rho = jnp.argmax(jnp.cumsum(rho_pos_flag))
    lam = (1 - u_sum[rho]) / (rho + 1)
    x = jnp.maximum(y + lam, 0)
    return x


proj_to_Pi = jax.vmap(projection_to_simplex)

In [4]:
from functools import partial
from tqdm import tqdm
import chex


policy = jnp.ones((NUM_SEEDS, S, A)) / A
sum_policy = jnp.zeros((NUM_SEEDS, S, A))
lam = jnp.zeros((NUM_SEEDS, N))
res_J0U_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
vios_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
res_J0U_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
vios_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
res_J0U_avg_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
vios_avg_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))

InitLagArgs = res_J0U_list, vios_list, res_J0U_avg_list, vios_avg_list, policy, sum_policy, lam 

@jax.jit
def solve_inner_Lagrange(lam: float, rcmdp: RCMDP, init_policy: jnp.ndarray, num_iter: int, lr: float):
    """Apply policy gradients to the inner minimization problem of the Lagrangian formulation.
    See Algorithm 3 in the paper.

    Args:
        lam (float): Lagrangian variable
        rcmdp (RCMDP)
        init_policy (jnp.ndarray): Initial policy
        num_iter (int): Number of iteration
        lr (float): learning rate to update policy

    Returns:
        policy (jnp.ndarray): (SxA) array
    """
    chex.assert_shape(lam, (N, ))
    one_lam = jnp.hstack([jnp.array([1,]), lam])

    def condition_fn(loop_args):
        # break iteration if the policy does not change or k >= num_iter
        k, _, _, _, policy_diff = loop_args
        return (k < num_iter) & (policy_diff > 1e-5)

    def loop_fn(loop_args):
        k, policy, best_policy, best_L_lam, _ = loop_args

        # evaluate current policy
        worst_P_Q, worst_P_occ, worst_P_J = compute_policy_worst_values(policy, rcmdp)
        L_lam = (one_lam.reshape(N+1) * worst_P_J.reshape(N+1)).sum()
        best_policy = jax.lax.cond(L_lam < best_L_lam, lambda: policy, lambda: best_policy)
        best_L_lam = jnp.minimum(L_lam, best_L_lam)

        grad = jnp.sum(one_lam.reshape(N+1, 1, 1) * worst_P_occ.reshape(N+1, S, 1) * worst_P_Q, axis=0)
        new_policy = proj_to_Pi(policy - lr * grad)

        policy_diff = jnp.abs(new_policy - policy).sum()
        return k+1, new_policy, best_policy, best_L_lam, policy_diff
    
    best_policy = init_policy
    best_L_lam = jnp.inf
    _, _, best_policy, _, _ = jax.lax.while_loop(condition_fn, loop_fn, (0, init_policy, best_policy, best_L_lam, jnp.inf))
    return best_policy


@jax.jit
def update_outer_Lagrange(init_args, init_k: int, end_k: int, rcmdp: RCMDP, lam_lr: float = 0.02, inner_iter: int=1000, inner_lr: float = 0.001):
    """Update Lagrangian variable (end_k - init_k) times.
    See Algorithm 3 in the paper.

    Args:
        init_args: These arguments will be passed to the foriloop of jax. See InitLagArgs defined above.
        init_k (int): initial update index
        end_k (int): end of update index

    Returns:
        args: Computed arguments. See InitLagArgs defined above.
    """

    def eval_performance(policy):
        _, _, worst_P_J = compute_policy_worst_values(policy, rcmdp)
        vio = worst_P_J[1:] - rcmdp.threshes
        chex.assert_shape(vio, (N, ))
        return worst_P_J, vio

    def body_fn(k, args):
        res_J0U_list, vios_list, res_J0U_avg_list, vios_avg_list, policy, sum_policy, lam = args
        policy = solve_inner_Lagrange(lam, rcmdp, policy, inner_iter, inner_lr)
        worst_P_J, vio = eval_performance(policy)

        # report performance
        res_J0U_list = res_J0U_list.at[k].set(worst_P_J[0])
        vios_list = vios_list.at[k].set(vio.max())

        # update Lagrange
        new_lam = lam + lam_lr * vio
        lam = jnp.maximum(new_lam, 0)

        # report the averaged policy performance
        sum_policy = sum_policy + policy
        avg_policy = sum_policy / (k + 1)
        worst_P_J, vio = eval_performance(avg_policy)
        res_J0U_avg_list = res_J0U_avg_list.at[k].set(worst_P_J[0])
        vios_avg_list = vios_avg_list.at[k].set(vio.max())
        return res_J0U_list, vios_list, res_J0U_avg_list, vios_avg_list, policy, sum_policy, lam

    args = jax.lax.fori_loop(init_k, end_k, body_fn, init_args)
    return args


In [5]:
from functools import partial
from tqdm import tqdm


policy = jnp.ones((NUM_SEEDS, S, A)) / A
i = jnp.zeros((NUM_SEEDS))
j = jnp.ones((NUM_SEEDS)) * 1 / (1 - DISCOUNT)
res_J0U_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
vios_list = jnp.zeros((NUM_SEEDS, ITER_LENGTH))
InitEFArgs = res_J0U_list, vios_list, policy, i, j


@jax.jit
def solve_inner_EF(b0: float, rcmdp: RCMDP, init_policy: jnp.ndarray, num_iter: int, lr: float):
    """Apply policy gradients to the auxiliary minimization problem of the epigraph form.
    See Algorithm 1 in the paper.

    Args:
        b0 (float): Threshold variable
        rcmdp (RCMDP)
        init_policy (jnp.ndarray): Initial policy
        num_iter (int): Number of iteration
        lr (float): learning rate to update policy

    Returns:
        policy (jnp.ndarray): (SxA) array
    """
    b0_threshes = jnp.hstack([jnp.array([b0,]), rcmdp.threshes])

    def condition_fn(loop_args):
        # break iteration if Δ <= 0 or k >= num_iter
        k, _, _, best_Delta, policy_diff = loop_args
        return (k < num_iter) & (best_Delta > 0) & (policy_diff > 1e-5) # the k'th policy has not been evaluated yet

    def loop_fn(loop_args):
        k, policy, best_policy, best_Delta, _ = loop_args

        # evaluate current policy
        worst_P_Q, worst_P_occ, worst_P_J = compute_policy_worst_values(policy, rcmdp)
        Delta = jnp.max(worst_P_J - b0_threshes)
        best_policy = jax.lax.cond(Delta < best_Delta, lambda: policy, lambda: best_policy)
        best_Delta = jnp.minimum(Delta, best_Delta)

        # compute gradient
        worst_vio_idx = jnp.argmax(worst_P_J - b0_threshes)
        worst_Q, worst_occ = worst_P_Q[worst_vio_idx], worst_P_occ[worst_vio_idx]
        chex.assert_shape(worst_occ, (S,))
        chex.assert_shape(worst_Q, (S, A))
        grad = worst_occ.reshape(-1, 1) * worst_Q

        # update to new policy
        new_policy = proj_to_Pi(policy - lr * grad)
        policy_diff = jnp.abs(new_policy - policy).sum()
        return k+1, new_policy, best_policy, best_Delta, policy_diff
    
    best_policy = init_policy
    best_Delta = jnp.inf
    _, _, best_policy, _, _ = jax.lax.while_loop(condition_fn, loop_fn, (0, init_policy, best_policy, best_Delta, jnp.inf))
   
    return best_policy


@jax.jit
def update_outer_EF(args, init_k: int, end_k: int, rcmdp: RCMDP, inner_iter: int=1000, inner_lr: float = 0.001):
    """Update the threshold variable (end_k - init_k) times.
    See Algorithm 2 in the paper.

    Args:
        init_args: These arguments will be passed to the foriloop of jax. See InitEFArgs defined above.
        init_k (int): initial update index
        end_k (int): end of update index

    Returns:
        args: Computed arguments. See InitEFArgs defined above.
    """
    def body_fn(k, args):
        res_J0U_list, vios_list, policy, i, j = args
        b0 = (i + j) / 2
        policy = solve_inner_EF(b0, rcmdp, policy, inner_iter, inner_lr)
        _, _, worst_P_J = compute_policy_worst_values(policy, rcmdp)
        b0_threshes = jnp.hstack([jnp.array([b0,]), rcmdp.threshes])
        Delta = jnp.max(worst_P_J - b0_threshes)

        i = jax.lax.cond(Delta > 0, lambda: b0, lambda: i)
        j = jax.lax.cond(Delta <= 0, lambda: b0, lambda: j)

        res_J0U_list = res_J0U_list.at[k].set(worst_P_J[0])
        vio = worst_P_J[1:] - rcmdp.threshes
        vios_list = vios_list.at[k].set(vio.max())
        return res_J0U_list, vios_list, policy, i, j

    args = jax.lax.fori_loop(init_k, end_k, body_fn, args)
    return args


In [6]:
@partial(jax.pmap, in_axes=(0, None, None, 0, 0))
def update_Args(seed, init_k, end_k, LagArgs, EFArgs):
    rcmdp = create_rcmdp(seed)
    LagArgs = update_outer_Lagrange(LagArgs, init_k, end_k, rcmdp)
    EFArgs = update_outer_EF(EFArgs, init_k, end_k, rcmdp)

    uniform_policy = jnp.ones((S, A)) / A
    _, _, worst_P_J = compute_policy_worst_values(uniform_policy, rcmdp)
    UJ = worst_P_J[0]
    Uv = jnp.max(worst_P_J[1:] - rcmdp.threshes)
    return LagArgs, EFArgs, UJ, Uv


In [7]:
from copy import deepcopy
from tqdm import tqdm

UNROLL_ITER = 5
seeds = jnp.arange(NUM_SEEDS)
LagArgs, EFArgs = deepcopy(InitLagArgs), deepcopy(InitEFArgs)
for i in tqdm(range(int(ITER_LENGTH / UNROLL_ITER))):
    LagArgs, EFArgs, Uni_J0U_list, Uni_vio_list = update_Args(seeds, UNROLL_ITER * i, UNROLL_ITER * (i+1), LagArgs, EFArgs)
Lag_J0U_list, Lag_vio_list, Lag_J0U_avg_list, Lag_vio_avg_list, *_ = LagArgs
EF_J0U_list, EF_vio_list, *_ = EFArgs

100%|██████████| 200/200 [02:36<00:00,  1.28it/s]


In [11]:
Uni_J0U_list_rep = jnp.repeat(Uni_J0U_list.reshape(-1, 1), ITER_LENGTH, axis=1)
Uni_vio_list_rep = jnp.repeat(Uni_vio_list.reshape(-1, 1), ITER_LENGTH, axis=1)
J_baseval = Uni_J0U_list.reshape(-1, 1)

Unif, LF, LFavg, EF = r"Uniform policy ($\pi_{\mathrm{unif}}$)", "LF-PGS", "LF-PGS-avg", r"$\mathbf{EpiRC\operatorname{-}PGS\;(Ours)}$"

algos =  {Unif: (Uni_J0U_list_rep - J_baseval, Uni_vio_list_rep),
          LF: (Lag_J0U_list - J_baseval, Lag_vio_list), 
          LFavg: (Lag_J0U_avg_list - J_baseval, Lag_vio_avg_list), 
          EF: (EF_J0U_list - J_baseval, EF_vio_list), 
}


In [12]:
import pickle

with open(f"results/{FIGNAME}.pkl", "wb") as f:
    pickle.dump(algos, f)

In [14]:
algos

{'Uniform policy ($\\pi_{\\mathrm{unif}}$)': (Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  Array([[0.19898933, 0.19898933, 0.19898933, ..., 0.19898933, 0.19898933,
          0.19898933],
         [0.22319311, 0.22319311, 0.22319311, ..., 0.22319311, 0.22319311,
          0.22319311],
         [0.01922429, 0.01922429, 0.01922429, ..., 0.01922429, 0.01922429,
          0.01922429]], dtype=float32)),
 'LF-PGS': (Array([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [-0.03162563, -0.03162563, -0.03162563, ..., -0.03162563,
          -0.03162563, -0.03162563],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]], dtype=float32),
  Array([[0.19898933, 0.19898933, 0.19898933, ..., 0.19898933, 0.19898933,
          0.19898933],
         [0.24164563, 0.24164563, 0.24164563, ..., 0.24164563, 0.24164