# To do:
- adam optimizer within

In [1]:
from absl.testing import absltest
from absl.testing import parameterized

import copy
import hypothesis.extra.numpy

import jax.test_util
import jax.numpy as jnp
from jax import random
from jax import tree_util
from jax.experimental import optimizers
from jax.scipy.special import logsumexp
from jax.experimental.stax import softmax
from jax.config import config
from jax.random import bernoulli
from jax import jacfwd

from jax import lax
from fax import converge
from fax import test_util
from fax.constrained import make_lagrangian
from fax.constrained import cga_lagrange_min
from fax.constrained import cga_ecp
from fax.constrained import slsqp_ecp
from fax.constrained import implicit_ecp


# check device
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [2]:
true_transition = jnp.array([[[0.9, 0.1], [0.2, 0.8]],
                             [[0.8, 0.2], [0.99, 0.01]]])
temperature = 1e-2

true_discount = 0.9

traj_len = 100

initial_distribution = jnp.ones(2) / 2

policy_expert = jnp.array(([[0.4, 0.6],
                            [0.4,  0.6]]))
key = random.PRNGKey(0)


In [3]:
def get_new_key():
    global key
    mykey, subkey = random.split(key)
    key = subkey


def roll_out(last_state, last_action, p, model):
    global key
    get_new_key()
    s = bernoulli(key, p=p[last_action][last_state][0]).astype(int)
    get_new_key()
    a = bernoulli(key, p=model[s][0]).astype(int)
    return (s, a)


def sample_trajectory(policy):
    get_new_key()
    s = bernoulli(key, p=initial_distribution[0]).astype(int)
    get_new_key()
    a = bernoulli(key, p=policy[s][0]).astype(int)
    traj = []
    traj.append((s, a))
    for i in range(traj_len-1):
        s, a = roll_out(s, a, true_transition, policy)
        traj.append((s, a))
    return jnp.array(copy.deepcopy(traj))


# @jax.jit
def ratio_loss(discriminator_logits, traj_model, traj_expert):
    discriminator = softmax((1. / temperature) * discriminator_logits)
    loss = 0
    for i in range(traj_len):
        s_expert, a_expert = traj_expert[i]
        s_model, a_model = traj_model[i]
        loss += -jnp.log(discriminator[s_expert][a_expert]) - jnp.log(1 - discriminator[s_model][a_model])
    return loss/traj_len

def generator(discriminator_logits, traj_model):
    gen_losses = []
    discriminator = softmax((1. / temperature) * discriminator_logits)
    loss = 0
    for i in range(traj_len):
        s_model, a_model = traj_model[i]
        gen_losses.append(jnp.log(discriminator[s_model][a_model]))
    return gen_losses

# initialize parameters

In [15]:
discriminator_logits = jnp.ones((2,2))

model_logits = jnp.ones((2,2))

opt_init, opt_update, get_params = optimizers.adam(step_size=0.001)
opt_state = opt_init(discriminator_logits)

opt_init2, opt_update2, get_params2 = optimizers.adam(step_size=0.001)
opt_state2 = opt_init2(model_logits)

# update discriminator

In [16]:
def update_discriminator(i, discriminator_logits, traj_model, traj_expert, opt_state, opt_update, get_params):
    ratio_loss_grad = jax.grad(ratio_loss, 0)(discriminator_logits, traj_model, traj_expert)
    opt_state = opt_update(i, ratio_loss_grad, opt_state)
    discriminator_logits = get_params(opt_state)
    return opt_state, discriminator_logits

In [6]:
discriminator_logits

DeviceArray([[1., 1.],
             [1., 1.]], dtype=float32)

# update ratio

In [17]:
def get_log_policy(model_logits, s, a):
    policy_model = softmax((1. / temperature) * model_logits)
    return jnp.log(policy_model[s][a])

policy_grad = jax.grad(get_log_policy)


In [4]:
def discounted_reward(t, gen_losses, gamma = 0.9):
    discounted = [ gamma**(i-t) * gen_losses[i] for i in range(t, len(gen_losses))]
    G = jnp.array(discounted).sum()
    return G

def advantage(t, gen_losses, k, gamma=0.9):
    reward = discounted_reward(0, gen_losses[t:t+k], gamma=gamma)
#     print (reward)
    A = reward + (gamma ** k) * discounted_reward(t + k, gen_losses, gamma = gamma) - discounted_reward(t, gen_losses, gamma = gamma) 
#     print (discounted_reward(t + k, gen_losses, gamma = gamma))
#     print (discounted_reward(t, gen_losses, gamma = gamma))
    return A

# gen_losses = [1,2,3,4,5,6,7,8]
# advantage(0, gen_losses, 1, gamma = 0.1)

def GAE(t, gen_losses, gamma=0.9, lamb=0.9):
    gae = 0 
    for k in range(1,len(gen_losses)-t):
#         print (k)
#         print (advantage (t, gen_losses, k, gamma))
        gae += (lamb ** (k-t)) * advantage (t, gen_losses, k, gamma)
    return gae * (1-lamb)

    

In [28]:
def update_generator(i, model_logits, discriminator_logits, traj_model, opt_state2, opt_update2, get_params2):
    generator_losses = generator(discriminator_logits, traj_model)
    generator_grad = 0
    
    for t in range(traj_len):
        #grad of log policy
        s_model, a_model = traj_model[i]
        grad_log_policy = policy_grad(model_logits, s_model, a_model) 
        reward = discounted_reward(t, generator_losses)
        generator_grad += grad_log_policy * reward
        
    generator_grad = generator_grad/traj_len
    opt_state2 = opt_update2(t, generator_grad, opt_state2)
    model_logits = get_params2(opt_state2)

    return opt_state2, model_logits

In [9]:
print ("model_logits:", model_logits)
print ("policy", softmax((1. / temperature) * model_logits))


model_logits: [[1. 1.]
 [1. 1.]]
policy [[0.5 0.5]
 [0.5 0.5]]


In [29]:
for i in range(50):
    policy_model = softmax((1. / temperature) * model_logits)
    traj_model = sample_trajectory(policy_model)
    traj_expert = sample_trajectory(policy_expert)
    opt_state, discriminator_logits = update_discriminator(i, discriminator_logits, traj_model, traj_expert, opt_state, opt_update, get_params)
   
    print ("discriminator_logits:", discriminator_logits.flatten())
    opt_state2, model_logits = update_generator(i, model_logits, discriminator_logits, traj_model, opt_state2, opt_update2, get_params2)
    
    print ("model_logits:", model_logits.flatten())
    print ("policy", softmax((1. / temperature) * model_logits).flatten())
    print ("")
    
    


discriminator_logits: [0.999842   1.0001581  1.002008   0.99799204]
model_logits: [0.9990242 1.0009757 1.        1.       ]
policy [0.4513662 0.5486338 0.5       0.5      ]

discriminator_logits: [1.0004891  0.999511   1.0023987  0.99760133]
model_logits: [0.9981456 1.0018544 1.0009757 0.9990242]
policy [0.40832838 0.59167165 0.5486338  0.4513662 ]

discriminator_logits: [1.0005208  0.99947923 1.0022682  0.99773186]
model_logits: [0.998383  1.0016171 1.0018544 0.9981456]
policy [0.4198449  0.5801551  0.59167165 0.40832838]

discriminator_logits: [1.0005476  0.99945235 1.0021287  0.99787134]
model_logits: [0.9985967 1.0014033 1.0015652 0.9984348]
policy [0.43029132 0.56970865 0.5776294  0.42237058]

discriminator_logits: [1.0004758  0.99952424 1.0018859  0.9981142 ]
model_logits: [0.99830127 1.0016989  1.0013049  0.9986952 ]
policy [0.41586784 0.5841322  0.5648731  0.43512687]

discriminator_logits: [1.0005394 0.9994605 1.0016824 0.9983178]
model_logits: [0.9980352  1.0019649  1.0005482

KeyboardInterrupt: 

- baseline in reinforce

- implicit differentiation

- romina: reinforce generalize advantage estimator in jax