#### This script works on transform the seq2seq model to a graphical model using antra

In [9]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import json
from model import *
from ReaSCAN_dataset import *
import torch.nn.functional as F
import torch
from antra.antra import *
def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
if isnotebook():
    device = torch.device("cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# setting up the seeds.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

Initialize ReaSCAN dataset to load config of the model

In [2]:
data_directory = "../../../data-files/ReaSCAN-Simple/"
data_file = "data-compositional-splits.txt"
input_vocab_file = "input_vocabulary.txt"
target_vocab_file = "target_vocabulary.txt"
dataset = ReaSCANDataset(
    json.load(open(os.path.join(data_directory, data_file), "r")), 
    data_directory, split="train",
    input_vocabulary_file=input_vocab_file,
    target_vocabulary_file=target_vocab_file,
    generate_vocabulary=False,
    k=0,
)
# Loading a couple of example from ReaSCAN.
dataset.read_dataset(
    max_examples=10,
    simple_situation_representation=True
)

2021-07-21 11:07 Formulating the dataset from the passed in json file...
2021-07-21 11:07 Loading vocabularies...
2021-07-21 11:07 Done loading vocabularies.
2021-07-21 11:07 Converting dataset to tensors...


Define computatinal graph of our model

In [10]:
def _generate_lstm_step_fxn(step_module, i):
    """ 
    Generate a function for a layer in lstm.
    """

    def _lstm_step_fxn(hidden_states):
        (output, hidden, context_situation, attention_weights_commands,
         attention_weights_situations) = step_module(
            hidden_states["input_tokens_sorted"][:, i], 
            hidden_states["hidden"], 
            hidden_states["projected_keys_textual"], 
            hidden_states["commands_lengths"], 
            hidden_states["projected_keys_visual"],
        )
        hidden_states["hidden"] = hidden
        hidden_states["return_lstm_output"] += [output.unsqueeze(0)]
        hidden_states["return_attention_weights"] += [attention_weights_situations.unsqueeze(0)]
        
        return hidden_states

    return _lstm_step_fxn

def generate_compute_graph(model):
    
    
    ####################
    #
    # Input preparation.
    #
    ####################
    """
    Command Inputs.
    """
    command_world_inputs = ["commands_input", "commands_lengths"]
    command_world_input_leaves = [
        GraphNode.leaf(name=name, use_default=True, default_value=None) 
        for name in command_world_inputs
    ]
    @GraphNode(*command_world_input_leaves, cache_results=False)
    def command_input_preparation(
        commands_input, commands_lengths,
    ):
        input_dict = {
            "commands_input": commands_input,
            "commands_lengths": commands_lengths,
        }
        # We may not need the following fields. But we leave it here in case we need these
        # to generate other inputs.
        batch_size = input_dict["commands_input"].shape[0]
        device = input_dict["commands_input"].device
        return input_dict
    
    """
    Situation Inputs.
    """
    situation_inputs = ["situations_input"]
    situation_input_leaves = [
        GraphNode.leaf(name=name, use_default=True, default_value=None) 
        for name in situation_inputs
    ]
    @GraphNode(*situation_input_leaves, cache_results=False)
    def situation_input_preparation(
        situations_input,
    ):
        input_dict = {
            "situations_input": situations_input,
        }
        return input_dict
        
    """
    Target Inputs
    """
    target_sequence_inputs = ["target_batch", "target_lengths"]
    target_sequence_input_leaves = [
        GraphNode.leaf(name=name, use_default=True, default_value=None) 
        for name in target_sequence_inputs
    ]
    @GraphNode(*target_sequence_input_leaves, cache_results=False)
    def target_sequence_input_preparation(
        target_batch, target_lengths
    ):
        input_dict = {
            "target_batch": target_batch,
            "target_lengths": target_lengths,
        }
        return input_dict
    
    ####################
    #
    # Input encoding.
    #
    ####################
    """
    Situation Encoding.
    """
    @GraphNode(situation_input_preparation)
    def situation_encode(input_dict):
        encoded_image = model.situation_encoder(
            input_images=input_dict["situations_input"]
        )
        return encoded_image
    
    """
    Language Encoding.
    """
    @GraphNode(command_input_preparation)
    def command_input_encode(input_dict):
        hidden, encoder_outputs = model.encoder(
            input_batch=input_dict["commands_input"], 
            input_lengths=input_dict["commands_lengths"],
        )
        output_dict = {
            "command_hidden" : hidden,
            "command_encoder_outputs" : encoder_outputs["encoder_outputs"],
            "command_sequence_lengths" : encoder_outputs["sequence_lengths"],
        }
        return output_dict
    
    ####################
    #
    # Decoding.
    #
    ####################
    """
    Preparation of Decoding Data structure.
    """
    @GraphNode(command_input_encode, situation_encode, target_sequence_input_preparation)
    def decode_input_preparation(c_encode, s_encode, target_sequence):
        """
        The decoding step can be represented as:
        h_T = f(h_T-1, C)
        where h_i is the recurring hidden states, and C
        is the static state representations.
        
        In this function, we want to abstract the C.
        """
        
        initial_hidden = model.attention_decoder.initialize_hidden(
            model.tanh(model.enc_hidden_to_dec_hidden(c_encode["command_hidden"])))
        
        """
        Renaming.
        """
        input_tokens, input_lengths = target_sequence["target_batch"], target_sequence["target_lengths"]
        init_hidden = initial_hidden
        encoded_commands = c_encode["command_encoder_outputs"]
        commands_lengths = c_encode["command_sequence_lengths"]
        encoded_situations = s_encode
        
        """
        Reshaping as well as getting the context-guided attention weights.
        """
        batch_size, max_time = input_tokens.size()
        # Sort the sequences by length in descending order
        input_lengths = torch.tensor(input_lengths, dtype=torch.long, device=device)
        input_lengths, perm_idx = torch.sort(input_lengths, descending=True)
        input_tokens_sorted = input_tokens.index_select(dim=0, index=perm_idx)
        initial_h, initial_c = init_hidden
        hidden = (initial_h.index_select(dim=1, index=perm_idx),
                  initial_c.index_select(dim=1, index=perm_idx))
        encoded_commands = encoded_commands.index_select(dim=1, index=perm_idx)
        commands_lengths = torch.tensor(commands_lengths, device=device)
        commands_lengths = commands_lengths.index_select(dim=0, index=perm_idx)
        encoded_situations = encoded_situations.index_select(dim=0, index=perm_idx)

        # For efficiency
        projected_keys_visual = model.visual_attention.key_layer(
            encoded_situations)  # [batch_size, situation_length, dec_hidden_dim]
        projected_keys_textual = model.textual_attention.key_layer(
            encoded_commands)  # [max_input_length, batch_size, dec_hidden_dim]
        
        return {
            "return_lstm_output":[],
            "return_attention_weights":[],
            "hidden":hidden,
            "input_tokens_sorted":input_tokens_sorted,
            "projected_keys_textual":projected_keys_textual,
            "commands_lengths":commands_lengths,
            "projected_keys_visual":projected_keys_visual,
            "perm_idx":perm_idx,
            "seq_lengths":input_lengths,
        }
    hidden_layer = decode_input_preparation
    """
    Here, we set to a static bound of decoding steps.
    """
    max_time = 4
    for i in range(max_time):
        f = _generate_lstm_step_fxn(model.attention_decoder.forward_step, i)
        hidden_layer = GraphNode(hidden_layer,
                                 name=f"lstm_step_{i}",
                                 forward=f)
    """
    Formulating outputs.
    """
    @GraphNode(hidden_layer)
    def output_preparation(hidden_states):
        hidden_states["return_lstm_output"] = torch.cat(
            hidden_states["return_lstm_output"], dim=0)
        hidden_states["return_attention_weights"] = torch.cat(
            hidden_states["return_attention_weights"], dim=0)
        
        _, unperm_idx = hidden_states["perm_idx"].sort(0)
        hidden_states["return_lstm_output"] = hidden_states["return_lstm_output"].index_select(dim=1, index=unperm_idx)  # [max_time, batch_size, output_size]
        hidden_states["seq_lengths"] = hidden_states["seq_lengths"][unperm_idx].tolist()
        hidden_states["return_attention_weights"] = hidden_states["return_attention_weights"].index_select(dim=1, index=unperm_idx)
        
        decoder_output_batched = hidden_states["return_lstm_output"]
        context_situation = hidden_states["return_attention_weights"]
        decoder_output_batched = F.log_softmax(decoder_output_batched, dim=-1)
        
        if model.auxiliary_task:
            pass # Not implemented yet.
        else:
            target_position_scores = torch.zeros(1), torch.zeros(1)
            # We are not returning this as well, since it is not used...
        
        return (decoder_output_batched.transpose(0, 1), 
                target_position_scores) # [batch_size, max_target_seq_length, target_vocabulary_size]
    
    root = output_preparation # TODO: removing this and continue.
    
    return root
    
class ReaSCANMultiModalLSTMCompGraph(ComputationGraph):
    def __init__(self, model: torch.nn.Module):
        self.model = model
        root = generate_compute_graph(model)

        super().__init__(root)

Loading model to the computational graph

In [15]:
model = Model(
    input_vocabulary_size=dataset.input_vocabulary_size,
    target_vocabulary_size=dataset.target_vocabulary_size,
    num_cnn_channels=dataset.image_channels,
    input_padding_idx=dataset.input_vocabulary.pad_idx,
    target_pad_idx=dataset.target_vocabulary.pad_idx,
    target_eos_idx=dataset.target_vocabulary.eos_idx,
    # language encoder config
    embedding_dimension=25,
    encoder_hidden_size=100,
    num_encoder_layers=1,
    encoder_dropout_p=0.3,
    encoder_bidirectional=True,
    # world encoder config
    simple_situation_representation=True,
    cnn_hidden_num_channels=50,
    cnn_kernel_size=7,
    cnn_dropout_p=0.1,
    auxiliary_task=False,
    # decoder config
    num_decoder_layers=1,
    attention_type="bahdanau",
    decoder_dropout_p=0.3,
    decoder_hidden_size=100,
    conditional_attention=True,
    output_directory="../../../saved_models/ReaSCAN-Simple/"
)
model.eval()
model.to(device)
g = ReaSCANMultiModalLSTMCompGraph(
     model=model,
)

Loading some examples to verify

In [16]:
for (input_batch, input_lengths, _, situation_batch, _, target_batch,
     target_lengths, agent_positions, target_positions) in dataset.get_data_iterator(batch_size=1):
    target_scores, target_position_scores = model(
        commands_input=input_batch, commands_lengths=input_lengths,
        situations_input=situation_batch, target_batch=target_batch,
        target_lengths=target_lengths
    )
    print(target_scores)
    break

tensor([[[-2.1911, -1.9745, -2.0447, -1.8779, -1.9720, -1.7679, -1.8505],
         [-1.8861, -1.9734, -1.8465, -2.0182, -2.0061, -1.8882, -2.0189],
         [-1.8752, -1.9630, -1.8398, -2.0234, -2.0173, -1.9010, -2.0192],
         [-1.8743, -1.9563, -1.8381, -2.0251, -2.0205, -1.9070, -2.0178],
         [-1.8760, -1.9521, -1.8376, -2.0255, -2.0214, -1.9095, -2.0168],
         [-1.9016, -2.0647, -1.8947, -2.1291, -1.9441, -1.8685, -1.8512],
         [-1.8975, -1.9529, -1.8267, -2.0353, -2.0107, -1.9009, -2.0152],
         [-1.8850, -1.9511, -1.8303, -2.0352, -2.0170, -1.9037, -2.0176],
         [-1.8816, -1.9495, -1.8333, -2.0326, -2.0192, -1.9060, -2.0176],
         [-1.8812, -1.9481, -1.8350, -2.0303, -2.0201, -1.9076, -2.0170],
         [-1.9035, -2.0623, -1.8934, -2.1322, -1.9451, -1.8665, -1.8514],
         [-1.8994, -1.9508, -1.8257, -2.0376, -2.0103, -1.8992, -2.0168],
         [-1.8863, -1.9499, -1.8298, -2.0365, -2.0167, -1.9025, -2.0185],
         [-1.8825, -1.9487, -1.8330, -

In [17]:
input_dict = {
    "commands_input": input_batch, 
    "commands_lengths": input_lengths,
    "situations_input": situation_batch,
    "target_batch": target_batch,
    "target_lengths": target_lengths,
}
all_in = GraphInput(input_dict, batched=True, batch_dim=0)

In [18]:
target_scores, target_position_scores = g.compute(all_in)
print(target_scores)

tensor([[[-2.1911, -1.9745, -2.0447, -1.8779, -1.9720, -1.7679, -1.8505],
         [-1.8861, -1.9734, -1.8465, -2.0182, -2.0061, -1.8882, -2.0189],
         [-1.8752, -1.9630, -1.8398, -2.0234, -2.0173, -1.9010, -2.0192],
         [-1.8743, -1.9563, -1.8381, -2.0251, -2.0205, -1.9070, -2.0178]]],
       grad_fn=<TransposeBackward0>)
