In [1]:
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from gensim.models import Word2Vec

In [None]:
import warnings
warnings.filterwarnings("ignore")

## Build Model

In [2]:
class VisualModel(nn.Module):
    
    def __init__(self):
        
        super(VisualModel, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()
        
    def forward(self, x):
        
        return self.cnn(x)
    
    
class TextModel(nn.Module):
    
    def __init__(self, pretrained_embedding, hidden_dim):
        
        super(TextModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(pretrained_embedding, freeze=True)
        self.rnn = nn.LSTM(pretrained_embedding.shape[1], hidden_dim)
        
    def forward(self, x):
        
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)
        return output[:, :, -1]

In [3]:
class MultimodalDQN(nn.Module):
    
    def __init__(self, visual_model, text_model, action_space):
        
        super(MultimodalDQN, self).__init__()
        self.visual_model = visual_model
        self.text_model = text_model
        self.fc1 = nn.Linear(515, 4120)  
        self.fc2 = nn.Linear(4120, action_space)
        
    def forward(self, visual_input, text_input):
        
        visual_features = self.visual_model(visual_input)
        text_features = self.text_model(text_input)
        
        combined_features = torch.cat((visual_features, text_features), dim=1)
        combined_features = self.fc1(combined_features)
        q_values = self.fc2(combined_features)
        
        return q_values

## Build Agent

In [4]:
class ReplayBuffer:
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory.pop(0)
            self.memory.append(experience)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

In [5]:
def q_learning_update(model, target_model, optimizer, batch, gamma):
    
    visual_state, text_state, action, reward, next_state = zip(*batch)

    visual_state = torch.stack(visual_state)
    text_state = torch.stack(text_state)
    action = torch.tensor(action)
    reward = torch.tensor(reward)
    next_state = torch.stack(next_state)

    q_values = model(visual_state, text_state)
    next_q_values = target_model(visual_state, text_state).max(1).values.detach()
    expected_q_values = reward + gamma * next_q_values

    loss = nn.MSELoss()
    loss_value = loss(q_values.gather(1, action.unsqueeze(1)), expected_q_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss_value.backward()
    optimizer.step()

## Load Word2Vec

In [6]:
from gensim.models import Word2Vec, KeyedVectors
from gensim.scripts.glove2word2vec import glove2word2vec

In [7]:
glove_input_file = 'weights/glove.6B.100d.txt'
word2vec_output_file = 'weights/glove.6B.100d.txt.word2vec'

glove2word2vec(glove_input_file, word2vec_output_file)
word2vec_model = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)

  glove2word2vec(glove_input_file, word2vec_output_file)


In [8]:
pretrained_embeddings = torch.FloatTensor(word2vec_model.vectors)

## Init Model

In [9]:
visual_model = VisualModel()
text_model = TextModel(pretrained_embeddings, hidden_dim=64)

model = MultimodalDQN(visual_model, text_model, action_space=4)
target_model = MultimodalDQN(visual_model, text_model, action_space=4)



In [10]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
replay_buffer = ReplayBuffer(capacity=1000)
gamma = 0.9 # Discount factor

## Train Model

In [11]:
vocab = {"light": 0, "switch": 1, "go": 2, "find": 3}

images = [torch.randn(3, 64, 64) for _ in range(10)]
instructions = ["find light switch" for _ in range(10)]
actions = [0, 1, 1, 0, 0, 1, 0, 1, 0, 1] 

In [12]:
for epoch in range(100):
    
    print("--- Epoch:", epoch+1)
    visual_state = images[epoch % len(images)]
    text_state = instructions[epoch % len(instructions)]
    text_state = [vocab[word] for word in text_state.split()]
    text_state = torch.LongTensor(text_state)
    action = actions[epoch % len(actions)]
    next_state = images[(epoch + 1) % len(images)]
    
    replay_buffer.push((visual_state, text_state, action, 1.0, next_state))
    
    if len(replay_buffer.memory) > 32:
        batch = replay_buffer.sample(32)
        q_learning_update(model, target_model, optimizer, batch, gamma)
        
    if epoch % 100 == 0:
        target_model.load_state_dict(model.state_dict())

--- Epoch: 1
--- Epoch: 2
--- Epoch: 3
--- Epoch: 4
--- Epoch: 5
--- Epoch: 6
--- Epoch: 7
--- Epoch: 8
--- Epoch: 9
--- Epoch: 10
--- Epoch: 11
--- Epoch: 12
--- Epoch: 13
--- Epoch: 14
--- Epoch: 15
--- Epoch: 16
--- Epoch: 17
--- Epoch: 18
--- Epoch: 19
--- Epoch: 20
--- Epoch: 21
--- Epoch: 22
--- Epoch: 23
--- Epoch: 24
--- Epoch: 25
--- Epoch: 26
--- Epoch: 27
--- Epoch: 28
--- Epoch: 29
--- Epoch: 30
--- Epoch: 31
--- Epoch: 32
--- Epoch: 33
--- Epoch: 34
--- Epoch: 35
--- Epoch: 36
--- Epoch: 37
--- Epoch: 38
--- Epoch: 39
--- Epoch: 40
--- Epoch: 41
--- Epoch: 42
--- Epoch: 43
--- Epoch: 44
--- Epoch: 45
--- Epoch: 46
--- Epoch: 47
--- Epoch: 48
--- Epoch: 49
--- Epoch: 50
--- Epoch: 51
--- Epoch: 52
--- Epoch: 53
--- Epoch: 54
--- Epoch: 55
--- Epoch: 56
--- Epoch: 57
--- Epoch: 58
--- Epoch: 59
--- Epoch: 60
--- Epoch: 61
--- Epoch: 62
--- Epoch: 63
--- Epoch: 64
--- Epoch: 65
--- Epoch: 66
--- Epoch: 67
--- Epoch: 68
--- Epoch: 69
--- Epoch: 70
--- Epoch: 71
--- Epoch: 72
-

## Test Model

In [13]:
visual_input = images[0].unsqueeze(0)
text_input = instructions[0]
text_input = [vocab[word] for word in text_input.split()]
text_input = torch.LongTensor(text_input).unsqueeze(0)
    
q_values = model(visual_input, text_input)
action = q_values.argmax().item()

print(f"Predicted action: {action}")

Predicted action: 0


---