In [25]:
import numpy as np
import matplotlib.pyplot as plt

# Network
import torch
from torch import autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

# Optimizer
import torch.optim as optim

In [26]:
# Game setup
num_agents = 2   # Number of agents playing the game
num_types = 3    # Number of item types
max_item = 5     # Maximum number of each item in a pool
max_utility = 10 # Maximum utility value for agents
num_games = 128  # Number of games per episode

# Turn sampling
lam = 7     # Poisson parameter
max_N = 10  # Maximum number of turns
min_N = 4   # Minimum number of turns

# Linguistic channel
num_vocab = 10   # Symbol vocabulary size for linguistic channel
len_message = 6  # Linguistic message length

# Appendix
lambda1 = 0.05  # Entropy regularizer for pi_term, pi_prop
lambda2 = 0.001 # Entropy regularizer for pi_utt
smoothing_const = 0.7 # Smoothing constant for the exponential moving average baseline

In [27]:
# Sample an item pool for a game
def create_item_pool(num_types, max_item, batch_size):
    # Possible to have zero items?
    pool = np.random.randint(0, max_item+1, (batch_size,num_types))
    return torch.from_numpy(pool).long()
        
# Sample agent utility
def create_agent_utility(num_types, max_utility, batch_size):
    utility = np.zeros((batch_size,num_types)) # Initialize zero vector
    
    while 0 in np.sum(utility,1): # At least one item has to have non-zero utility
        utility = np.random.randint(0, max_utility+1, [batch_size, num_types])

    return torch.from_numpy(utility).long()

# Calculate reward
def reward(share, utility):
    return np.dot(utility, share)

In [28]:
class combined_policy(nn.Module):
    def __init__(self, embedding_dim = 100, batch_size = 128, num_layers = 1, bias = True, batch_first = False, dropout = 0, bidirectional = False):
        super(combined_policy, self).__init__()
        # Save variables
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        
        # Encoding -------------------------------------------------------------
        
        # Numerical encoder
        self.encoder1 = nn.Embedding(max_utility, embedding_dim)
        # Linguistic encoder
        self.encoder2 = nn.Embedding(num_vocab, embedding_dim)
        
        # Item context LSTM
        self.lstm1 = nn.LSTM(embedding_dim, embedding_dim, num_layers, bias, batch_first, dropout, bidirectional)
        # Linguistic LSTM
        self.lstm2 = nn.LSTM(embedding_dim, embedding_dim, num_layers, bias, batch_first, dropout, bidirectional)
        # Proposal LSTM
        self.lstm3 = nn.LSTM(embedding_dim, embedding_dim, num_layers, bias, batch_first, dropout, bidirectional)
        
        # Feed-forward
        self.ff = nn.Linear(3*embedding_dim, embedding_dim)
        
        # Policy ---------------------------------------------------------------
        
        # Termination policy
        self.policy_term = nn.Linear(embedding_dim, 1)
        # Linguistic policy
        self.policy_ling = nn.LSTM(embedding_dim, embedding_dim, num_layers, bias, batch_first, dropout, bidirectional)
        self.ff_ling = nn.Linear(embedding_dim, num_vocab)
        # Proposal policies
        self.policy_prop = []
        for i in range(num_types):
            ff = nn.Linear(embedding_dim, max_item)
            self.policy_prop.append(ff)
        
    def forward(self, x, test):
        # Extract inputs ------------------------------------------------------------
        
        # Item context
        x1 = x[0]
        # Previous linguistic message
        x2 = x[1]
        # Previous proposal
        x3 = x[2]  

        # Encoding ------------------------------------------------------------------
        
        # Initial embedding
        x1 = self.encoder1(x1).transpose(0,1)
        x2 = self.encoder2(x2).transpose(0,1)
        x3 = self.encoder1(x3).transpose(0,1)
            
        # LSTM for item context
        h = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial hidden
        c = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial cell

        for i in range(x1.size()[0]):
            _, (h,c) = self.lstm1(x1[i].view(1,self.batch_size,self.embedding_dim),(h,c))
        x1 = h
        
        # LSTM for linguistic
        h = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial hidden
        c = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial cell

        for i in range(x2.size()[0]):
            _, (h,c) = self.lstm2(x2[i].view(1,self.batch_size,self.embedding_dim),(h,c))
        x2 = h
        
        # LSTM for proposal
        h = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial hidden
        c = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial cell

        for i in range(x3.size()[0]):
            _, (h,c) = self.lstm2(x3[i].view(1,self.batch_size,self.embedding_dim),(h,c))
        x3 = h

        # Concatenate side-by-side
        x = torch.cat([x1,x2,x3],2)

        # Feedforward
        h = self.ff(x)
        h = F.relu(h) # Hidden layer input for policy networks
        
        # Policy ------------------------------------------------------------------

        # Termination -----------------------------------------------
        p_term = F.sigmoid(self.policy_term(h)).view(self.batch_size,1)
        
        entropy_term = -(p_term * p_term.log2()) - (torch.ones(128,1)-p_term * (torch.ones(128,1)-p_term.log2()))
    
        if test:
            # Greedy
            term = torch.round(p_term).long()
        else:
            # Sample
            term = torch.bernoulli(p_term).long()
        
        # Linguistic construction ----------------------------------
        h = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial hidden state
        c = torch.zeros(1,self.batch_size,self.embedding_dim) # Initial cell state
        letter = torch.zeros(self.batch_size,1).long() # Initial letter (dummy)
        entropy_letter = torch.zeros([self.batch_size,num_vocab])
        
        message = torch.zeros(self.batch_size,len_message) # Message
        for i in range(len_message):
            embedded_letter = self.encoder2(letter)

            _, (h,c) = self.policy_ling(embedded_letter.view(1,self.batch_size,self.embedding_dim),(h,c))
            logit = self.ff_ling(h)
            p_letter = F.softmax(logit,dim=2).view(self.batch_size,num_vocab)
            
            entropy_letter[:,i] = -1*(torch.sum(p_letter[i],0,keepdim=True) * torch.sum(p_letter[i],0,keepdim=True).log2())
            
            if test:
                # Greedy
                letter = p_letter.argmax(dim=1).view(self.batch_size,1)
            else:
                # Sample
                letter = torch.polynomial(p_letter,1)
            message[:,i] = letter.squeeze()
            
        message = message.long()
        entropy_letter = torch.sum(letter,1)     
   
        # Proposal ----------------------------------------------
        p_prop = torch.zeros(num_types,self.batch_size,max_item)
        prop = torch.zeros([self.batch_size,num_types]).long()
        entropy_prop= torch.zeros([self.batch_size,num_types])
        
        for i in range(num_types):
            blah = F.sigmoid(self.policy_prop[i](h))
            p_prop[i] = F.sigmoid(self.policy_prop[i](h))
            
            entropy_prop[:,i] = -1*(torch.sum(p_prop[i],1) * torch.sum(p_prop[i],1).log2())
            if test:
                # Greedy
                prop[:,i] = p_prop[i].argmax(dim=1)
            else:
                # Sample
                prop[0][i] = torch.multinomial(p_prop,1)
            
        entropy_prop = torch.sum(entropy_prop,1) # Entropy for exploration        

        loss = torch.zeros(1,requires_grad=True)
        
        return (term,message,prop,loss)
    
net = combined_policy()

In [29]:
x = torch.randint(0,max_item,[128,6]).long()
y = torch.randint(0,num_vocab,[128,6]).long()
z = torch.randint(0,max_item,[128,3]).long()

In [30]:
blah = net([x,y,z],True)

print(blah[2].size())

torch.Size([128, 3])


In [42]:
# Agents
Agents = []
for i in range(num_agents):
    Agents.append(combined_policy())
    
# Train REINFORCE
alpha = 0.001     # learning rate
N_ep = 50   # Number of episodes
num_games = 128 # Number of games per episode (batch size)

# Initialize optimizer to update the DQN
optimizers = []
for i in range(num_agents):
    optimizers.append(optim.Adam(Agents[i].parameters(), alpha))

# Loop over episodes
for i_ep in range(N_ep):
    # Setting up games -----------------------------------------------------------------------
    
    # Truncated Poisson sampling
    N = np.random.poisson(lam,num_games) 
    N = np.minimum(N,max_N)
    N = np.maximum(N,min_N)

    # Setting
    pool = create_item_pool(num_types, max_item, num_games) # Item pool
    item_contexts = []
    for i in range(num_agents):
        utility = create_agent_utility(num_types, max_utility, num_games)
        item_contexts.append(torch.cat([pool,utility],1))
    
    # Initialization
    survivors = torch.ones(num_games).nonzero() # Everyone alive initially
    prev_messages = torch.zeros(num_games,len_message).long() # Previous linguistic messages
    prev_proposals = torch.zeros(num_games,num_types).long()   # Previous proposals
    num_alive = len(survivors)
    # Play the games -------------------------------------------------------------------------
    for i_turn in range(max_N):
        # Sieve
        pool = pool[survivors]
        prev_messages = prev_messages[survivors]
        prev_proposals = prev_proposals[survivors]
        for j in range(num_agents):
            item_contexts[j] = item_contexts[j][survivors]
        
        # Agent that is playing
        Agent = Agents[i_turn % 2]             
        item_context = item_contexts[i_turn % 2]
        
        # Actually play the game
        term,prev_messages,prev_proposals,loss = Agent([item_context,prev_messages,prev_proposals], True)

        # optimize
        optimizer = optimizers[i_turn % 2]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Remove finished games
        # In term, element = 1 means die
        term_N = torch.from_numpy(1*(N <= i_turn)).view(num_alive,1)
        # In survivors, element = 1 means live
        survivors = (((term+term_N)) == 0).nonzero().view(num_alive)
        num_alive = len(survivors)



RuntimeError: index out of range at /opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/TH/generic/THTensorMath.c:343

In [36]:
prev_messages.size()

torch.Size([128, 6])

In [39]:
survivors

tensor([[   0],
        [   1],
        [   2],
        [   3],
        [   4],
        [   5],
        [   6],
        [   7],
        [   8],
        [   9],
        [  10],
        [  11],
        [  12],
        [  13],
        [  14],
        [  15],
        [  16],
        [  17],
        [  18],
        [  19],
        [  20],
        [  21],
        [  22],
        [  23],
        [  24],
        [  25],
        [  26],
        [  27],
        [  28],
        [  29],
        [  30],
        [  31],
        [  32],
        [  33],
        [  34],
        [  35],
        [  36],
        [  37],
        [  38],
        [  39],
        [  40],
        [  41],
        [  42],
        [  43],
        [  44],
        [  45],
        [  46],
        [  47],
        [  48],
        [  49],
        [  50],
        [  51],
        [  52],
        [  53],
        [  54],
        [  55],
        [  56],
        [  57],
        [  58],
        [  59],
        [  60],
        [  61],
        