In [61]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm

In [59]:
class CommNet:
    
    def __init__(self, sess, N, J, K, embedding_size = 128, lr = 1e-3):
        
        self.N = N
        self.J = J
        self.K = K
        self.embedding_size = embedding_size
        
        self.build_controler()
        with tf.variable_scope('Optimizer'):
            self.train_op = tf.train.AdamOptimizer(lr).minimize(self.supervised_loss)
        
        for var in tf.global_variables():
            print(var)
        
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        
    def encode(self, inputs):
        
        with tf.variable_scope('Encoder'):
        
            identity_embeddings = tf.get_variable("identity_embeddings",
                                             [self.N, self.embedding_size])
            
            embedded_identities = tf.nn.embedding_lookup(identity_embeddings, inputs)
            
        return tf.unstack(embedded_identities, axis = 1)
    
    def build_f(self, name, h, c, h0 = None):
        
        with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
            
            if h0 is not None:
                
                W1 = tf.get_variable('W1', shape = (3 * self.embedding_size,
                                                  self.embedding_size))
                
                concat = tf.concat([h, c, h0], axis = 1)
            
            else:
                W1 = tf.get_variable('W1', shape = (2 * self.embedding_size,
                                                  self.embedding_size))
                
                concat = tf.concat([h, c], axis = 1)
            
            W2 = tf.get_variable('W2', shape = (self.embedding_size,
                                                  self.embedding_size))
            
            dense1 =  tf.nn.relu(tf.einsum("ij,jk->ik", concat, W1))
            dense2 = tf.nn.relu(tf.einsum("ij,jk->ik", dense1, W2))
            
            return dense2
        
    def decode(self, h):
        
        with tf.variable_scope('Decoder', reuse = tf.AUTO_REUSE):
            
            W = tf.get_variable('W', shape = (self.embedding_size,
                                                  self.J))
            
            policy_logits = tf.einsum("ij,jk->ik", h, W)
        
            return policy_logits
    
    
    def communicate(self, h_seq):
        
        return tf.add_n(h_seq) / (self.J - 1)
    
    def sample_one_hot_actions(self, policy_logit):
        
        
        action = tf.multinomial(policy_logit, num_samples = 1)
        
        one_hot_action = tf.one_hot(tf.reshape(action, [-1]), depth = self.J)
        
        return one_hot_action
        
    
        
    def build_controler(self):
        
        self.inputs = tf.placeholder(tf.int32, shape = (None, self.J))
        
        h0_seq = self.encode(self.inputs)
        c0_seq = [self.communicate([h0_seq[j] for j in range(self.J) if j != i]) for i in range(self.J)]
        
        h1_seq = [self.build_f("Comm_step_1", h0_seq[j], c0_seq[j], None) for j in range(self.J)]
        c1_seq = [self.communicate([h1_seq[j] for j in range(self.J) if j != i]) for i in range(self.J)]
        
        h2_seq = [self.build_f("Comm_step_2", h1_seq[j], c1_seq[j], h0_seq[j]) for j in range(self.J)]
        
        
        self.policy_logit_seq = [self.decode(h2) for h2 in h2_seq]

        self.targets = tf.placeholder(tf.int32, shape = (None, self.J))
        unstacked_targets = tf.unstack(self.targets, axis = 1)
        
        supervised_loss_seq = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=unstacked_targets[j],
                                                                                   logits=self.policy_logit_seq[j])
                                    for j in range(self.J)]
        
        self.supervised_loss = tf.reduce_sum(tf.add_n(supervised_loss_seq))
        
        self.actions = [self.sample_one_hot_actions(policy_logit) for policy_logit in self.policy_logit_seq] 
        

        self.reward = tf.reduce_sum(tf.count_nonzero(tf.add_n(self.actions), axis = 1) / self.J)
        
        
    def train(self, X, y, val_X, val_y, batch_size = 32, epochs = 1):
        
        n = X.shape[0]
        
        val_n = val_X.shape[0]
        
        data_inds = np.array(range(n))
        for ep in range(1, epochs + 1):
            np.random.shuffle(data_inds)
            loss_sum = 0
            reward_sum = 0
            for i in tqdm(range(0, n, batch_size), "Epoch: %d" % ep):
                inds_batch = data_inds[i:i+batch_size]
                X_batch = X[inds_batch]
                y_batch = y[inds_batch]
                _, loss, reward = sess.run([self.train_op, self.supervised_loss, self.reward], feed_dict={self.inputs: X_batch, self.targets: y_batch})
                loss_sum += loss
                reward_sum += reward
                
            print("loss = %f" % (loss_sum / n))
            print("reward = %f" % (reward_sum / n))
            print()
            
            val_supervised_loss, val_reward = sess.run([self.supervised_loss, self.reward], feed_dict={self.inputs: val_X, self.targets: val_y})
            print('val loss = %f' % (val_supervised_loss / val_n))
            print('val reward = %f' % (val_reward / val_n))

        

In [49]:
def generate_data(n, N, J):
    
    X = np.empty((n, J), dtype = int)
    
    for i in range(n):
        
        X[i] = np.sort(np.random.choice(N, size = J, replace = False))
        
    y = np.tile([j for j in range(J)], (n,1))
        
    return X, y


In [50]:
N = 100
J = 5
K = 2
batch_size = 32
n = batch_size * 1000

In [51]:
X, y = generate_data(n, N, J)
val_X, val_y = generate_data(1024, N, J)

In [60]:
tf.reset_default_graph()
with tf.Session() as sess:
    
    commNet = CommNet(sess, N, J, K, lr = 1e-4)
    commNet.train(X, y, val_X, val_y, batch_size = batch_size, epochs = 10)

<tf.Variable 'Encoder/identity_embeddings:0' shape=(100, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/W1:0' shape=(256, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_1/W2:0' shape=(128, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/W1:0' shape=(384, 128) dtype=float32_ref>
<tf.Variable 'Comm_step_2/W2:0' shape=(128, 128) dtype=float32_ref>
<tf.Variable 'Decoder/W:0' shape=(128, 5) dtype=float32_ref>
<tf.Variable 'Optimizer/beta1_power:0' shape=() dtype=float32_ref>
<tf.Variable 'Optimizer/beta2_power:0' shape=() dtype=float32_ref>
<tf.Variable 'Optimizer/Encoder/identity_embeddings/Adam:0' shape=(100, 128) dtype=float32_ref>
<tf.Variable 'Optimizer/Encoder/identity_embeddings/Adam_1:0' shape=(100, 128) dtype=float32_ref>
<tf.Variable 'Optimizer/Comm_step_1/W1/Adam:0' shape=(256, 128) dtype=float32_ref>
<tf.Variable 'Optimizer/Comm_step_1/W1/Adam_1:0' shape=(256, 128) dtype=float32_ref>
<tf.Variable 'Optimizer/Comm_step_1/W2/Adam:0' shape=(128, 128) dtype=float32_ref>
<tf.Var

Epoch: 1: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 91.90it/s]


loss = 3.753242
reward = 0.742094

val loss = 2.017161
val reward = 0.805273


Epoch: 2: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 99.06it/s]


loss = 1.551881
reward = 0.838050

val loss = 1.128531
val reward = 0.872656


Epoch: 3: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 94.66it/s]


loss = 0.836161
reward = 0.903356

val loss = 0.618213
val reward = 0.927734


Epoch: 4: 100%|███████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 101.53it/s]


loss = 0.443674
reward = 0.946125

val loss = 0.335846
val reward = 0.954492


Epoch: 5: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 96.74it/s]


loss = 0.236337
reward = 0.969225

val loss = 0.200270
val reward = 0.973828


Epoch: 6: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:11<00:00, 89.58it/s]


loss = 0.133150
reward = 0.981781

val loss = 0.113380
val reward = 0.983789


Epoch: 7: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 96.02it/s]


loss = 0.075745
reward = 0.989094

val loss = 0.078622
val reward = 0.989844


Epoch: 8: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 99.31it/s]


loss = 0.044250
reward = 0.993100

val loss = 0.055486
val reward = 0.992578


Epoch: 9: 100%|███████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 100.82it/s]


loss = 0.026490
reward = 0.995512

val loss = 0.036524
val reward = 0.994531


Epoch: 10: 100%|███████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 98.31it/s]


loss = 0.015945
reward = 0.997437

val loss = 0.047365
val reward = 0.995508
