In [174]:
from absl.testing import absltest
from absl.testing import parameterized

import copy
import collections 

import jax.test_util
import jax.numpy as jnp

from jax import lax, vjp, custom_vjp, grad, jacrev, jacfwd, random, tree_util, jacfwd
from jax.experimental import optimizers
from jax.scipy.special import logsumexp
from jax.experimental.stax import softmax
from jax.config import config
from jax.random import bernoulli
from jax.numpy.linalg import norm

from fax import converge, test_util
from fax.constrained import implicit_ecp
from fax.loop import fixed_point_iteration
from fax.implicit.twophase import make_adjoint_fixed_point_iteration
from fax.implicit.twophase import make_forward_fixed_point_iteration

# check device
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

FixedPointSolution = collections.namedtuple(
    "FixedPointSolution",
    "value converged iterations previous_value"
)

#just in time compilation
from jax import jit
config.update("jax_enable_x64", True)
config.update('jax_disable_jit', True)

gpu


In [187]:
true_transition = jnp.array([[[0.9, 0.1], [0.2, 0.8]],
                             [[0.8, 0.2], [0.99, 0.01]]])
temperature = 1e-2

true_discount = 0.9

traj_len = 75

initial_distribution = jnp.ones(2) / 2

policy_expert = jnp.array(([[0.45, 0.55],
                            [0.55,  0.45]]))
key = random.PRNGKey(0)


In [211]:
def get_new_key():
    global key
    mykey, subkey = random.split(key)
    key = subkey


def roll_out(last_state, last_action, p, model):
    global key
    get_new_key()
    s = bernoulli(key, p=p[last_action][last_state][0]).astype(int)
    get_new_key()
    a = bernoulli(key, p=model[s][0]).astype(int)
    return (s, a)


def sample_trajectory(policy):
    get_new_key()
    s = bernoulli(key, p=initial_distribution[0]).astype(int)
    get_new_key()
    a = bernoulli(key, p=policy[s][0]).astype(int)
    traj = []
    traj.append((s, a))
    for i in range(traj_len-1):
        s, a = roll_out(s, a, true_transition, policy)
        traj.append((s, a))
    return jnp.array(copy.deepcopy(traj))

#ratio_loss
def L(theta, w, traj_model, traj_expert):
    del theta
    discriminator = softmax((1. / temperature) * w)    
    loss = 0
    for i in range(traj_len):
        s_expert, a_expert = traj_expert[i]
        s_model, a_model = traj_model[i]
        loss += - jnp.log(discriminator[s_expert][a_expert]) - jnp.log(1 - discriminator[s_model][a_model])
    return loss/traj_len

F = grad(L, (1))

# generator loss
def J(theta, w, traj_model):
    del theta
    discriminator = softmax((1. / temperature) * w)
    loss = 0
    for i in range(traj_len):
        s_model, a_model = traj_model[i]
        loss += jnp.log(discriminator[s_model][a_model])
    return loss / traj_len

# initialize parameters

In [191]:
# w = jnp.array([[1.0, 1.1],[1.02,0.99]])
# theta = jnp.array([[1.1, 1.0],[0.99,1.0]])

w = jnp.ones((2,2))
theta = jnp.ones((2,2))

# find constraint solution (w*)

In [192]:
def constraints_solver(theta, w, max_iter=100, threshold = 1e-3):
    
    #initialize optimizer
    opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
    opt_state = opt_init(w)
    prev = w
    
    policy_model = softmax((1. / temperature) * theta)
    
    for i in range(max_iter):
        traj_model = sample_trajectory(policy_model)
        traj_expert = sample_trajectory(policy_expert)
        
        ratio_grad = F(theta, w, traj_model, traj_expert)
        opt_state = opt_update(i, ratio_grad, opt_state)
        w = get_params(opt_state)
    
        #check threshold
        if i > 0 and jnp.max(jnp.abs(w - prev)) <= threshold:
            return FixedPointSolution(
                value=w,
                converged=True,
                iterations=i,
                previous_value=prev,
            )
        if i < max_iter - 1:
            prev = w
    
    return FixedPointSolution(
        value=w,
        converged=False,
        iterations=max_iter,
        previous_value=prev,
    )


In [180]:
forward_solution = constraints_solver(theta, w)
forward_solution

FixedPointSolution(value=DeviceArray([[1.0063526 , 0.9936474 ],
             [0.99571118, 1.00428882]], dtype=float64), converged=True, iterations=8, previous_value=DeviceArray([[1.00623082, 0.99376918],
             [0.9957778 , 1.0042222 ]], dtype=float64))

In [193]:
policy_model = softmax((1. / temperature) * theta)
traj_model = sample_trajectory(policy_model)
traj_expert = sample_trajectory(policy_expert)

# find dJ/dw

In [194]:
theta = jnp.ones((2,2))
grad(J, (1))(theta, w, traj_model)

DeviceArray([[-2.66666667,  2.66666667],
             [-0.66666667,  0.66666667]], dtype=float64)

In [195]:
dJ_dw  = grad(J, argnums=1)(theta, w, traj_model)
dJ_dw

DeviceArray([[-2.66666667,  2.66666667],
             [-0.66666667,  0.66666667]], dtype=float64)

# find dJ/dtheta

In [209]:
def get_log_policy(theta, s, a):
    policy_model = softmax((1. / temperature) * theta)
    return jnp.log(policy_model[s][a])

policy_grad = jax.grad(get_log_policy, (0))


def discounted_reward(t, rewards, gamma = 0.9):
    discounted = [ gamma**(i-t) * rewards[i] for i in range(t, len(rewards))]
    G = jnp.array(discounted).sum()
    return G


def reinforce(theta, traj_model, rewards):
    estimator = 0
    for t in range(len(theta)):
        #grad of log policy
        s_model, a_model = traj_model[t]
        grad_log_policy = policy_grad(theta, s_model, a_model) 
        reward = discounted_reward(t, rewards)
        estimator += grad_log_policy * reward
    return estimator / traj_len


def get_dJ_dtheta(theta, w, traj_model):
    discriminator = softmax((1. / temperature) * w)
    rewards = []
    for i in range(len(traj_model)):
        s_model,a_model = traj_model[i]
        rewards.append(jnp.log(discriminator[s_model][a_model]))
    
    return reinforce(theta, traj_model, rewards)



In [12]:
dJ_dtheta = get_dJ_dtheta(theta, w, traj_model)
dJ_dtheta

DeviceArray([[ 3.46563362, -3.46563362],
             [-3.46564385,  3.46564385]], dtype=float64)

# find dF/dtheta

In [199]:
def dis_temp(w, s_model, a_model):
    discriminator = softmax((1. / temperature) * w)
    return jnp.log(1-discriminator[s_model][a_model])

grad_dis_w = jax.grad(dis_temp)

def get_dF_dtheta(theta, w, traj_model):
    rewards = []
    for i in range(len(traj_model)):
        s_model,a_model = traj_model[i]
        rewards.append(-grad_dis_w(w, s_model, a_model))
    return reinforce(theta, traj_model, rewards)


In [200]:
dF_dtheta = get_dF_dtheta(theta, w, traj_model)
dF_dtheta

DeviceArray([[0., 0.],
             [0., 0.]], dtype=float64)

# find dF/dw

In [201]:
get_dF_dw = jax.hessian(L, argnums=1)
dF_dw = get_dF_dw(theta, w, traj_model, traj_expert)

In [216]:
jnp.linalg.inv(dF_dw.reshape((4,4)))

DeviceArray([[ 2.75730589e+12,  2.75730589e+12,  0.00000000e+00,
               0.00000000e+00],
             [ 2.75730589e+12,  2.75730589e+12,  0.00000000e+00,
               0.00000000e+00],
             [ 0.00000000e+00,  0.00000000e+00, -1.09951163e+12,
              -1.09951163e+12],
             [-0.00000000e+00, -0.00000000e+00, -1.09951163e+12,
              -1.09951163e+12]], dtype=float64)

# find implicit gradient

In [46]:
dphi_dtheta = (jnp.linalg.inv(dF_dw.reshape((4,4))).dot(dF_dtheta.flatten()))
dphi_dtheta = dphi_dtheta.reshape(2,2)
implicit_grad = dJ_dw.dot(dphi_dtheta) + dJ_dtheta

# go

In [None]:
def implicit_diff(theta, w, max_iter=100, threshold = 1e-3):
    
    #initialize optimizer
    opt_init, opt_update, get_params = optimizers.adam(step_size=0.001)
    opt_state = opt_init(theta)
    prev = theta
    
    for i in range(max_iter):
        
        prev_theta = theta
        print (i)
        # get converged discriminator logits
        forward_solution = constraints_solver(theta, w)
        assert forward_solution.converged == True 
        w = forward_solution.value
        print ("w", w)
        
        policy_model = softmax((1. / temperature) * theta)
        traj_model = sample_trajectory(policy_model)
        traj_expert = sample_trajectory(policy_expert)
        
        dJ_dtheta = get_dJ_dtheta(theta, w, traj_model)
        print ("dJ_dtheta", dJ_dtheta)
        dJ_dw  = grad(J, (1))(theta, w, traj_model)
        print ("dJ_dw",  dJ_dw)
        dF_dtheta = get_dF_dtheta(theta, w, traj_model)
        print ("dF_dtheta", dF_dtheta)
        dF_dw = get_dF_dw(theta, w, traj_model, traj_expert)
        print ("dF_dw", dF_dw.reshape((4,4)))
        print ("inverse: ", jnp.linalg.inv(dF_dw.reshape((4,4))))
        
        dphi_dtheta = (jnp.linalg.pinv(dF_dw.reshape((4,4))).dot(dF_dtheta.flatten()))
        dphi_dtheta = dphi_dtheta.reshape(2,2)
        implicit_grads = dJ_dw.dot(dphi_dtheta) + dJ_dtheta
        print ("implicit_grads", implicit_grads)
        
        opt_state = opt_update(i, implicit_grads, opt_state)
        theta = get_params(opt_state)
        policy_model = softmax((1. / temperature) * theta)
        print ("theta", theta)
        print ("policy", policy_model)
        print ("")
        #check threshold
        if i > 0 and jnp.max(jnp.abs(theta - prev)) <= threshold:
            return theta
        if i < max_iter - 1:
            prev = theta
    
    print ("not converged")
    return theta
    

# use LINALG.SOLVE

In [None]:
# w = jnp.array([[0.99, 0.998],[1.01, 1.1]])
# theta = jnp.array([[0.99, 1.01],[1.0, 0.9]])

#collect samples 

w = jnp.ones((2,2))
theta = jnp.ones((2,2))

implicit_diff(theta, w,threshold = 1e-5)

In [None]:
w = jnp.ones((2,2))
theta = jnp.ones((2,2))

initial_values = (theta, w)
opt_init, opt_update, get_params = optimizers.adam(step_size=0.001)

x0, init_params = initial_values
opt_state = opt_init(init_params)

def update(i, values):
    w, opt_state = values
    theta = get_params(opt_state)

    # get converged discriminator logits
    forward_solution = constraints_solver(theta, w)
    assert forward_solution.converged == True 
    w = forward_solution.value
    
    traj_model = sample_trajectory(policy_model)
    dJ_dtheta = get_dJ_dtheta(theta, w, traj_model)
    dJ_dw  = grad(J, (1))(theta, w, traj_model)
    dF_dtheta = get_dF_dtheta(theta, w, traj_model)
    dF_dw = get_dF_dw(theta, w, traj_model, traj_expert)
    
    dphi_dtheta = (jnp.linalg.inv(dF_dw.reshape((4,4))).dot(dF_dtheta.flatten()))
    dphi_dtheta = dphi_dtheta.reshape(2,2)
    implicit_grads = dJ_dw.dot(dphi_dtheta) + dJ_dtheta
    opt_state = opt_update(i, implicit_grads, opt_state)

    return forward_solution.value, opt_state

In [None]:
def convergence_test(x_new, x_old):
    min_type = converge.tree_smallest_float_dtype(x_new)
    rtol, atol = converge.adjust_tol_for_dtype(1e-10, 1e-10, min_type)
    return converge.max_diff_test(x_new, x_old, rtol, atol)

def _convergence_test(new_state, old_state):
    x_new, params_new = new_state[0], get_params(new_state[1])
    x_old, params_old = old_state[0], get_params(old_state[1])
    return convergence_test((x_new, params_new), (x_old, params_old))

solution = fixed_point_iteration(init_x=(x0, opt_state),
                                  func=update,
                                  convergence_test=_convergence_test,
                                  max_iter=50,
                                  batched_iter_size=1,
                                  unroll=False)
