In [101]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt

In [109]:
class CommNet:
    
    def __init__(self, sess, N, J, embedding_size = 128, lr = 1e-3, training_mode = 'supervised', alpha = 0.03):
        
        self.N = N
        self.J = J
        self.embedding_size = embedding_size
        
        self.build_controler()
        
        self.training_mode = training_mode
        
        if training_mode == 'supervised':
            self.build_supervised()
            with tf.variable_scope('Supervised_optimizer'):
                self.train_op = tf.train.RMSPropOptimizer(lr).minimize(self.supervised_loss)
                
        elif training_mode == 'reinforce':
            self.alpha = 0.03
            self.build_reinforce()
            with tf.variable_scope('Reinforce_optimizer'):
                self.train_op =  tf.train.RMSPropOptimizer(lr).minimize(self.reinforce_loss)
            
        else:
            raise(ValueError("Unknown training mode: %s" % training_mode))
        
        print("All variables")
        for var in tf.global_variables():
            print(var)
            
        
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        
    def encode(self, inputs):
        
        with tf.variable_scope('Encoder'):
        
            identity_embeddings = tf.get_variable("identity_embeddings",
                                             [self.N, self.embedding_size])
            
            self.embedded_identities = tf.nn.embedding_lookup(identity_embeddings, inputs)
        
            
        return tf.unstack(self.embedded_identities, axis = 1)
    
    def build_f(self, name, h, c, h0 = None):
        
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
            
            if h0 is not None:
                
                b1 = tf.get_variable('b1', shape = (1, self.embedding_size))
                W1 = tf.get_variable('W1', shape = (3 * self.embedding_size,
                                                  self.embedding_size))
                
                W2 = tf.get_variable('W2', shape = (self.embedding_size,
                                                  self.embedding_size))
                
                concat = tf.concat([h, c, h0], axis = 1)
            
            else:
                b1 = tf.get_variable('b1', shape = (1, self.embedding_size))
                
                W1 = tf.get_variable('W1', shape = (2 * self.embedding_size,
                                                  self.embedding_size))
                
                W2 = tf.get_variable('W2', shape = (self.embedding_size,
                                                  self.embedding_size))
                
                concat = tf.concat([h, c], axis = 1)
            
            b2 = tf.get_variable('b2', shape = (1, self.embedding_size))
            
            dense1 =tf.nn.relu(tf.einsum("ij,jk->ik", concat, W1) + b1)
            dense2 = tf.nn.relu(tf.einsum("ij,jk->ik", dense1, W2) + b2)
            
            return dense2
        
    def decode(self, h):
        
        with tf.variable_scope('Decoder', reuse = tf.AUTO_REUSE):
            
            W = tf.get_variable('W', shape = (self.embedding_size,
                                                  self.J))
            
            b = tf.get_variable('b', shape = (1, self.J))
            
            policy_logit = tf.einsum("ij,jk->ik", h, W) + b
        
            return policy_logit
    
    
    def communicate(self, h_seq):
        
        return tf.add_n(h_seq) / (self.J - 1)
    
    def sample_actions(self, policy_logit):
        
        action = tf.multinomial(policy_logit, num_samples = 1)
        
        return action
    
        
    def build_controler(self):
        
        self.inputs = tf.placeholder(tf.int32, shape = (None, self.J))
        
        h0_seq = self.encode(self.inputs)
        c0_seq = [self.communicate([h0_seq[j] for j in range(self.J) if j != i]) for i in range(self.J)]
        
        h1_seq = [self.build_f("Comm_step_1", h0_seq[j], c0_seq[j], None) for j in range(self.J)]
        c1_seq = [self.communicate([h1_seq[j] for j in range(self.J) if j != i]) for i in range(self.J)]
        
        h2_seq = [self.build_f("Comm_step_2", h1_seq[j], c1_seq[j], h0_seq[j]) for j in range(self.J)]
        
        self.layers = {'h0_seq': h0_seq, 'c0_seq': c0_seq, 'h1_seq': h1_seq, 'c1_seq':c1_seq, 'h2_seq': h2_seq}
        
        
        self.policy_logit_seq = [self.decode(h2) for h2 in h2_seq]
        
        self.proba_seq = [tf.nn.softmax(policy_logit, axis = 1) for policy_logit in self.policy_logit_seq]
        
        self.action_seq = [self.sample_actions(policy_logit) for policy_logit in self.policy_logit_seq]
        
        self.one_hot_action_seq = [tf.one_hot(action, depth = self.J) for action in self.action_seq]
        
        
        
    def build_supervised(self):
        
        assert self.training_mode == 'supervised', 'Wrong training mode'
        
        self.targets = tf.placeholder(tf.int32, shape = (None, self.J))
        unstacked_targets = tf.unstack(self.targets, axis = 1)
        
        supervised_loss_seq = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=unstacked_targets[j],
                                                                                   logits=self.policy_logit_seq[j])
                                    for j in range(self.J)]
        
        self.supervised_loss = tf.reduce_mean(supervised_loss_seq)
        
        
    def supervised_train(self, X, y, val_X, val_y, env, batch_size = 32, epochs = 1):
        
        assert self.training_mode == 'supervised', 'Wrong training mode'
        
        n = X.shape[0]
        
        val_n = val_X.shape[0]
        
        data_inds = np.array(range(n))
        for ep in range(1, epochs + 1):
            np.random.shuffle(data_inds)
            supervised_loss_sum = 0
            reward_sum = 0
            for i in tqdm_notebook(range(0, n, batch_size), "Epoch: %d" % ep):
                inds_batch = data_inds[i:i+batch_size]
                X_batch = X[inds_batch]
                y_batch = y[inds_batch]
                _, supervised_loss, one_hot_action_seq = sess.run([self.train_op, self.supervised_loss, self.one_hot_action_seq], feed_dict={self.inputs: X_batch, self.targets: y_batch})
                supervised_loss_sum += supervised_loss * batch_size
                reward_sum += env.get_reward(one_hot_action_seq)
            
            print("loss = %f" % (supervised_loss_sum / n))
            print("reward = %f" % (reward_sum / n))
            print()
            
            val_supervised_loss, val_one_hot_action_seq = sess.run([self.supervised_loss, self.one_hot_action_seq], feed_dict={self.inputs: val_X, self.targets: val_y})
            print('val loss = %f' % (val_supervised_loss))
            print('val reward = %f' % (env.get_reward(val_one_hot_action_seq) / val_n))
            

    def build_reinforce(self):
        
        assert self.training_mode == 'reinforce', 'Wrong training mode'
        
        
        self.state_q_val_seq = tf.placeholder(tf.float32, shape = (None, self.J, self.J))
        
        self.reward_sum_values = tf.placeholder(tf.float32, shape = (None,))
        self.advantage_values = tf.placeholder(tf.float32, shape = (None,))
  
        
        self.action_taken = tf.placeholder(tf.int32, shape = (None, self.J))
        unstacked_action_taken = tf.unstack(self.action_taken, axis = 1)
        one_hot_action_taken_seq = [tf.one_hot(action, depth = self.J, dtype = tf.int32) for action in unstacked_action_taken]
        
        self.neg_log_p_seq = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=unstacked_action_taken[j],
                                                    logits=self.policy_logit_seq[j]) for j in range(self.J)]
        
        neg_log_p_sums = tf.reduce_sum(self.neg_log_p_seq, axis = 0)
        
        
        stacked_proba_seq = tf.stack(self.proba_seq, axis = 1)
        
        baseline = tf.einsum('ijk, ijk-> ij', self.state_q_val_seq, stacked_proba_seq)
        baseline = tf.reduce_mean(baseline, axis = 1)
        
        #surrogate loss (- dtheta)
        advantage = self.reward_sum_values - baseline
    
        self.reinforce_loss =  tf.multiply(neg_log_p_sums, self.advantage_values)
        self.reinforce_loss += self.alpha * tf.square(advantage)
        self.reinforce_loss = tf.reduce_sum(self.reinforce_loss, axis = 0)
        
        
    def take_action(self, state):
        
        assert self.training_mode == 'reinforce', 'Wrong training mode'
        
        action_seq, proba_seq = self.sess.run([self.action_seq, self.proba_seq], {self.inputs: [state]})
        
        return [a[0,0] for a in action_seq], np.array(proba_seq)
    
    def reinforce_train(self, env, n_episodes, T):
        
        assert self.training_mode == 'reinforce', 'Wrong training mode'
        
        
        history = {'reward' : [],  'loss': []}
        
        q_reward_sum = np.zeros((self.N, self.J))
        q_state_action_count = np.zeros((self.N, self.J))
        q_val = np.zeros((self.N, self.J))
    
        
        for _ in tqdm(range(n_episodes), "REINFORCE"):
            
            # todo: change code to avoid this seq_seq name (sequence of sequence)
            state_seq, action_seq, reward_seq, proba_seq = policy_rollout(T, env, commNet)
            episode_len = reward_seq.shape[0]
            
            history['reward'].append(np.mean(reward_seq))
        
            state_q_val_seq = np.array([q_val[state] for state in state_seq])
            
            baseline = np.einsum('ijk, ijk-> ij', state_q_val_seq, proba_seq)
            
            baseline = baseline.mean(axis = 1)
            
            reward_sum_values = np.array([reward_seq[t:].sum() for t in range(episode_len)])
            advantage_values = np.array([reward_sum_values[t] - baseline[t] for t in range(episode_len)])          
            
            
            feed_dict = {}
            feed_dict[self.inputs] = state_seq
            feed_dict[self.state_q_val_seq] = state_q_val_seq
            feed_dict[self.reward_sum_values] = reward_sum_values
            feed_dict[self.advantage_values] = advantage_values
            feed_dict[self.action_taken] = action_seq
            
            _, loss = self.sess.run([self.train_op, self.reinforce_loss], feed_dict = feed_dict)
            
            history['loss'].append(loss)  
            
            # udpate_q_val
            for i in range(episode_len):
                
                state = state_seq[i]
                action = action_seq[i]
                
                cummul_reward = reward_seq[i:].sum()
                q_state_action_count[state, action] += 1
                q_reward_sum[state, action] += cummul_reward
            
            q_val = q_reward_sum/np.maximum(1,q_state_action_count)
            
            
        return history, q_val
            
            
            
            
            
            
            
        

In [110]:
class LeverEnv:
    
    def __init__(self, N, J):
        
        self.J = J
        self.N = N
        
    def reset(self):
        
        state = np.random.choice(self.N, size = self.J, replace = False)
        
        terminal_state = False
        
        return state, terminal_state
    
    def get_reward(self, one_hot_action_seq):        
        
        reward = np.sum(np.sum(one_hot_action_seq, axis = 0) > 0) /self.J
        
        return reward
        
    def step(self, state, action):
        
        next_state = np.random.choice(self.N, size = self.J, replace = False)
        
        one_hot_action_seq = np.zeros((self.J, self.J))
        one_hot_action_seq[range(self.J), action] = 1
        reward = self.get_reward(one_hot_action_seq)
        
        terminal_state = False
        
        return next_state, reward, terminal_state
        

In [160]:
# data generation for supervised learning
def generate_data(n, N, J):
    
    X = np.empty((n, J), dtype = int)
    y= np.empty((n,J), dtype = int)
    
    for i in range(n):
        
        X[i] = np.random.choice(N, size = J, replace = False)
        sorted_args = np.argsort(X[i])
        y[i] = np.argsort(sorted_args)
        
    return X, y

In [112]:
# episode generation for reinforcement learning
def policy_rollout(T, env, agent):
    
    state_seq = []
    action_seq = []
    reward_seq = []
    proba_seq = []
    
    
    state, terminal_state = env.reset()
    
    t = 0
    
    while not terminal_state and t < T:
        t +=1
        
        state_seq.append(state)
        action, proba = agent.take_action(state)
        
        state, reward, terminal_state = env.step(state, action)
        
        
        action_seq.append(action)
        reward_seq.append(reward)
        proba_seq.append(proba)
        
    return np.array(state_seq), np.array(action_seq), np.array(reward_seq), np.squeeze(np.array(proba_seq))

In [113]:
N = 500
J = 3
batch_size = 32
n = batch_size * 1000

In [114]:
X, y = generate_data(n, N, J)
val_X, val_y = generate_data(500, N, J)

In [115]:
tf.reset_default_graph()
with tf.Session() as sess:
    commNet = CommNet(sess, N, J, lr = 1e-3, embedding_size= 128, training_mode = 'supervised', alpha = 0.1)
    env = LeverEnv(N, J)
    commNet.supervised_train(X, y, val_X, val_y, env, batch_size = batch_size, epochs = 30)
    
    #history, q_val = commNet.reinforce_train(env, n_episodes = 5000, T = 7)
    
    rv = sess.run([commNet.embedded_identities, commNet.layers, commNet.policy_logit_seq, commNet.supervised_loss, commNet.action_seq, commNet.one_hot_action_seq], feed_dict={commNet.inputs: val_X[0:1], commNet.targets: val_y[0:1]})


All variables
<tf.Variable 'Encoder/identity_embeddings:0' shape=(500, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/b1:0' shape=(1, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/W1:0' shape=(256, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/W2:0' shape=(128, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/b2:0' shape=(1, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/b1:0' shape=(1, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/W1:0' shape=(384, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/W2:0' shape=(128, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/b2:0' shape=(1, 128) dtype=float32_ref>
<tf.Variable 'Decoder/W:0' shape=(128, 3) dtype=float32_ref>
<tf.Variable 'Decoder/b:0' shape=(1, 3) dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/Encoder/identity_embeddings/RMSProp:0' shape=(500, 128) dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/Encoder/identity_embeddings/RMSProp_1:0' shape=(500, 128) dtype=float32_ref>
<tf.Variable 'Supervised_op


loss = 1.046343
reward = 0.713635

val loss = 0.967683
val reward = 0.748000



loss = 0.926915
reward = 0.748479

val loss = 0.964891
val reward = 0.739333



loss = 0.905397
reward = 0.756250

val loss = 0.956542
val reward = 0.749333



loss = 0.883013
reward = 0.761406

val loss = 0.963858
val reward = 0.787333



loss = 0.855678
reward = 0.770073

val loss = 0.983786
val reward = 0.766000



loss = 0.820209
reward = 0.778333

val loss = 0.994558
val reward = 0.764000



loss = 0.780347
reward = 0.788104

val loss = 1.037243
val reward = 0.782667



loss = 0.738372
reward = 0.798771

val loss = 1.208681
val reward = 0.826000



loss = 0.692366
reward = 0.808854

val loss = 1.162292
val reward = 0.803333



loss = 0.646388
reward = 0.818448

val loss = 1.208937
val reward = 0.793333



loss = 0.600537
reward = 0.829250

val loss = 1.410644
val reward = 0.798667



loss = 0.555645
reward = 0.839823

val loss = 1.450270
val reward = 0.810000



loss = 0.512064
reward = 0.850708

val loss = 1.754566
val reward = 0.837333



loss = 0.472376
reward = 0.858802

val loss = 1.727940
val reward = 0.826667



loss = 0.432218
reward = 0.868906

val loss = 2.000592
val reward = 0.835333



loss = 0.391994
reward = 0.879292

val loss = 2.011152
val reward = 0.846667



loss = 0.357823
reward = 0.887865

val loss = 2.370091
val reward = 0.865333



loss = 0.329082
reward = 0.895135

val loss = 2.578494
val reward = 0.870667



loss = 0.297596
reward = 0.904938

val loss = 2.667833
val reward = 0.865333



loss = 0.273147
reward = 0.911979

val loss = 3.050503
val reward = 0.884667



loss = 0.248165
reward = 0.918094

val loss = 3.110831
val reward = 0.882667



loss = 0.229260
reward = 0.923979

val loss = 3.457235
val reward = 0.898667



loss = 0.205809
reward = 0.930781

val loss = 3.335571
val reward = 0.875333



loss = 0.193066
reward = 0.935865

val loss = 3.507915
val reward = 0.890000



loss = 0.178722
reward = 0.939813

val loss = 3.516766
val reward = 0.895333



loss = 0.159585
reward = 0.945094

val loss = 3.851482
val reward = 0.880000



loss = 0.151104
reward = 0.948302

val loss = 4.248665
val reward = 0.898667



loss = 0.140939
reward = 0.952708

val loss = 4.308733
val reward = 0.888667



loss = 0.127657
reward = 0.955219

val loss = 4.675236
val reward = 0.888667



loss = 0.121728
reward = 0.958302

val loss = 4.386762
val reward = 0.884000


In [116]:
rv[0]

array([[[-0.2084906 ,  0.02374804, -0.00896381, -0.01363329,
         -0.01539735, -0.09960759,  0.00879658,  0.02616302,
         -0.03708452,  0.11100191, -0.09006239,  0.12490398,
         -0.1830035 , -0.01612187, -0.04865932,  0.11230216,
         -0.08000849, -0.01113982,  0.0566847 ,  0.10024989,
         -0.08895486,  0.1845117 ,  0.22353396,  0.000383  ,
          0.07099406,  0.05801363,  0.18270227, -0.13345873,
         -0.10249335,  0.091339  ,  0.03591219, -0.0090279 ,
          0.10476949,  0.21619219, -0.03079632,  0.12995882,
          0.02130782,  0.03541506,  0.07031996,  0.0742506 ,
         -0.00989546, -0.05115569,  0.2525144 ,  0.03887724,
         -0.01513926,  0.07744752,  0.05067027,  0.08630602,
          0.1428998 ,  0.18778773,  0.07030316,  0.16404508,
          0.06847449, -0.09552611,  0.07967106,  0.07937705,
          0.07148624,  0.12983364,  0.18970098,  0.05084126,
          0.02083551,  0.03974245,  0.01596422,  0.12931432,
         -0.07611421, -0

In [124]:
for k,v in rv[1].items():
    print()
    print(k)
    print(v)


h0_seq
[array([[-0.2084906 ,  0.02374804, -0.00896381, -0.01363329, -0.01539735,
        -0.09960759,  0.00879658,  0.02616302, -0.03708452,  0.11100191,
        -0.09006239,  0.12490398, -0.1830035 , -0.01612187, -0.04865932,
         0.11230216, -0.08000849, -0.01113982,  0.0566847 ,  0.10024989,
        -0.08895486,  0.1845117 ,  0.22353396,  0.000383  ,  0.07099406,
         0.05801363,  0.18270227, -0.13345873, -0.10249335,  0.091339  ,
         0.03591219, -0.0090279 ,  0.10476949,  0.21619219, -0.03079632,
         0.12995882,  0.02130782,  0.03541506,  0.07031996,  0.0742506 ,
        -0.00989546, -0.05115569,  0.2525144 ,  0.03887724, -0.01513926,
         0.07744752,  0.05067027,  0.08630602,  0.1428998 ,  0.18778773,
         0.07030316,  0.16404508,  0.06847449, -0.09552611,  0.07967106,
         0.07937705,  0.07148624,  0.12983364,  0.18970098,  0.05084126,
         0.02083551,  0.03974245,  0.01596422,  0.12931432, -0.07611421,
        -0.10948015, -0.12645112, -0.00573

         0.05166071,  0.11984118, -0.02007674]], dtype=float32)]

h1_seq
[array([[0.        , 0.6125327 , 0.06860468, 0.01899664, 0.        ,
        0.        , 0.        , 0.28293478, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.5098363 , 0.        ,
        0.        , 0.02267392, 0.        , 0.        , 0.6906497 ,
        0.        , 0.5174747 , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.10795878, 0.        ,
        0.        , 0.        , 0.39895058, 0.        , 0.        ,
        0.        , 0.00454222, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.10060106, 0.        , 0.        

        0.        , 0.06264307, 0.2215037 ]], dtype=float32)]

h2_seq
[array([[2.3113449 , 0.        , 0.        , 0.53008044, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        3.842878  , 0.3507939 , 0.        , 1.6253008 , 0.        ,
        0.        , 0.        , 0.        , 0.        , 2.0973408 ,
        0.        , 0.        , 0.        , 0.        , 1.2628742 ,
        0.        , 0.        , 0.        , 0.6609036 , 0.        ,
        0.        , 4.836417  , 0.        , 3.2073948 , 0.        ,
        0.        , 0.        , 1.5880854 , 0.        , 0.70575255,
        0.35698032, 0.        , 0.        , 0.68910795, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.25589725,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        2.9065537 , 0.        , 0.        , 2.5989685 , 0.        ,
        0.94587994, 0.        , 0.        , 0.        , 4.037211  ,
        0.        , 0.        , 0.        , 0

In [125]:
print(val_X[0:1])
print(val_y[0:1])
print(rv[2])

[[309 398 193]]
[[2 0 1]]
[array([[-11.986955,   2.68754 , -12.87569 ]], dtype=float32), array([[-20.87973   , -14.591713  ,   0.43009904]], dtype=float32), array([[  6.545301,  -9.575035, -17.306194]], dtype=float32)]


In [131]:
for x in rv[2]:
    for xx in x:
        print(np.exp(xx)/np.sum(np.exp(x)))

[4.2359187e-07 9.9999934e-01 1.7417049e-07]
[5.5623539e-10 2.9930214e-07 9.9999970e-01]
[9.9999994e-01 9.9776194e-08 4.3795266e-11]


In [127]:
rv[3]

17.664467

In [128]:
rv[4]

[array([[1]], dtype=int64),
 array([[2]], dtype=int64),
 array([[0]], dtype=int64)]

In [129]:
rv[5]

[array([[[0., 1., 0.]]], dtype=float32),
 array([[[0., 0., 1.]]], dtype=float32),
 array([[[1., 0., 0.]]], dtype=float32)]

In [130]:
plt.plot(history['reward'], '*')
plt.show()
plt.plot(history['loss'], '*')
plt.show()

NameError: name 'history' is not defined

In [None]:
q_val.mean(axis = 1)