In [None]:
import os
import time
import random
import math
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from PIL import Image
from IPython.display import clear_output

import ai2thor
import ai2thor_colab
from ai2thor_colab import plot_frames
from ai2thor.controller import Controller

from ai2thor.platform import CloudRendering
controller = Controller(platform=CloudRendering)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.distributions import Categorical

In [None]:
from network import DQN
from utils import to_torchdim, frame2tensor, plot_durations, encode_feedback

In [None]:
import warnings
warnings.filterwarnings('ignore')

plt.style.use('ggplot')

## Set Environment

In [None]:
floor_index = random.randint(0, 30)
floor_index = 20

controller = Controller(
    agentMode = "default", # arm
    visibilityDistance = 1.5,
    scene = f"FloorPlan{floor_index}",

    # step sizes
    snapToGrid = True,
    gridSize = 0.25,
    rotateStepDegrees = 90,

    # image modalities
    renderInstanceSegmentation = False,
    renderDepthImage = False,
    renderSemanticSegmentation = False,
    renderNormalsImage = False,
    
    # camera properties
    width = 600,
    height = 420,
    fieldOfView = 120,
)

plot_frames(controller.last_event)

In [None]:
instructions = ["switch"]
# instructions = ["find light switch", "locate power switch", "seek light control", "discover light control", "identify switch light", "pinpoint light switch"]

## Set Configs

In [None]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

In [None]:
BUFFER_SIZE = int(1e4)
BATCH_SIZE = 64
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
NUM_EPISODES = 3000
TARGET_UPDATE = 4

TAU = 1e-4
LR = 2.5e-4

SCREEN_WIDTH = SCREEN_HEIGHT = 100

AGENT_TARGET = "LightSwitch_bf8119ce" # LightSwitch_887b121a, LightSwitch_bf8119ce

In [None]:
action_space = ["MoveAhead", "MoveLeft", "MoveRight", "MoveBack", "RotateLeft", "RotateRight"]

## Load Word2Vec

In [None]:
from gensim.models import Word2Vec, KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec

In [None]:
glove_input_file = 'weights/glove.6B.100d.txt'
word2vec_output_file = 'weights/glove.6B.100d.txt.word2vec'

glove2word2vec(glove_input_file, word2vec_output_file)
word2vec_model = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)

In [None]:
pretrained_embeddings = torch.FloatTensor(word2vec_model.vectors)

## Set Replay Memory

In [None]:
class ReplayBuffer(object):
    """Fixed-size buffer to store experience tuples."""
    
    def __init__(self, action_size, buffer_size, batch_size, seed):
        """Initialize a ReplayMemory object."""
        
        self.experience = namedtuple("Experience", field_names=["visual_state", "text_state", "action", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
        
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
    
    def add(self, visual_state, text_state, action, reward, next_state, done):
        """Add a new experience to buffer."""
        
        self.memory.append(self.experience(visual_state, text_state, action, reward, next_state, done))
    
    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        
        experiences = random.sample(self.memory, k=self.batch_size)
        
        visual_states = torch.from_numpy(np.vstack([exp.visual_state.cpu().numpy() for exp in experiences if exp is not None])).float()
        visual_states = visual_states.to(device)
        
        text_states = torch.from_numpy(np.vstack([exp.text_state.cpu().numpy() for exp in experiences if exp is not None])).float()
        text_states = text_states.to(device)
        
        actions = torch.from_numpy(np.vstack([exp.action.cpu().numpy() for exp in experiences if exp is not None])).long()
        actions = actions.to(device)
        
        rewards = torch.from_numpy(np.vstack([exp.reward.cpu().numpy() for exp in experiences if exp is not None])).float()
        rewards = rewards.to(device)
        
        next_states = torch.from_numpy(np.vstack([exp.next_state.cpu().numpy() for exp in experiences if exp is not None])).float()
        next_states = next_states.to(device)
        
        dones = torch.from_numpy(np.vstack([exp.done for exp in experiences if exp is not None]).astype(np.uint8)).float()
        dones = dones.to(device)
        
        return (visual_states, text_states, actions, rewards, next_states, dones)
    
    def __len__(self):
        """Return the current size of internal memory."""
        
        return len(self.memory)

## Build Model

In [None]:
class CustomCNN(nn.Module):
    def __init__(self):
        
        super(CustomCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        
    def forward(self, state):
        
        state = F.relu(self.bn1(self.conv1(state)))
        state = F.relu(self.bn2(self.conv2(state)))
        state = F.relu(self.bn3(self.conv3(state)))
        
        return state

In [None]:
class VisualModel(nn.Module):
    
    def __init__(self):
        
        super(VisualModel, self).__init__()
        self.cnn = CustomCNN()
        
    def forward(self, x):
        
        x = self.cnn(x)
        output = x.reshape(x.size(0), -1)
        
        return output
    
    
class TextModel(nn.Module):
    
    def __init__(self, pretrained_embedding, hidden_dim):
        
        super(TextModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(pretrained_embedding, freeze=True)
        self.rnn = nn.LSTM(pretrained_embedding.shape[1], hidden_dim)
        
    def forward(self, x):
        
        x = x.long()
        
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)
        output = output.view(output.shape[0], -1)
        
        return output

In [None]:
class MultimodalDQN(nn.Module):
    
    def __init__(self, visual_model, text_model, action_size, seed):
        
        super(MultimodalDQN, self).__init__()
        
        self.seed = torch.manual_seed(seed)
        
        self.visual_model = visual_model
        self.text_model = text_model  
        self.fc = nn.Linear(2656, action_size)
        
    def forward(self, visual_input, text_input):
        
        visual_features = self.visual_model(visual_input.to(device))
        text_features = self.text_model(text_input.to(device))
        
        combined_features = torch.cat((visual_features, text_features), dim=1)
        q_values = self.fc(combined_features)
        
        return q_values

## Build Agent

In [None]:
class DQNAgent():
    """The agent interacting with and learning from the environment."""
    
    def __init__(self, screen_width, screen_height, action_size, seed):
        """Init Agent’s models."""
        
        self.action_size = action_size
        self.seed = random.seed(seed)
        
        # Multimodal DQN
        self.visual_model = VisualModel()
        self.text_model = TextModel(pretrained_embeddings, hidden_dim=64)

        self.dqn_net = MultimodalDQN(self.visual_model, self.text_model, action_size, seed).to(device)
        self.target_net = MultimodalDQN(self.visual_model, self.text_model, action_size, seed).to(device)
        self.optimizer = optim.RMSprop(self.dqn_net.parameters(), lr=LR, alpha=0.95, eps=0.01)
        
        # Replay Buffer
        self.buffer = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        self.time_step = 0
    
    def memorize(self, visual_state, text_state, action, reward, next_state, done):
        """Save experience in replay buffer."""
        
        self.buffer.add(visual_state, text_state, action, reward, next_state, done)
    
        self.time_step = (self.time_step + 1) % TARGET_UPDATE
        if self.time_step == 0:
            # if enough samples are available in memory, get random subset and learn
            if len(self.buffer) > BATCH_SIZE:
                experiences = self.buffer.sample()
                self.learn(experiences, GAMMA)
    
    def visual_preprocess(self, visual_state, screen_width, screen_height):
        """Preprocess input frame before passing into agent."""
        
        resized_screen = Image.fromarray(visual_state).resize((screen_width, screen_height))
        visual_state = frame2tensor(to_torchdim(resized_screen)).to(torch.float32).to(device)

        return visual_state
    
    def text_preprocess(self, instruction):
        """Preprocess instructions before passing into agent."""
        
        text_state = instruction
        text_state = [word2vec_model.key_to_index[word] for word in text_state.split()]
        text_state = torch.tensor(text_state).long()
        text_state = text_state.unsqueeze(0)
        
        return text_state
    
    def act(self, visual_state, text_state, epsilon=0.):
        """Returns actions for given state as per current policy."""
        
        # epsilon-greedy action selection
        if random.random() > epsilon:
            self.dqn_net.eval()
            with torch.no_grad():
                action = self.dqn_net(visual_state, text_state).max(1)[1].view(1, 1)
            self.dqn_net.train()
            
        else:
            action = torch.tensor([[random.randrange(self.action_size)]], dtype=torch.long, device=device)
            
        return action
    
    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples."""
    
        visual_states, text_states, actions, rewards, next_states, dones = experiences
        
        self.optimizer.zero_grad()
        
        # get index of maximum value for next state
        Qsa_next = self.dqn_net(next_states, text_states).detach()
        _, action_max = Qsa_next.max(1)
        
        # get max predicted Q values (for next states) from target network
        Q_target_next = self.target_net(next_states, text_states).detach().gather(1, action_max.unsqueeze(1))
        
        # compute Q target
        Q_target = rewards + (gamma * Q_target_next * (1 - dones))
        
        # get expected Q values from dqn network
        Q_expected = self.dqn_net(visual_states, text_states).gather(1, actions)
        
        # compute loss
        loss = F.mse_loss(Q_target, Q_expected)
        
        # minimize the loss
        loss.backward()
        self.optimizer.step()
        
        # update target network
        self.soft_update(self.target_net, self.dqn_net, TAU)
    
    def soft_update(self, target_net, dqn_net, tau):
        """Soft update target network parameters."""
        
        for target_param, dqn_param in zip(target_net.parameters(), dqn_net.parameters()):
            target_param.data.copy_(tau*dqn_param.data + (1.0-tau) * target_param.data)
        
    def watch(self, controller, instruction, num_episodes=10):
        """Watch trained agent."""
        best_score = -np.inf

        for i_episode in range(1, num_episodes+1):

            # initialize the environment and state
            controller.reset(random=True)

            screen = controller.last_event.frame
            resized_screen = Image.fromarray(screen).resize((SCREEN_WIDTH, SCREEN_HEIGHT))

            visual_state = frame2tensor(to_torchdim(resized_screen)).to(torch.float32).to(device)
            text_state = agent.text_preprocess(instruction)
        
            total_score = 0

            for time_step in range(1, 100):

                # select an action using the trained dqn network
                if time_step == 1 or time_step == 2 or time_step == 3:
                    action = torch.tensor([[random.randint(0, self.action_size-1)]]).to(device)
                else:
                    with torch.no_grad():
                        action = self.dqn_net(visual_state, text_state).max(1)[1].view(1, 1)

                event = controller.step(action = action_space[action.item()])

                time.sleep(1)
                
                _, reward, done, _ = encode_feedback(event, controller, target_name=AGENT_TARGET)

                # observe a new state
                if not done:
                    screen = controller.last_event.frame
                    resized_screen = Image.fromarray(screen).resize((SCREEN_WIDTH, SCREEN_HEIGHT))

                    next_state = frame2tensor(to_torchdim(resized_screen)).to(torch.float32).to(device)
                else:
                    next_state = None

                visual_state = next_state
                total_score += reward
                if done:
                    break

            if total_score > best_score: 
                best_score = total_score

            print(f'\rEpisode {i_episode}/{num_episodes}, Total Score: {total_score}, Best Score: {best_score}', end='') 

In [None]:
agent = DQNAgent(screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT, action_size=len(action_space), seed=90)

## Train Agent

In [None]:
# define linear decay
def calculate_epsilon(episode):
    slope = (EPS_END - EPS_START) / NUM_EPISODES
    epsilon = EPS_START + slope * episode

    return max(epsilon, EPS_END)

In [None]:
def train_network(num_episodes, max_time):
    
    epsilon = EPS_START
    
    for i_episode in range(1, num_episodes+1):
        
        # initialize the environment and state
        controller.reset(random=True)
        
        visual_state = agent.visual_preprocess(controller.last_event.frame, 
                                               screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT)
        
        instruction = instructions[i_episode % len(instructions)]
        text_state = agent.text_preprocess(instruction)
        
        total_score = 0
        
        for time_step in range(1, max_time+1):
            
            # select and perform an action using dqn network
            action = agent.act(visual_state, text_state, epsilon)
            event = controller.step(action = action_space[action.item()])
            
            _, reward, done, _ = encode_feedback(event, controller, target_name=AGENT_TARGET)
            total_score += reward
            reward = torch.tensor([reward], device=device)
            
            next_state = agent.visual_preprocess(controller.last_event.frame, 
                                                 screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT)
            
            agent.memorize(visual_state, text_state, action, reward, next_state, done)
            
            # move to the next state
            state = next_state
            
            if done or time_step == max_time:
                plot_durations(total_score, i_episode, num_episodes)
                break
            
        epsilon = calculate_epsilon(i_episode)
        
    if not os.path.exists('./agents/'): os.makedirs('./agents/')
    torch.save(agent.dqn_net.state_dict(), f'./agents/AI2THOR_MM_RL.pth')
    
    print('Training completed.')
    plt.ioff()
    plt.show()

In [None]:
print('Training the network...')
train_network(num_episodes=NUM_EPISODES, max_time=400)

## Check The Result!

In [None]:
# load the weights of smart agent
agent.dqn_net.load_state_dict(torch.load(f'./agents/AI2THOR_MM_RL.pth'));

In [None]:
instruction = "find light switch"
agent.watch(controller, instruction, num_episodes=10)

In [None]:
controller.stop()

In [None]:
# 1. try learning with single word
# 2. try taking out the hidden layer on text model (only embedding then)
# 3. if first & second work, then try three words with embedding layer only
# 4. increase num_epsiodes to 4000, target_update to 10 and buffer_size to 2e4
# 5. think more about Concatenation technique, Model capacity, Data balance, Loss function and Reward shaping

---