In [1]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm

  '{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)


In [50]:
class CommNet:
    
    def __init__(self, sess, N, J, embedding_size = 128, lr = 1e-3, training_mode = 'supervised'):
        
        self.N = N
        self.J = J
        self.embedding_size = embedding_size
        
        self.build_controler()
        
        if training_mode == 'supervised':
            self.build_supervised()
            with tf.variable_scope('Supervised_optimizer'):
                self.train_op = tf.train.AdamOptimizer(lr).minimize(self.supervised_loss)
                
        elif training_mode == 'reinforce':
            self.build_reinforce()
            
        else:
            raise(ValueError("Unknown training mode: %s" % training_mode))
        
        print("All variables")
        for var in tf.global_variables():
            print(var)
            
        
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        
    def encode(self, inputs):
        
        with tf.variable_scope('Encoder'):
        
            identity_embeddings = tf.get_variable("identity_embeddings",
                                             [self.N, self.embedding_size])
            
            embedded_identities = tf.nn.embedding_lookup(identity_embeddings, inputs)
            
        return tf.unstack(embedded_identities, axis = 1)
    
    def build_f(self, name, h, c, h0 = None):
        
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
            
            if h0 is not None:
                
                W1 = tf.get_variable('W1', shape = (3 * self.embedding_size,
                                                  self.embedding_size))
                
                concat = tf.concat([h, c, h0], axis = 1)
            
            else:
                W1 = tf.get_variable('W1', shape = (2 * self.embedding_size,
                                                  self.embedding_size))
                
                concat = tf.concat([h, c], axis = 1)
            
            W2 = tf.get_variable('W2', shape = (self.embedding_size,
                                                  self.embedding_size))
            
            dense1 =  tf.nn.relu(tf.einsum("ij,jk->ik", concat, W1))
            dense2 = tf.nn.relu(tf.einsum("ij,jk->ik", dense1, W2))
            
            return dense2
        
    def decode(self, h):
        
        with tf.variable_scope('Decoder', reuse = tf.AUTO_REUSE):
            
            W = tf.get_variable('W', shape = (self.embedding_size,
                                                  self.J))
            
            policy_logits = tf.einsum("ij,jk->ik", h, W)
        
            return policy_logits
    
    
    def communicate(self, h_seq):
        
        return tf.add_n(h_seq) / (self.J - 1)
    
    def sample_actions(self, policy_logit):
        
        
        action = tf.multinomial(policy_logit, num_samples = 1)
        
        return action      
    
        
    def build_controler(self):
        
        self.inputs = tf.placeholder(tf.int32, shape = (None, self.J))
        
        h0_seq = self.encode(self.inputs)
        c0_seq = [self.communicate([h0_seq[j] for j in range(self.J) if j != i]) for i in range(self.J)]
        
        h1_seq = [self.build_f("Comm_step_1", h0_seq[j], c0_seq[j], None) for j in range(self.J)]
        c1_seq = [self.communicate([h1_seq[j] for j in range(self.J) if j != i]) for i in range(self.J)]
        
        h2_seq = [self.build_f("Comm_step_2", h1_seq[j], c1_seq[j], h0_seq[j]) for j in range(self.J)]
        
        
        self.policy_logit_seq = [self.decode(h2) for h2 in h2_seq]
        
        self.actions = [self.sample_actions(policy_logit) for policy_logit in self.policy_logit_seq]
        
        one_hot_actions = [tf.one_hot(tf.reshape(action, [-1]), depth = self.J) for action in self.actions]
        
        self.reward = tf.reduce_sum(tf.count_nonzero(tf.add_n(one_hot_actions), axis = 1) / self.J)
        
    def build_supervised(self):
        
        self.targets = tf.placeholder(tf.int32, shape = (None, self.J))
        unstacked_targets = tf.unstack(self.targets, axis = 1)
        
        supervised_loss_seq = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=unstacked_targets[j],
                                                                                   logits=self.policy_logit_seq[j])
                                    for j in range(self.J)]
        
        self.supervised_loss = tf.reduce_sum(tf.add_n(supervised_loss_seq))
        
        
        
    def supervised_train(self, X, y, val_X, val_y, batch_size = 32, epochs = 1):
        
        n = X.shape[0]
        
        val_n = val_X.shape[0]
        
        data_inds = np.array(range(n))
        for ep in range(1, epochs + 1):
            np.random.shuffle(data_inds)
            supervised_loss_sum = 0
            reward_sum = 0
            for i in tqdm(range(0, n, batch_size), "Epoch: %d" % ep):
                inds_batch = data_inds[i:i+batch_size]
                X_batch = X[inds_batch]
                y_batch = y[inds_batch]
                _, supervised_loss, reward = sess.run([self.train_op, self.supervised_loss, self.reward], feed_dict={self.inputs: X_batch, self.targets: y_batch})
                supervised_loss_sum += supervised_loss
                reward_sum += reward
            
            print("loss = %f" % (supervised_loss_sum / n))
            print("reward = %f" % (reward_sum / n))
            print()
            
            val_supervised_loss, val_reward = sess.run([self.supervised_loss, self.reward], feed_dict={self.inputs: val_X, self.targets: val_y})
            print('val loss = %f' % (val_supervised_loss / val_n))
            print('val reward = %f' % (val_reward / val_n))
            

    def build_reinforce(self):
        
        log_p_seq = [tf.log(tf.nn.softmax(policy_logit)) for policy_logit in self.policy_logit_seq]
        
        self.advantage = tf.placeholder(tf.float32, shape = (None, self.J))
        unstacked_advantage = tf.unstack(self.advantage, axis = 1)
        
        self.actions_taken = tf.placeholder(tf.int32, shape = (None, self.J))
        unstacked_actions_taken = tf.unstack(self.actions_taken, axis = 1)
        
        print(log_p_seq[0])
        print(unstacked_actions_taken[0])
        
        self.actions_taken_p_seq = [tf.gather(log_p_seq[j], unstacked_actions_taken[j], axis = 1) for j in range(self.J)]
        
        #surrogate loss
        self.reinforce_loss =  - tf.add_n([tf.reduce_sum(tf.multiply(unstacked_advantage[j], self.actions_taken_p_seq[j])) for j in range(self.J)])
        
        
        

In [3]:
def generate_data(n, N, J):
    
    X = np.empty((n, J), dtype = int)
    
    for i in range(n):
        
        X[i] = np.sort(np.random.choice(N, size = J, replace = False))
        
    y = np.tile([j for j in range(J)], (n,1))
    
    return X, y


In [4]:
N = 100
J = 5
batch_size = 32
n = batch_size * 1

In [5]:
X, y = generate_data(n, N, J)
val_X, val_y = generate_data(1024, N, J)

In [52]:
with tf.Graph().as_default(), tf.Session() as sess:
    
    commNet = CommNet(sess, N, J, lr = 1e-4, training_mode = 'supervised')
    commNet.supervised_train(X, y, val_X, val_y, batch_size = batch_size, epochs = 10)
    
    #print(sess.run(commNet.actions_taken_p_seq, feed_dict={commNet.inputs: X, commNet.advantage: np.ones((1,5)), commNet.actions_taken: np.ones((1,5))}))


All variables
<tf.Variable 'Encoder/identity_embeddings:0' shape=(100, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/W1:0' shape=(256, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/W2:0' shape=(128, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/W1:0' shape=(384, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/W2:0' shape=(128, 128) dtype=float32_ref>
<tf.Variable 'Decoder/W:0' shape=(128, 5) dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/beta1_power:0' shape=() dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/beta2_power:0' shape=() dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/Encoder/identity_embeddings/Adam:0' shape=(100, 128) dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/Encoder/identity_embeddings/Adam_1:0' shape=(100, 128) dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/Comm_step_1/W1/Adam:0' shape=(256, 128) dtype=float32_ref>
<tf.Variable 'Supervised_optimizer/Comm_step_1/W1/Adam_1:0' shape=(256, 128) dtype=float32_ref>
<tf.Variab

Epoch: 1: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.27it/s]


loss = 8.050602
reward = 0.637500

val loss = 8.048738
val reward = 0.673828


Epoch: 2: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 47.65it/s]


loss = 8.038939
reward = 0.668750

val loss = 8.045919
val reward = 0.667383


Epoch: 3: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 76.97it/s]


loss = 8.027385
reward = 0.631250

val loss = 8.043117
val reward = 0.677734


Epoch: 4: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 71.48it/s]


loss = 8.015928
reward = 0.643750

val loss = 8.040318
val reward = 0.675000


Epoch: 5: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 66.70it/s]


loss = 8.004602
reward = 0.650000

val loss = 8.037531
val reward = 0.674609


Epoch: 6: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 66.72it/s]


loss = 7.993379
reward = 0.618750

val loss = 8.034752
val reward = 0.675586


Epoch: 7: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 43.51it/s]


loss = 7.982247
reward = 0.693750

val loss = 8.031969
val reward = 0.675195


Epoch: 8: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 52.67it/s]


loss = 7.971165
reward = 0.706250

val loss = 8.029185
val reward = 0.667773


Epoch: 9: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 66.72it/s]


loss = 7.960152
reward = 0.681250

val loss = 8.026405
val reward = 0.658789


Epoch: 10: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 58.88it/s]


loss = 7.949202
reward = 0.693750

val loss = 8.023636
val reward = 0.677344
