#### This script works on transform the seq2seq model to a graphical model using antra

In [1]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import json
from model import *
from ReaSCAN_dataset import *
import torch.nn.functional as F
import torch
from antra.antra import *
from decode_graphical_models import *
from torch.utils.data.sampler import RandomSampler, SequentialSampler

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
if isnotebook():
    device = torch.device("cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# setting up the seeds.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

Initialize ReaSCAN dataset to load config of the model

In [2]:
data_directory = "../../../data-files/gSCAN-Simple/"
data_file = "data-compositional-splits.txt"
input_vocab_file = "input_vocabulary.txt"
target_vocab_file = "target_vocabulary.txt"
dataset = ReaSCANDataset(
    json.load(open(os.path.join(data_directory, data_file), "r")), 
    data_directory, split="train",
    input_vocabulary_file=input_vocab_file,
    target_vocabulary_file=target_vocab_file,
    generate_vocabulary=False,
    k=0,
)
# Loading a couple of example from ReaSCAN.
dataset.read_dataset(
    max_examples=100,
    simple_situation_representation=True
)

2021-08-12 02:10 Formulating the dataset from the passed in json file...
2021-08-12 02:10 Loading vocabularies...
2021-08-12 02:10 Done loading vocabularies.
2021-08-12 02:10 Converting dataset to tensors...


In [3]:
# dataset.save_vocabularies(
#     input_vocabulary_file="input_vocabulary.txt", 
#     target_vocabulary_file="target_vocabulary.txt"
# )

In [85]:
def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

def _generate_lstm_step_fxn(step_module, i, max_decode_step, 
                            image_size=6, hidden_dim=100, 
                            vocab_size=6):
    """ 
    Generate a function for a layer in lstm.
    """

    def _lstm_step_fxn(hidden_states):
        if isnotebook():
            
            last_states = hidden_states
            batch_size = last_states.size(0)

            last_hidden = last_states[:,:hidden_dim].unsqueeze(dim=1).contiguous()
            last_cell = last_states[:,hidden_dim:hidden_dim*2].unsqueeze(dim=1).contiguous()
            input_tokens_sorted = last_states[:,hidden_dim*2:hidden_dim*2+max_decode_step].long().contiguous()
            
            commands_lengths = last_states[
                :,hidden_dim*2+max_decode_step:hidden_dim*2+max_decode_step+1
            ].long().contiguous()
            
            projected_keys_visual = last_states[
                :,hidden_dim*2+max_decode_step+1:hidden_dim*2+max_decode_step+1+image_size*image_size*hidden_dim
            ].reshape(
                batch_size, image_size*image_size, hidden_dim
            ).contiguous()
            
            _output = last_states[
                :,hidden_dim*2+max_decode_step+1+image_size*image_size*hidden_dim:hidden_dim*2+max_decode_step+1+image_size*image_size*hidden_dim+vocab_size
            ].contiguous()
            
            projected_keys_textual = last_states[
                :,hidden_dim*2+max_decode_step+1+image_size*image_size*hidden_dim+vocab_size:
            ].reshape(
                batch_size, -1, hidden_dim
            ).contiguous()
            
            (output, hidden) = step_module.forward(
                lstm_input_tokens_sorted=input_tokens_sorted[:, i], 
                lstm_hidden=(last_hidden, last_cell), 
                lstm_projected_keys_textual=projected_keys_textual, 
                lstm_commands_lengths=commands_lengths, 
                lstm_projected_keys_visual=projected_keys_visual,
                tag="_lstm_step_fxn"
            )
        else:
            (output, hidden) = step_module(
                lstm_input_tokens_sorted=hidden_states["input_tokens_sorted"][:, i], 
                lstm_hidden=hidden_states["hidden"], 
                lstm_projected_keys_textual=hidden_states["projected_keys_textual"], 
                lstm_commands_lengths=hidden_states["commands_lengths"], 
                lstm_projected_keys_visual=hidden_states["projected_keys_visual"],
                tag="_lstm_step_fxn"
            )
        
        last_states = torch.cat(
                [
                    hidden[0].squeeze(dim=1),
                    hidden[1].squeeze(dim=1),
                    input_tokens_sorted,
                    commands_lengths,
                    projected_keys_visual.reshape(batch_size, -1),
                    output,
                    projected_keys_textual.reshape(batch_size, -1),
                ], dim=-1
        )
        return last_states

    return _lstm_step_fxn

def generate_compute_graph(
    model, 
    max_decode_step,
    cache_results=False, 
    vocab_size=6,
    image_size=6, 
    hidden_dim=100
):
    
    
    ####################
    #
    # Input preparation.
    #
    ####################
    """
    Command Inputs.
    """
    command_world_inputs = ["commands_input", "commands_lengths"]
    command_world_input_leaves = [
        GraphNode.leaf(name=name, use_default=True, default_value=None) 
        for name in command_world_inputs
    ]
    @GraphNode(*command_world_input_leaves, cache_results=False)
    def command_input_preparation(
        commands_input, commands_lengths,
    ):
        input_dict = {
            "commands_input": commands_input,
            "commands_lengths": commands_lengths,
        }
        # We may not need the following fields. But we leave it here in case we need these
        # to generate other inputs.
        batch_size = input_dict["commands_input"].shape[0]
        device = input_dict["commands_input"].device
        return input_dict
    
    """
    Situation Inputs.
    """
    situation_inputs = ["situations_input"]
    situation_input_leaves = [
        GraphNode.leaf(name=name, use_default=True, 
                       default_value=None) 
        for name in situation_inputs
    ]
    @GraphNode(*situation_input_leaves, cache_results=cache_results)
    def situation_input_preparation(
        situations_input,
    ):
        return {
            "situations_input": situations_input,
        }
        
    """
    Target Inputs
    """
    target_sequence_inputs = ["target_batch", "target_lengths"]
    target_sequence_input_leaves = [
        GraphNode.leaf(name=name, use_default=True, 
                       default_value=None) 
        for name in target_sequence_inputs
    ]
    @GraphNode(*target_sequence_input_leaves, 
               cache_results=cache_results)
    def target_sequence_input_preparation(
        target_batch, target_lengths
    ):
        return {
            "target_batch": target_batch,
            "target_lengths": target_lengths,
        }
    
    ####################
    #
    # Input encoding.
    #
    ####################
    """
    Situation Encoding.
    """
    @GraphNode(situation_input_preparation, 
               cache_results=cache_results)
    def situation_encode(input_dict):
        if isnotebook():
            return model.forward(
                situations_input=input_dict["situations_input"],
                tag="situation_encode"
            )
        else:
            return model(
                situations_input=input_dict["situations_input"],
                tag="situation_encode"
            )
    
    """
    Language Encoding.
    """
    @GraphNode(command_input_preparation, 
               cache_results=cache_results)
    def command_input_encode(input_dict):
        if isnotebook():
            return model.forward(
                commands_input=input_dict["commands_input"], 
                commands_lengths=input_dict["commands_lengths"],
                tag="command_input_encode"
            )
        else:
            return model(
                commands_input=input_dict["commands_input"], 
                commands_lengths=input_dict["commands_lengths"],
                tag="command_input_encode"
            )
    
    ####################
    #
    # Decoding.
    #
    ####################
    """
    Preparation of Decoding Data structure.
    """
    @GraphNode(command_input_encode, situation_encode, 
               target_sequence_input_preparation, 
               cache_results=cache_results)
    def decode_input_preparation(c_encode, s_encode, target_sequence):
        if isnotebook():
            hidden_states = model.forward(
                target_batch=target_sequence["target_batch"],
                target_lengths=target_sequence["target_lengths"],
                command_hidden=c_encode["command_hidden"],
                command_encoder_outputs=c_encode["command_encoder_outputs"],
                command_sequence_lengths=c_encode["command_sequence_lengths"],
                encoded_situations=s_encode,
                tag="decode_input_preparation"
            )
        else:
            hidden_states = model(
                target_batch=target_sequence["target_batch"],
                target_lengths=target_sequence["target_lengths"],
                command_hidden=c_encode["command_hidden"],
                command_encoder_outputs=c_encode["command_encoder_outputs"],
                command_sequence_lengths=c_encode["command_sequence_lengths"],
                encoded_situations=s_encode,
                tag="decode_input_preparation"
            )
        # dummy output tensor for the first time.
        batch_size = hidden_states["input_tokens_sorted"].size(0)
        return torch.cat(
                [
                    hidden_states["hidden"][0].squeeze(dim=1),
                    hidden_states["hidden"][1].squeeze(dim=1),
                    hidden_states["input_tokens_sorted"],
                    hidden_states["commands_lengths"],
                    hidden_states["projected_keys_visual"].reshape(batch_size, -1),
                    torch.zeros(batch_size, vocab_size),
                    # we need the textual key to be at last since the dimension is not interpretable.
                    hidden_states["projected_keys_textual"].reshape(batch_size, -1)
                ], dim=-1
            )
        

    hidden_layer = decode_input_preparation
    """
    Here, we set to a static bound of decoding steps.
    """
    for i in range(max_decode_step):
        f = _generate_lstm_step_fxn(
            model, i, max_decode_step,
            vocab_size=vocab_size,
            image_size=image_size, 
            hidden_dim=hidden_dim
        )
        hidden_layer = GraphNode(hidden_layer,
                                 name=f"lstm_step_{i}",
                                 forward=f, cache_results=cache_results)
        
    # Do we really need this?
    # """
    # Formulating outputs.
    # """
    # @GraphNode(hidden_layer, cache_results=cache_results)
    # def output_preparation(hidden_states):
    #     hidden_states["return_lstm_output"] = torch.cat(
    #         hidden_states["return_lstm_output"], dim=0)
    #     hidden_states["return_attention_weights"] = torch.cat(
    #         hidden_states["return_attention_weights"], dim=0)
    #     
    #     decoder_output_batched = hidden_states["return_lstm_output"]
    #     context_situation = hidden_states["return_attention_weights"]
    #     decoder_output_batched = F.log_softmax(decoder_output_batched, dim=-1)
    #     
    #     # if model.module.auxiliary_task:
    #     if False:
    #         pass # Not implemented yet.
    #     else:
    #         target_position_scores = torch.zeros(1), torch.zeros(1)
    #        # We are not returning this as well, since it is not used...
    #    print(decoder_output_batched.shape)
    #    return decoder_output_batched.transpose(0, 1) # [batch_size, max_target_seq_length, target_vocabulary_size]
    # root = hidden_layer # TODO: removing this and continue.
    
    return hidden_layer
    
class ReaSCANMultiModalLSTMCompGraph(ComputationGraph):
    def __init__(self, model,
                 max_decode_step,
                 cache_results=False):
        self.model = model
        root = generate_compute_graph(
            model,
            max_decode_step,
        )

        super().__init__(root)

Loading model to the computational graph

In [86]:
model = Model(
    input_vocabulary_size=dataset.input_vocabulary_size,
    target_vocabulary_size=dataset.target_vocabulary_size,
    num_cnn_channels=dataset.image_channels,
    input_padding_idx=dataset.input_vocabulary.pad_idx,
    target_pad_idx=dataset.target_vocabulary.pad_idx,
    target_eos_idx=dataset.target_vocabulary.eos_idx,
    # language encoder config
    embedding_dimension=25,
    encoder_hidden_size=100,
    num_encoder_layers=1,
    encoder_dropout_p=0.3,
    encoder_bidirectional=True,
    # world encoder config
    simple_situation_representation=True,
    cnn_hidden_num_channels=50,
    cnn_kernel_size=7,
    cnn_dropout_p=0.1,
    auxiliary_task=False,
    # decoder config
    num_decoder_layers=1,
    attention_type="bahdanau",
    decoder_dropout_p=0.3,
    decoder_hidden_size=100,
    conditional_attention=True,
    output_directory="../../../saved_models/gSCAN-Simple/"
)
model.eval()
model.to(device)
G = ReaSCANMultiModalLSTMCompGraph(
    model=model,
    max_decode_step=13
)

In [87]:
train_data, _ = dataset.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=2)

Loading some examples to verify

In [88]:
for step, batch in enumerate(train_dataloader):
    # just using this loop to get a pair of examples
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch
    break

In [89]:
low1 = {
    "commands_input": input_batch, 
    "commands_lengths": input_lengths_batch,
    "situations_input": situation_batch,
    "target_batch": target_batch,
    "target_lengths": target_lengths_batch,
}
low1 = GraphInput(low1, batched=True, batch_dim=0, cache_results=False)

low2 = {
    "commands_input": dual_input_batch, 
    "commands_lengths": dual_input_lengths_batch,
    "situations_input": dual_situation_batch,
    "target_batch": dual_target_batch,
    "target_lengths": dual_target_lengths_batch,
}
low2 = GraphInput(low2, batched=True, batch_dim=0, cache_results=False)

In [90]:
dual_target_batch.shape

torch.Size([2, 13])

In [93]:
low2_hidden = G.compute_node('lstm_step_2', low2)

In [94]:
low2_hidden

tensor([[-0.0345, -0.2683,  0.0763,  ...,  0.0531, -0.1266, -0.0160],
        [-0.0148, -0.2442, -0.0280,  ...,  0.0531, -0.1266, -0.0160]],
       grad_fn=<CatBackward>)

In [106]:
low2_hidden_select = {
    'hidden': low2_hidden['hidden']
}

In [107]:
from antra.antra.utils import serialize

In [108]:
keys = [serialize(x) for x in low2_hidden["projected_keys_textual"]]
low_interv_input = GraphInput.batched(
    values={"lstm_step_1": low2_hidden_select},
    batch_dim=0,
    keys=keys
)

In [109]:
low_interv = Intervention.batched(
    low1, low_interv_input, 
    cache_results=False,
    cache_base_results=False,
)

In [120]:
G.intervene_node("lstm_step_1", low_interv)

({'return_lstm_output': [tensor([[[ 0.1085,  0.3292, -0.3566,  0.1918,  0.0502,  0.2859, -0.1443,
              0.1870]]], grad_fn=<UnsqueezeBackward0>),
   tensor([[[ 0.4515, -0.1367, -0.0346,  0.1133, -0.0544, -0.4565, -0.1353,
             -0.0419]]], grad_fn=<UnsqueezeBackward0>)],
  'return_attention_weights': [tensor([[[0.0273, 0.0286, 0.0273, 0.0274, 0.0290, 0.0280, 0.0275, 0.0278,
             0.0283, 0.0273, 0.0280, 0.0276, 0.0273, 0.0274, 0.0280, 0.0288,
             0.0275, 0.0291, 0.0274, 0.0283, 0.0275, 0.0282, 0.0269, 0.0274,
             0.0269, 0.0284, 0.0276, 0.0277, 0.0281, 0.0274, 0.0275, 0.0294,
             0.0275, 0.0273, 0.0274, 0.0268]]], grad_fn=<UnsqueezeBackward0>),
   tensor([[[0.0273, 0.0286, 0.0273, 0.0274, 0.0290, 0.0280, 0.0275, 0.0278,
             0.0283, 0.0273, 0.0280, 0.0276, 0.0273, 0.0274, 0.0280, 0.0288,
             0.0275, 0.0291, 0.0274, 0.0283, 0.0275, 0.0282, 0.0269, 0.0274,
             0.0269, 0.0284, 0.0276, 0.0277, 0.0281, 0.0274, 0.0275

In [121]:
G.compute_node(
    "lstm_step_2", 
     GraphInput({
        "lstm_step_1":  low2_hidden
     })          
)

AttributeError: 'NoneType' object has no attribute 'shape'

In [118]:
low2_hidden

{'return_lstm_output': [tensor([[[ 0.1083,  0.3289, -0.3572,  0.1932,  0.0509,  0.2854, -0.1446,
             0.1894]]], grad_fn=<UnsqueezeBackward0>),
  tensor([[[ 0.4512, -0.1370, -0.0352,  0.1147, -0.0538, -0.4568, -0.1355,
            -0.0396]]], grad_fn=<UnsqueezeBackward0>)],
 'return_attention_weights': [tensor([[[0.0280, 0.0275, 0.0274, 0.0281, 0.0285, 0.0277, 0.0273, 0.0281,
            0.0282, 0.0272, 0.0273, 0.0276, 0.0273, 0.0278, 0.0277, 0.0290,
            0.0284, 0.0279, 0.0279, 0.0277, 0.0290, 0.0275, 0.0267, 0.0270,
            0.0278, 0.0282, 0.0272, 0.0273, 0.0274, 0.0289, 0.0282, 0.0273,
            0.0283, 0.0284, 0.0266, 0.0280]]], grad_fn=<UnsqueezeBackward0>),
  tensor([[[0.0280, 0.0275, 0.0274, 0.0281, 0.0285, 0.0277, 0.0273, 0.0281,
            0.0282, 0.0272, 0.0273, 0.0276, 0.0273, 0.0278, 0.0277, 0.0290,
            0.0284, 0.0279, 0.0279, 0.0277, 0.0290, 0.0275, 0.0267, 0.0270,
            0.0278, 0.0282, 0.0272, 0.0273, 0.0274, 0.0289, 0.0282, 0.0273,
   

In [11]:
input_dict = {
    "commands_input": input_batch, 
    "commands_lengths": input_lengths,
    "situations_input": situation_batch,
    "target_batch": target_batch,
    "target_lengths": target_lengths,
}
all_in = GraphInput(input_dict, batched=True, batch_dim=0)

In [12]:
g.compute_node(
    "command_input_encode", 
    all_in
)

{'command_hidden': tensor([[-0.0716,  0.0218, -0.0727, -0.1464,  0.2586,  0.0150,  0.1276, -0.1273,
           0.1017, -0.1471,  0.0721,  0.0877, -0.0730,  0.0089,  0.0756,  0.0308,
           0.0855,  0.1926, -0.0424, -0.1425, -0.0747,  0.2126,  0.1662, -0.0599,
          -0.1097, -0.0167, -0.0249,  0.1125,  0.1209, -0.0604,  0.1541, -0.1178,
          -0.2140,  0.2301, -0.1050, -0.1421, -0.0394,  0.0096,  0.1389, -0.0672,
           0.1375,  0.1194,  0.0805,  0.1265, -0.2354,  0.2100,  0.0325, -0.2087,
          -0.1465, -0.0817,  0.1489,  0.1670, -0.2006, -0.1258,  0.1037, -0.1835,
          -0.0210,  0.0387, -0.0457,  0.0012, -0.1044, -0.0470,  0.0509,  0.1522,
           0.0645, -0.1165,  0.1050, -0.0331, -0.0771,  0.2220, -0.0576,  0.1506,
           0.0650,  0.0381, -0.0998, -0.0811,  0.0228, -0.0296,  0.0171, -0.0422,
           0.0539, -0.0739, -0.0856, -0.0870, -0.0091, -0.0355, -0.2654, -0.0285,
          -0.1823,  0.0548,  0.1766, -0.0890,  0.3243, -0.0289, -0.1823,  0.1486

Setting up training loop for this model in antra