In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## MMDQN with 3 FC Layers

In [2]:
class MultimodalDQN(nn.Module):
    
    def __init__(self, visual_model, text_model, action_size, num_objects, seed):
        
        super(MultimodalDQN, self).__init__()
        
        self.seed = torch.manual_seed(seed)
        self.num_objects = num_objects
        
        self.visual_model = visual_model
        self.text_model = text_model  
        
        # Add separate fully connected layers for each object
        self.fc_switch = nn.Linear(visual_model.fc.out_features + text_model.fc.out_features, action_size)
        self.fc_fridge = nn.Linear(visual_model.fc.out_features + text_model.fc.out_features, action_size)
        self.fc_door = nn.Linear(visual_model.fc.out_features + text_model.fc.out_features, action_size)
        
    def forward(self, visual_input, text_input):
        
        visual_features = self.visual_model(visual_input.to(device))
        text_features = self.text_model(text_input.to(device))
        
        # Combine visual and text features
        combined_features = torch.cat((visual_features, text_features), dim=1)
        
        # Forward pass for each object head
        q_values_switch = self.fc_switch(combined_features)
        q_values_fridge = self.fc_fridge(combined_features)
        q_values_door = self.fc_door(combined_features)
        
        return q_values_switch, q_values_fridge, q_values_door

## MMDQN with Attention Weights

In [3]:
class MultimodalDQN(nn.Module):
    
    def __init__(self, visual_model, text_model, action_size, num_objects, seed):
        
        super(MultimodalDQN, self).__init__()
        
        self.seed = torch.manual_seed(seed)
        self.num_objects = num_objects
        
        self.visual_model = visual_model
        self.text_model = text_model  
        
        # Add separate fully connected layers for each object
        self.fc_switch = nn.Linear(visual_model.fc.out_features + text_model.fc.out_features, action_size)
        self.fc_fridge = nn.Linear(visual_model.fc.out_features + text_model.fc.out_features, action_size)
        self.fc_door = nn.Linear(visual_model.fc.out_features + text_model.fc.out_features, action_size)
        
        # Add an embedding layer for the instruction
        self.embedding = nn.Embedding(num_objects, embedding_dim)
        
    def forward(self, visual_input, text_input, instruction):
        
        visual_features = self.visual_model(visual_input.to(device))
        text_features = self.text_model(text_input.to(device))
        
        # Combine visual and text features
        combined_features = torch.cat((visual_features, text_features), dim=1)
        
        # Forward pass for each object head
        q_values_switch = self.fc_switch(combined_features)
        q_values_fridge = self.fc_fridge(combined_features)
        q_values_door = self.fc_door(combined_features)
        
        # Apply attention based on the instruction
        attention_weights = F.softmax(self.get_attention_weights(instruction), dim=1)
        
        # Combine the Q-values with attention
        q_values_combined = attention_weights[:, 0].view(-1, 1) * q_values_switch + \
                            attention_weights[:, 1].view(-1, 1) * q_values_fridge + \
                            attention_weights[:, 2].view(-1, 1) * q_values_door
        
        return q_values_combined
    
    def get_attention_weights(self, instruction):
        
        # Implement a mechanism to compute attention weights based on the instruction
        # This can be a simple lookup in a learned embedding or a more complex attention mechanism
        # For simplicity, here's a placeholder that assumes the instruction is a one-hot encoded vector
        embedded_instruction = self.embedding(instruction)
        attention_weights = F.softmax(embedded_instruction, dim=1)
        
        return attention_weights

---