In [11]:
import torch
import random
import numpy as np
from torch.autograd import Variable
from collections import namedtuple

In [35]:
Transition = namedtuple('Transition',
                        ('state', 'action_logit', 'next_state', 'reward', 'value'))

State = namedtuple('State', ('visual', 'instruction'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, *args):
        self.memory.append(Transition(*args))
        
        while len(self.memory) > capacity:
            self.memory.pop(0)

    def sample(self, batch_size):
        start_index = np.random.randint(0, self.capacity - batch_size)
        return self.memory[start_index : start_index + batch_size]

    def __len__(self):
        return len(self.memory)

In [None]:
for i in range(50):
    logit, value = model(state)

    # Calculate entropy from action probability distribution
    prob = F.softmax(logit)
    log_prob = F.log_softmax(logit)
    entropy = -(log_prob * prob).sum(1)
    entropies.append(entropy)

    # Take an action from distribution
    action = prob.multinomial().data
    log_prob = log_prob.gather(1, Variable(action))

    # Perform the action on the environment
    next_state, reward, done, _ = env.step(action.numpy())
        
    values.append(value)
    log_probs.append(log_prob)
    rewards.append(reward)

    # Push to experience replay buffer
    # THERE IS NO Terminal state in the buffer, ONLY transition
    # THere'll be NO resetting the MEMORY Buffer
    memory.push(state, logit, next_state, reward, value)
    
    if done:
        final_value = Variable(torch.zeros(1, 1))
    elif episode_length >= 100:
        _, final_value = model(next_state)
    
    done = done or episode_length >= 100:
    if done:
        values.append(final_value)
        episode_length = 0
        next_state = env.reset()
            
    # move to next state
    state = next_state
    
    if done:
        break

        
optimize_model(values, log_probs, rewards)

In [37]:
def mse_loss(predicted, target):
    return torch.sum((predicted - target) ** 2)

In [38]:
def optimize_model(values, log_probs, rewards):
    R = values[-1]
    gae = torch.zeros(1, 1)
    
    # Base A3C Loss
    policy_loss, value_loss = 0, 0

    # Performing update
    for i in reversed(range(len(rewards))):
        # Value function loss
        R = gamma * R + rewards[i]
        value_loss = value_loss + 0.5 * (R - values[i]).pow(2)

        # Generalized Advantage Estimataion
        delta_t = rewards[i] + gamma * \
                values[i + 1].data - values[i].data
        gae = gae * gamma * tau + delta_t

        # Computing policy loss
        policy_loss = policy_loss - \
            log_probs[i] * Variable(gae) - 0.01 * entropies[i]

    # Auxiliary loss
    language_prediction_loss = 0 
    tae_loss = 0
    reward_prediction_loss = 0
    value_replay_loss = 0

    # Non-skewed sampling from experience buffer
    auxiliary_sample = memory.sample(11)
    auxiliary_batch = Transition(*zip(*auxiliary_sample))

    # Language Prediction Loss
    # TODO #
    
    # TAE Loss
    visual_input = auxiliary_batch.state[:10]
    visual_input = [t.visual for t in visual_input]

    visual_target = auxiliary_batch.state[1:11].visual
    visual_target = [t.visual for t in visual_target]
    
    action_logit = auxiliary_batch.action_logit[:10]
        
    tae_output = model.tAE(visual_input, action_logit)
    tae_loss = mse_loss(tae_output, visual_target)
    
    # Skewed-Sampling from experience buffer
    skewed_sample = memory.skewed_sample(31)
    skewed_batch = Transition(*zip(*skewed_sample))
    
    # Reward Prediction loss
    batch_rp_input = []
    batch_rp_output = []
    
    for i in range(10):
        rp_input = skewed_batch.state[i : i+3]
        rp_output = skewed_batch.rewards[i+3]
            
        batch_rp_input.append(rp_input)
        batch_rp_output.append(rp_output)
            
    rp_predicted = model.reward_prediction(batch_rp_input)
    reward_prediction_loss = mse_loss(rp_predicted, batch_rp_output)
    
    # Value function replay
    index = np.random.randint(0, 10)
    R_vr = auxiliary_batch.value[index+1]
    R_vr = R_vr * gamma + auxiliary_batch.reward[index]
    value_replay_loss = 0.5 * mse_loss(R_vr, auxiliary_batch.value[index])
            
    # Back-propagation
    optimizer.zero_grad()
    (policy_loss + 0.5 * value_loss + reward_prediction_loss + tae_loss + 
                         language_prediction_loss + value_replay_loss).backward()
    torch.nn.utils.clip_grad_norm(model.parameters(), 40)
    
    # Apply updates
    optimizer.step()