### Test compute log probabilities

In [1]:
import torch 
import torch.nn as nn 
from torch.autograd import Variable
import torch.functional as F

from collections import namedtuple
import numpy as np


def make_var_id_actions(actions):
    n_id, arg_ids = actions 
    args_var = {}
    fn_id_var = Variable(torch.LongTensor(n_id))
    for k, v in arg_ids.items():
        args_var[k] = Variable(torch.LongTensor(v))
    return fn_id_var, args_var


def make_var_probs_actions(actions):
    n_id, arg_ids = actions 
    args_var = {}
    fn_id_var = Variable(torch.FloatTensor(n_id))
    for k, v in arg_ids.items():
        args_var[k] = Variable(torch.FloatTensor(v))
    return fn_id_var, args_var


def mask_unavailable_actions(available_actions, fn_pi):
    fn_pi = fn_pi * available_actions
    fn_pi = fn_pi / fn_pi.sum(1, keepdim=True)
    return fn_pi


def compute_policy_log_probs(available_actions, policy, actions_var):
    def logclip(x):
        return torch.log(torch.clamp(x, 1e-12, 1.0))

    def compute_log_probs(probs, labels):
        new_labels = labels.clone()
        new_labels[new_labels < 0] = 0
        selected_probs = probs.gather(1, new_labels.unsqueeze(1))
        out = logclip(selected_probs)
        # Log of 0 will be 0
        out[selected_probs == 0] = 0
        return out.view(-1)

    fn_id, arg_ids = actions_var
    fn_pi, arg_pis = policy
    
    fn_pi = mask_unavailable_actions(available_actions, fn_pi)
    fn_log_prob = compute_log_probs(fn_pi, fn_id)

    log_prob = fn_log_prob
    for arg_type in arg_ids.keys():
        arg_id = arg_ids[arg_type]
        arg_pi = arg_pis[arg_type]
        arg_log_prob = compute_log_probs(arg_pi, arg_id)

        arg_id_masked = arg_id.clone()
        arg_id_masked[arg_id_masked != -1] = 1
        arg_id_masked[arg_id_masked == -1] = 0
        arg_log_prob = arg_log_prob * arg_id_masked.float()
        log_prob = log_prob + arg_log_prob
    return log_prob

In [2]:
TestArgType = namedtuple('ArgType', ['name'])
arg_type = TestArgType('arg')
A = np.array

available_actions = A([[1, 0, 1],
                           [1, 0, 0],
                           [1, 1, 1]], dtype=np.float32)

fn_pi = A([[0.2, 0.0, 0.8],
           [1.0, 0.0, 0.0],
           [0.2, 0.7, 0.1]], dtype=np.float32)

fn_ids = A([2, 0, 1], dtype=np.int32)

arg_pi = {arg_type: A([[0.0, 1.0],
                       [0.0, 1.0],
                       [0.5, 0.5]], dtype=np.float32)}

arg_ids = {arg_type: A([0, 1, -1], dtype=np.int32)}

policy_var = make_var_probs_actions((fn_pi, arg_pi))
actions_var = make_var_id_actions((fn_ids, arg_ids))
available_actions = Variable(torch.Tensor(available_actions))

log_probs = compute_policy_log_probs(
  available_actions, policy_var, actions_var
)

In [5]:
print(log_probs)

Variable containing:
-0.2231
 0.0000
-0.3567
[torch.FloatTensor of size (3,)]



### Test in Tensorflow

In [4]:
import tensorflow as tf
from collections import namedtuple
import numpy as np


def safe_log(x):
    return tf.where(tf.equal(x, 0),
                  tf.zeros_like(x),
                  tf.log(tf.maximum(1e-12, x)))

def mask_unavailable_actions(available_actions, fn_pi):
    fn_pi *= available_actions
    fn_pi /= tf.reduce_sum(fn_pi, axis=1, keep_dims=True)
    return fn_pi

def compute_policy_log_probs(available_actions, policy, actions):
    def compute_log_probs(probs, labels):
         # Select arbitrary element for unused arguments (log probs will be masked)
        labels = tf.maximum(labels, 0)
        indices = tf.stack([tf.range(tf.shape(labels)[0]), labels], axis=1)
        return safe_log(tf.gather_nd(probs, indices)) # TODO tf.log should suffice

    fn_id, arg_ids = actions
    fn_pi, arg_pis = policy
    fn_pi = mask_unavailable_actions(available_actions, fn_pi) # TODO: this should be unneccessary
    fn_log_prob = compute_log_probs(fn_pi, fn_id)

    log_prob = fn_log_prob
    for arg_type in arg_ids.keys():
        arg_id = arg_ids[arg_type]
        arg_pi = arg_pis[arg_type]
        arg_log_prob = compute_log_probs(arg_pi, arg_id)
        arg_log_prob_masked = arg_log_prob * tf.to_float(tf.not_equal(arg_id, -1))
        log_prob += arg_log_prob_masked

    return log_prob, fn_log_prob, arg_log_prob, arg_log_prob_masked

  return f(*args, **kwds)


In [5]:
TestArgType = namedtuple('ArgType', ['name'])
arg_type = TestArgType('arg')
A = np.array

available_actions = A([[1, 0, 1],
                           [1, 0, 0],
                           [1, 1, 1]], dtype=np.float32)

fn_pi = A([[0.2, 0.0, 0.8],
           [1.0, 0.0, 0.0],
           [0.2, 0.7, 0.1]], dtype=np.float32)

fn_ids = A([2, 0, 1], dtype=np.int32)

arg_pi = {arg_type: A([[0.0, 1.0],
                       [0.0, 1.0],
                       [0.5, 0.5]], dtype=np.float32)}

arg_ids = {arg_type: A([0, 1, -1], dtype=np.int32)}

log_probs, fn_log_probs, arg_log_probs, arg_log_probs_masked = compute_policy_log_probs(
      available_actions, (fn_pi, arg_pi), (fn_ids, arg_ids))

with tf.Session() as sess:
    log_probs_out, fn_log_probs_out, arg_log_probs_out, arg_log_probs_masked_out = sess.run([log_probs, fn_log_probs, \
                                                                    arg_log_probs, arg_log_probs_masked])
    
print(log_probs_out)
print(fn_log_probs_out)
print(arg_log_probs_out)
print(arg_log_probs_masked_out)

[-0.22314353  0.         -0.35667497]
[-0.22314353  0.         -0.35667497]
[ 0.          0.         -0.69314718]
[ 0.  0. -0.]


### Test Compute Entropy

In [7]:
import torch 
import torch.nn as nn 
from torch.autograd import Variable
import torch.functional as F

from collections import namedtuple
import numpy as np


def make_var_id_actions(actions):
    n_id, arg_ids = actions 
    args_var = {}
    fn_id_var = Variable(torch.LongTensor(n_id))
    for k, v in arg_ids.items():
        args_var[k] = Variable(torch.LongTensor(v))
    return fn_id_var, args_var


def make_var_probs_actions(actions):
    n_id, arg_ids = actions 
    args_var = {}
    fn_id_var = Variable(torch.FloatTensor(n_id))
    for k, v in arg_ids.items():
        args_var[k] = Variable(torch.FloatTensor(v))
    return fn_id_var, args_var


def mask_unavailable_actions(available_actions, fn_pi):
    fn_pi = fn_pi * available_actions
    fn_pi = fn_pi / fn_pi.sum(1, keepdim=True)
    return fn_pi


def compute_policy_entropy(available_actions, policy, actions_var):
    def logclip(x):
        return torch.log(torch.clamp(x, 1e-12, 1.0))

    def compute_entropy(probs):
        return -(logclip(probs) * probs).sum(-1)

    _, arg_ids = actions_var    
    fn_pi, arg_pis = policy
    fn_pi = mask_unavailable_actions(available_actions, fn_pi)

    entropy = compute_entropy(fn_pi).mean()
    for arg_type in arg_ids.keys():
        arg_id = arg_ids[arg_type]
        arg_pi = arg_pis[arg_type]

        batch_mask = arg_id.clone()
        batch_mask[batch_mask != -1] = 1
        batch_mask[batch_mask == -1] = 0
        # Reference: https://discuss.pytorch.org/t/how-to-use-condition-flow/644/4
        if (batch_mask == 0).all():
            arg_entropy = (compute_entropy(arg_pi) * 0.0).sum()
        else:
            arg_entropy = (compute_entropy(arg_pi) * batch_mask.float()).sum() / batch_mask.float().sum()
        entropy = entropy + arg_entropy
    return entropy

In [8]:
available_actions = A([[1, 0, 1],
                           [1, 0, 0],
                           [1, 1, 1]], dtype=np.float32)

fn_pi = A([[0.2, 0.0, 0.8],
           [1.0, 0.0, 0.0],
           [0.2, 0.7, 0.1]], dtype=np.float32)

fn_ids = A([2, 0, 1], dtype=np.int32)

arg_pi = {arg_type: A([[0.8, 0.2],
                       [0.0, 1.0],
                       [0.5, 0.5]], dtype=np.float32)}

arg_ids = {arg_type: A([0, 1, -1], dtype=np.int32)}

policy_var = make_var_probs_actions((fn_pi, arg_pi))
actions_var = make_var_id_actions((fn_ids, arg_ids))
available_actions = Variable(torch.Tensor(available_actions))

entropy = compute_policy_entropy(
      available_actions, policy_var, actions_var
    )

In [9]:
print(entropy)

Variable containing:
 0.6843
[torch.FloatTensor of size (1,)]



In [11]:
expected_entropy = (0.50040245 + 0.80181855) / 3.0 + (0.50040245) / 2
print(expected_entropy)

0.6842748916666668
