In [1]:
import ray
from ray.rllib import agents
from ray import tune
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models import ModelCatalog
from ray.rllib.utils import try_import_tf
import gym
from gym import spaces
import or_gym
from or_gym.utils.env_config import *
from or_gym.algos.rl_utils import create_env
import numpy as np
import time
from copy import deepcopy
import matplotlib.pyplot as plt
%matplotlib inline
ray.init(ignore_reinit_error=True)
tf = try_import_tf()

2020-05-08 12:40:32,703	INFO resource_spec.py:216 -- Starting Ray with 1.76 GiB memory available for workers and up to 0.88 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).


In [2]:
env = or_gym.make('SCSched-v3')
s = env.reset()
s

{'action_mask': array([1, 0, 1]),
 'avail_actions': array([1., 1., 1.]),
 'state': array([0.        , 0.        , 4.58033246, 0.        , 0.        ,
        0.        , 5.        , 0.        , 0.        ])}

In [3]:
config = agents.a3c.DEFAULT_CONFIG

In [18]:
def test(s):
    avail_actions = s['avail_actions']
    avail_actions[0] = 0
    action_mask = s['action_mask']
    action_embed_model = FullyConnectedNetwork(
        spaces.Box(0, 1, shape=(9,)), env.action_space, 3, config['model'], 'scsched3')
    action_embedding, _ = action_embed_model({'obs': s['state'].reshape(1, -1)})
    intent_vector = tf.expand_dims(action_embedding, 1)
    action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=1)
    inf_mask = tf.maximum(tf.log(action_mask.astype('float')), tf.float32.min)
#     output = tf.reduce_sum(action_embedding + inf_mask, 1)
    return action_embedding, intent_vector, action_logits, inf_mask#, output

In [19]:
with tf.Session() as sess:
    a, b, c, d = test(s)
    init = tf.global_variables_initializer()
    sess.run(init)
    out = (a.eval(), b.eval(), c.eval(), d.eval())
out

(array([[ 0.00544533, -0.00944633, -0.0045386 ]], dtype=float32),
 array([[[ 0.00544533, -0.00944633, -0.0045386 ]]], dtype=float32),
 array([[ 0.        , -0.00944633, -0.0045386 ]], dtype=float32),
 array([ 0.00000000e+00, -3.40282347e+38,  0.00000000e+00]))

In [6]:
out[2] + out[3]

array([[ 0.00000000e+00, -3.40282347e+38, -1.61058793e-03]])

In [20]:
out[0] + out[3]

array([[ 5.44532668e-03, -3.40282347e+38, -4.53860499e-03]])

In [30]:
class SCSchedMaskModel2(TFModelV2):
    
    def __init__(self, obs_space, action_space, num_outputs,
        model_config, name, true_obs_shape=(9,), action_embed_size=3,
        *args, **kwargs):
        super(SCSchedMaskModel2, self).__init__(obs_space,
            action_space, num_outputs, model_config, name, *args, **kwargs)
        self.action_embed_model = FullyConnectedNetwork(
            spaces.Box(0, 100, shape=true_obs_shape), action_space, action_embed_size,
            model_config, name + "_action_embedding")
        self.register_variables(self.action_embed_model.variables())
        
    def forward(self, state):
        action_mask = state["action_mask"]
        action_embedding, _ = self.action_embed_model({
            "obs": state["state"].reshape(1, -1)
        })
        intent_vector = tf.expand_dims(action_embedding, 1)
        action_logits = tf.reduce_sum(action_mask * intent_vector, axis=1)
        return action_logits, intent_vector
    
    def value_function(self):
        return self.action_embed_model.value_function()

In [31]:
env = or_gym.make('SCSched-v3')
state = env.reset()
print(state)
with tf.Session() as sess:
    model = SCSchedMaskModel2(
        env.observation_space, env.action_space, env.action_space.shape[0],
        config['model'], 'scsched')
    action_logits, intent_vector = model.forward(state)
    init = tf.global_variables_initializer()
    sess.run(init)
    out = (action_logits.eval(), intent_vector.eval())
    
out

{'action_mask': array([1, 0, 1]), 'avail_actions': array([1., 1., 1.]), 'state': array([0.        , 0.        , 6.64042218, 0.        , 0.        ,
       0.        , 5.        , 0.        , 0.        ])}


(array([[ 0.00182259, -0.        , -0.00247612]], dtype=float32),
 array([[[ 0.00182259, -0.0039382 , -0.00247612]]], dtype=float32))