In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
from collections import namedtuple

In [4]:
class Vision_M(nn.Module):
    def __init__(self):
        '''
            Use the same hyperparameter settings denoted in the paper
        '''
        
        super(Vision_M, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=8, stride=4 )
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2 )
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
        
        
    def forward(self, x):
        # x is input image with shape [3, 84, 84]
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        
        return out

In [5]:
class Language_M(nn.Module):
    def __init__(self, vocab_size=100, embed_dim=128, hidden_size=128):
        '''
            Use the same hyperparameter settings denoted in the paper
        '''
        
        super(Language_M, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        
        self.embeddings = nn.Embedding(vocab_size, embed_dim)  # 2 words in vocab, 5 dimensional embeddings
        self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers=1, batch_first=True)
        
    def forward(self, x):
        '''
            Argument
                x: natural language instructions encoded in word indices, has shape
                    [batch_size, seq]
        '''
        embedded_input = self.embeddings(x)
        out, hn = self.lstm(embedded_input)
        h, c = hn
        
        return h
        

In [6]:
def test_language_module():
    test_instruction = Variable(torch.LongTensor(np.random.randint(1, 100, size=(1, 10))))
    print('Input size', test_instruction.size())

    lm = Language_M()
    output = lm(test_instruction)
    
    print('Output size', output.size())
    
    return output

l_out = test_language_module()

Input size torch.Size([1, 10])
Output size torch.Size([1, 1, 128])


In [7]:
def test_vision_module():
    test_input = Variable(torch.randn(1, 3, 84, 84))
    print('Input size', test_input.size())

    vm = Vision_M()

    test_output = vm(test_input)
    print('Output size', test_output.size())

    return test_output
    
v_out = test_vision_module()

Input size torch.Size([1, 3, 84, 84])
Output size torch.Size([1, 64, 7, 7])


In [8]:
class Mixing_M(nn.Module):
    def __init__(self):
        super(Mixing_M, self).__init__()

    
    def forward(self, visual_encoded, instruction_encoded):
        '''
            Argument:
                visual_encoded: output of vision module, shape [batch_size, 64, 7, 7]
                instruction_encoded: hidden state of language module, shape [batch_size, 1, 128]
        '''
        batch_size = visual_encoded.size()[0]
        visual_flatten = visual_encoded.view(batch_size, -1)
        instruction_flatten = instruction_encoded.view(batch_size, -1)
                
        mixed = torch.cat([visual_flatten, instruction_flatten], dim=1)
        
        return mixed

In [9]:
def test_mixing_module(v_out, l_out):
    mm = Mixing_M()
    
    test_output = mm(v_out, l_out)
    print("Output size", test_output.size())
    
    return test_output

m_out = test_mixing_module(v_out, l_out)

Output size torch.Size([1, 3264])


In [10]:
class Action_M(nn.Module):
    def __init__(self, batch_size=1, hidden_size=256):
        super(Action_M, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        
        self.lstm_1 = nn.LSTMCell(input_size=3264, hidden_size=256)
        self.lstm_2 = nn.LSTMCell(input_size=256, hidden_size=256)
        
        self.hidden_1 = (Variable(torch.randn(batch_size, hidden_size)), 
                        Variable(torch.randn(batch_size, hidden_size))) 
        
        self.hidden_2 = (Variable(torch.randn(batch_size, hidden_size)), 
                        Variable(torch.randn(batch_size, hidden_size))) 
        
    def forward(self, x):
        '''
            Argument:
                x: x is output from the Mixing Module, as shape [batch_size, 1, 3264]
        '''
        # Feed forward
        h1, c1 = self.lstm_1(x, self.hidden_1)
        h2, c2 = self.lstm_2(h1, self.hidden_2)
        
        # Update current hidden state
        self.hidden_1 = (h1, c1)
        self.hidden_2 = (h2, c2)
        
        # Return the hidden state of the upper layer
        return h2

In [11]:
def test_action_module():
    am = Action_M()
    test_output = am(m_out)
    print(test_output.size())
    
    return test_output

am_output = test_action_module()

torch.Size([1, 256])


In [12]:
SavedAction = namedtuple('SavedAction', ['action', 'value'])

class Policy(nn.Module):
    def __init__(self, action_space):
        super(Policy, self).__init__()
        self.action_space = action_space
        
        self.affine1 = nn.Linear(256, 128)
        self.action_head = nn.Linear(128, action_space)
        self.value_head = nn.Linear(128, 1)

        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.action_head(x)
        state_values = self.value_head(x)
        
        return action_scores, state_values

In [13]:
def test_policy():
    policy = Policy(action_space=10)
    test_output = policy(am_output)
    
    print('Output size', test_output[0].size(), test_output[1].size())

    return test_output

policy_output = test_policy()

Output size torch.Size([1, 10]) torch.Size([1, 1])


In [14]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.vision_m = Vision_M()
        self.language_m = Language_M()
        self.mixing_m = Mixing_M()
        self.action_m = Action_M()
        
        self.policy = Policy(action_space=10)
        
    def forward(self, img, instruction):
        '''
        Argument:
        
            img: environment image, shape [batch_size, 3, 84, 84]
            instruction: natural language instruction [batch_size, seq]
        '''
        
        vision_out = self.vision_m(img)
        language_out = self.language_m(instruction)
        mix_out = self.mixing_m(vision_out, language_out)
        action_out = self.action_m(mix_out)
        
        action_prob, value = self.policy(action_out)
        
        return action_prob, value

In [22]:
def test_model():
    model = Model()
    
    vision = Variable(torch.randn(1, 3, 84, 84))
    instruction = Variable(torch.LongTensor(np.random.randint(1, 100, size=(1, 10))))
    
    probs, value = model(vision, instruction)
    
    print(probs.size(), value.size())
    print(probs.data.numpy(), value.data.numpy())
    
test_model()

torch.Size([1, 10]) torch.Size([1, 1])
[[-0.07670651 -0.10970842  0.03437663 -0.10954983 -0.03560381 -0.11679479
   0.05895545  0.06801075  0.12713522 -0.1156317 ]] [[-0.14793287]]


In [29]:
class FakeEnvironment(object):
    def __init__(self):
        pass
    
    def reset(self):
        return self.generate_state()
    
    def step(self, action):
        return self.generate_state(), 2, False, 1
        
    def generate_state(self):
        vision = Variable(torch.randn(1, 3, 84, 84))
        instruction = Variable(torch.LongTensor(np.random.randint(1, 100, size=(1, 10))))
        
        return (vision, instruction)

In [30]:
torch.manual_seed(42)

gamma = 0.99
tau = 1.0
model = Model()
#model = ActorCritic(env.observation_space.shape[0], env.action_space)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()

env = FakeEnvironment()
state = env.reset()
done = True

episode_length = 0
while True:
    episode_length += 1

    values = []
    log_probs = []
    rewards = []
    entropies = []

    # Loop for number of unrolling steps
    for step in range(50):
        logit, value = model(*state)
        
        # Calculate entropy from action probability distribution
        prob = F.softmax(logit)
        log_prob = F.log_softmax(logit)
        entropy = -(log_prob * prob).sum(1)
        entropies.append(entropy)

        # Take the action from distribution
        action = prob.multinomial().data
        log_prob = log_prob.gather(1, Variable(action))

        # Perform the action on the environment
        state, reward, done, _ = env.step(action.numpy())
        done = done or episode_length >= 100
        
        if done:
            episode_length = 0
            state = env.reset()

        values.append(value)
        log_probs.append(log_prob)
        rewards.append(reward)

        if done:
            break

    # Terminal state or reached number of unrolling steps, whichever comes first
    R, gae = Variable(torch.zeros(1, 1)), torch.zeros(1, 1)
    
    # Get value estimate of current state
    if not done:
        _, R = model(*state)

    values.append(R)
    policy_loss, value_loss = 0, 0

    # Performing update
    for i in reversed(range(len(rewards))):
        # Value function loss
        R = gamma * R + rewards[i]
        value_loss = value_loss + 0.5 * (R - values[i]).pow(2)

        # Generalized Advantage Estimataion
        delta_t = rewards[i] + gamma * \
                values[i + 1].data - values[i].data
        gae = gae * gamma * tau + delta_t

        # Computing policy loss
        policy_loss = policy_loss - \
            log_probs[i] * Variable(gae) - 0.01 * entropies[i]

    optimizer.zero_grad()

    # Back-propagation
    (policy_loss + 0.5 * value_loss).backward()
    torch.nn.utils.clip_grad_norm(model.parameters(), 40)

    optimizer.step()
    
    # This is only for testing
    break