In [None]:
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from gensim.models import Word2Vec

In [None]:
import warnings
warnings.filterwarnings("ignore")

## Build Model

In [None]:
class VisualModel(nn.Module):
    
    def __init__(self):
        
        super(VisualModel, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()
        # self.cnn.fc = nn.Linear(512, num_classes)
        
    def forward(self, x):
        
        return self.cnn(x)
    
    
class TextModel(nn.Module):
    
    def __init__(self, pretrained_embedding, hidden_dim):
        
        super(TextModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(pretrained_embedding, freeze=True)
        self.rnn = nn.LSTM(pretrained_embedding.shape[1], hidden_dim)
        
    def forward(self, x):
        
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)
        output = output.view(output.shape[0], -1)
        
        return output

In [None]:
class MultimodalDQN(nn.Module):
    
    def __init__(self, visual_model, text_model, action_space):
        
        super(MultimodalDQN, self).__init__()
        self.visual_model = visual_model
        self.text_model = text_model
        self.fc1 = nn.Linear(704, 4120)  
        self.fc2 = nn.Linear(4120, action_space)
        
    def forward(self, visual_input, text_input):
        
        visual_features = self.visual_model(visual_input) # shape 32x512
        text_features = self.text_model(text_input) # shape 32x192
                
        # print("--- visual_features:", visual_features.shape)
        # print("--- text_features:", text_features.shape)
        
        combined_features = torch.cat((visual_features, text_features), dim=1) # shape 32x704
        combined_features = self.fc1(combined_features) # shape 32x4120
        q_values = self.fc2(combined_features) # shape 32x4
        
        return q_values

## Build Agent

In [None]:
class ReplayBuffer:
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory.pop(0)
            self.memory.append(experience)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

In [None]:
def q_learning_update(model, target_model, optimizer, batch, gamma):
    
    visual_state, text_state, action, reward, next_state = zip(*batch)

    visual_state = torch.stack(visual_state)
    text_state = torch.stack(text_state)
    action = torch.tensor(action)
    reward = torch.tensor(reward)
    next_state = torch.stack(next_state)
    
    q_values = model(visual_state, text_state)
    next_q_values = target_model(visual_state, text_state).max(1).values.detach()
    expected_q_values = reward + gamma * next_q_values

    loss = nn.MSELoss()
    loss_value = loss(expected_q_values.unsqueeze(1), q_values.gather(1, action.unsqueeze(1)))
    
    optimizer.zero_grad()
    loss_value.backward()
    optimizer.step()

## Load Word2Vec

In [None]:
from gensim.models import Word2Vec, KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec

In [None]:
glove_input_file = 'weights/glove.6B.100d.txt'
word2vec_output_file = 'weights/glove.6B.100d.txt.word2vec'

glove2word2vec(glove_input_file, word2vec_output_file)
word2vec_model = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)

In [None]:
pretrained_embeddings = torch.FloatTensor(word2vec_model.vectors) # shape 400kx100

## Init Model

In [None]:
visual_model = VisualModel()
text_model = TextModel(pretrained_embeddings, hidden_dim=64)

model = MultimodalDQN(visual_model, text_model, action_space=4)
target_model = MultimodalDQN(visual_model, text_model, action_space=4)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
replay_buffer = ReplayBuffer(capacity=1000)
gamma = 0.9

## Train Model

In [None]:
images = [torch.randn(3, 64, 64) for _ in range(6)]
instructions = ["find light switch", "locate power switch", "seek light control", 
                "discover light control", "identify switch light", "pinpoint light switch"]

actions = [0, 1, 1, 2, 0, 3]

In [None]:
for epoch in range(50):
    
    print("--- Epoch:", epoch+1)
    visual_state = images[epoch % len(images)]
    text_state = instructions[epoch % len(instructions)]
    text_state = [word2vec_model.key_to_index[word] for word in text_state.split()]
    text_state = torch.LongTensor(text_state)
    action = actions[epoch % len(actions)]
    next_state = images[(epoch + 1) % len(images)]
    
    replay_buffer.push((visual_state, text_state, action, 1.0, next_state))
    
    if len(replay_buffer.memory) > 32:
        batch = replay_buffer.sample(32)
        q_learning_update(model, target_model, optimizer, batch, gamma)
        
    if epoch % 100 == 0:
        target_model.load_state_dict(model.state_dict())

## Test Model

In [None]:
visual_input = images[0].unsqueeze(0)
text_input = instructions[0]
text_input = [word2vec_model.key_to_index[word] for word in text_input.split()]
text_input = torch.LongTensor(text_input).unsqueeze(0)
    
q_values = model(visual_input, text_input)
action = q_values.argmax().item()

print(f"Predicted action: {action}")

---