#### This script works on transform the seq2seq model to a graphical model using antra

In [None]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import json
from model import *
from ReaSCAN_dataset import *
import torch.nn.functional as F
import torch
from antra.antra import *
from decode_graphical_models import *

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
if isnotebook():
    device = torch.device("cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# setting up the seeds.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

Initialize ReaSCAN dataset to load config of the model

In [None]:
data_directory = "../../../data-files/ReaSCAN-Simple/"
data_file = "data-compositional-splits.txt"
input_vocab_file = "input_vocabulary.txt"
target_vocab_file = "target_vocabulary.txt"
dataset = ReaSCANDataset(
    json.load(open(os.path.join(data_directory, data_file), "r")), 
    data_directory, split="train",
    input_vocabulary_file=input_vocab_file,
    target_vocabulary_file=target_vocab_file,
    generate_vocabulary=False,
    k=0,
)
# Loading a couple of example from ReaSCAN.
dataset.read_dataset(
    max_examples=10,
    simple_situation_representation=True
)

Loading model to the computational graph

In [None]:
model = Model(
    input_vocabulary_size=dataset.input_vocabulary_size,
    target_vocabulary_size=dataset.target_vocabulary_size,
    num_cnn_channels=dataset.image_channels,
    input_padding_idx=dataset.input_vocabulary.pad_idx,
    target_pad_idx=dataset.target_vocabulary.pad_idx,
    target_eos_idx=dataset.target_vocabulary.eos_idx,
    # language encoder config
    embedding_dimension=25,
    encoder_hidden_size=100,
    num_encoder_layers=1,
    encoder_dropout_p=0.3,
    encoder_bidirectional=True,
    # world encoder config
    simple_situation_representation=True,
    cnn_hidden_num_channels=50,
    cnn_kernel_size=7,
    cnn_dropout_p=0.1,
    auxiliary_task=False,
    # decoder config
    num_decoder_layers=1,
    attention_type="bahdanau",
    decoder_dropout_p=0.3,
    decoder_hidden_size=100,
    conditional_attention=True,
    output_directory="../../../saved_models/ReaSCAN-Simple/"
)
model.eval()
model.to(device)
g = ReaSCANMultiModalLSTMCompGraph(
     model=model,
)

Loading some examples to verify

In [None]:
for (input_batch, input_lengths, _, situation_batch, _, target_batch,
     target_lengths, agent_positions, target_positions) in dataset.get_data_iterator(batch_size=2):
#     target_scores, target_position_scores = model(
#         commands_input=input_batch, commands_lengths=input_lengths,
#         situations_input=situation_batch, target_batch=target_batch,
#         target_lengths=target_lengths
#     )
    # print(target_scores)
    break

In [None]:
input_batch

In [None]:
input_dict = {
    "commands_input": input_batch, 
    "commands_lengths": input_lengths,
    "situations_input": situation_batch,
    "target_batch": target_batch,
    "target_lengths": target_lengths,
}
all_in = GraphInput(input_dict, batched=True, batch_dim=0)

In [None]:
g.compute_node("command_input_encode", all_in)

Setting up training loop for this model in antra