In [None]:
import os
import time
import random
import math
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from PIL import Image
from IPython.display import clear_output

import ai2thor
import ai2thor_colab
from ai2thor_colab import plot_frames
from ai2thor.controller import Controller

from ai2thor.platform import CloudRendering
controller = Controller(platform=CloudRendering)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.distributions import Categorical

In [None]:
from network import DQN
from utils import to_torchdim, frame2tensor, encode_feedback

In [None]:
import warnings
warnings.filterwarnings('ignore')

plt.style.use('ggplot')

## Set Environment

In [None]:
floor_index = random.randint(0, 30)
floor_index = 20

controller = Controller(
    agentMode = "default", # arm
    visibilityDistance = 0.75,
    scene = f"FloorPlan{floor_index}",

    # step sizes
    snapToGrid = True,
    gridSize = 0.25,
    rotateStepDegrees = 90,

    # image modalities
    renderInstanceSegmentation = False,
    renderDepthImage = False,
    renderSemanticSegmentation = False,
    renderNormalsImage = False,
    
    # camera properties
    width = 600,
    height = 420,
    fieldOfView = 120,
    
    # set seed for reproducability
    seed = 90,
)

## Set Configs

In [None]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

In [None]:
NUM_EPISODES = 10000 # 6000

SCREEN_WIDTH = SCREEN_HEIGHT = 100

In [None]:
action_space = ["MoveAhead", "MoveLeft", "MoveRight", "MoveBack", "RotateLeft", "RotateRight"]

## Load Word2Vec

In [None]:
from gensim.models import Word2Vec, KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec

In [None]:
glove_input_file = 'weights/glove.6B.100d.txt'
word2vec_output_file = 'weights/glove.6B.100d.txt.word2vec'

glove2word2vec(glove_input_file, word2vec_output_file)
word2vec_model = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)

In [None]:
pretrained_embeddings = torch.FloatTensor(word2vec_model.vectors)

## Build Model

In [None]:
class VisualModel(nn.Module):
    
    def __init__(self, seed):
        
        super(VisualModel, self).__init__()
        
        self.seed = torch.manual_seed(seed)
        
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()
        
    def forward(self, x):
        
        return self.cnn(x)
    
    
class TextModel(nn.Module):
    
    def __init__(self, pretrained_embedding, hidden_dim, seed):
        
        super(TextModel, self).__init__()
        
        self.seed = torch.manual_seed(seed)
        
        self.embedding = nn.Embedding.from_pretrained(pretrained_embedding, freeze=True)
        self.rnn = nn.LSTM(pretrained_embedding.shape[1], hidden_dim)
        
    
    def forward(self, x):
        
        x = x.long()
        
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)
        output = output.view(output.shape[0], -1)
        
        return output

In [None]:
class MultimodalDQN(nn.Module):
    
    def __init__(self, visual_model, text_model, action_size, seed):
        
        super(MultimodalDQN, self).__init__()
        
        self.seed = torch.manual_seed(seed)
        
        self.visual_model = visual_model
        self.text_model = text_model
        
        # Define three fully connected layers
        self.fc1 = nn.Linear(576, 512) # 256, 512, 1024
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, action_size)
        
    def forward(self, visual_input, text_input):
        
        visual_features = self.visual_model(visual_input.to(device))
        text_features = self.text_model(text_input.to(device))
        
        # Concatenate visual and text features
        combined_features = torch.cat((visual_features, text_features), dim=1)
        
        # Apply fully connected layers
        combined_features = F.relu(self.fc1(combined_features))
        combined_features = F.relu(self.fc2(combined_features))
        q_values = self.fc3(combined_features)
        
        return q_values

## Build Agent

In [None]:
class DQNAgent():
    """The agent interacting with and learning from the environment."""
    
    def __init__(self, screen_width, screen_height, action_size, seed):
        """Init Agent’s models."""
        
        self.action_size = action_size
        self.seed = random.seed(seed)
        
        # Multimodal DQN
        self.visual_model = VisualModel(seed=seed)
        self.text_model = TextModel(pretrained_embeddings, hidden_dim=64, seed=seed)

        self.dqn_net = MultimodalDQN(self.visual_model, self.text_model, action_size, seed).to(device)
    
    def visual_preprocess(self, visual_state, screen_width, screen_height):
        """Preprocess input frame before passing into agent."""
        
        resized_screen = Image.fromarray(visual_state).resize((screen_width, screen_height))
        visual_state = frame2tensor(to_torchdim(resized_screen)).to(torch.float32).to(device)

        return visual_state
    
    def text_preprocess(self, instruction):
        """Preprocess instructions before passing into agent."""
        
        text_state = instruction
        text_state = [word2vec_model.key_to_index[word] for word in text_state.split()]
        text_state = torch.tensor(text_state).long()
        text_state = text_state.unsqueeze(0)
        
        return text_state
    
    def randomize_agent(self, controller):

        positions = controller.step(
            action="GetReachablePositions"
        ).metadata["actionReturn"]

        position = random.choice(positions)
        controller.step(
            action="Teleport",
            position=position,
            rotation=dict(x=0, y=270, z=0),
            horizon=0,
            standing=True
        )
        
    def watch(self, controller, instruction, num_episodes=10):
        """Watch trained agent."""
        
        best_score = -np.inf
        action_space = ["MoveAhead", "MoveLeft", "MoveRight", "MoveBack", "RotateLeft", "RotateRight"]

        for i_episode in range(1, num_episodes+1):

            # initialize the environment and state
            controller.reset(random=True)
            
            self.randomize_agent(controller)

            visual_state = agent.visual_preprocess(controller.last_event.frame, 
                                               screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT)
        
            inst_tupple = instructions[i_episode % len(instructions)]
            instruction, AGENT_TARGET = inst_tupple
            text_state = agent.text_preprocess(instruction)
            
            total_score = 0

            self.dqn_net.eval()
                
            for time_step in range(1, 50):
                
                # clear_output(wait=True)
                
                # select an action using the trained dqn network
                with torch.no_grad():
                    action = self.dqn_net(visual_state, text_state).max(1)[1].view(1, 1)

                # print(f"Time Step: {time_step}, Action: {action_space[action.item()]}")
                event = controller.step(action = action_space[action.item()])

                time.sleep(1)
                
                _, reward, done, _ = encode_feedback(event, controller, target_name=AGENT_TARGET)

                # observe a new state
                if not done:
                    screen = controller.last_event.frame
                    resized_screen = Image.fromarray(screen).resize((SCREEN_WIDTH, SCREEN_HEIGHT))

                    next_state = frame2tensor(to_torchdim(resized_screen)).to(torch.float32).to(device)
                else:
                    next_state = None

                visual_state = next_state
                total_score += reward
                if done:
                    break

            if total_score > best_score: 
                best_score = total_score

            print(f'\rEpisode {i_episode}/{num_episodes}, Total Step: {time_step}, Total Score: {total_score}, Best Score: {best_score}', end='') 

In [None]:
agent = DQNAgent(screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT, action_size=len(action_space), seed=90)

## Check The Result!

In [None]:
# load the weights of smart agent
agent.dqn_net.load_state_dict(torch.load(f'./agents/AI2THOR_MM_RL_3OBJ_R20.pth'));

In [None]:
instructions = [("switch", "LightSwitch_887b121a")]
# instructions = [("tomato", "Tomato_e65a6e2e")]
# instructions = [("garbage", "GarbageCan_d6916cf5")]

agent.watch(controller, instructions, num_episodes=10)

---