#### This script works on transform the seq2seq model to a graphical model using antra

In [1]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import json
from model import *
from ReaSCAN_dataset import *
import torch.nn.functional as F
import torch
from antra.antra import *
from decode_graphical_models import *
from torch.utils.data.sampler import RandomSampler, SequentialSampler

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
if isnotebook():
    device = torch.device("cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# setting up the seeds.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

Initialize ReaSCAN dataset to load config of the model

In [2]:
data_directory = "../../../data-files/gSCAN-Simple/"
data_file = "data-compositional-splits.txt"
input_vocab_file = "input_vocabulary.txt"
target_vocab_file = "target_vocabulary.txt"
dataset = ReaSCANDataset(
    json.load(open(os.path.join(data_directory, data_file), "r")), 
    data_directory, split="train",
    input_vocabulary_file=input_vocab_file,
    target_vocabulary_file=target_vocab_file,
    generate_vocabulary=False,
    k=0,
)
# Loading a couple of example from ReaSCAN.
dataset.read_dataset(
    max_examples=100,
    simple_situation_representation=True
)

2021-08-14 01:25 Formulating the dataset from the passed in json file...
2021-08-14 01:25 Loading vocabularies...
2021-08-14 01:25 Done loading vocabularies.
2021-08-14 01:25 Converting dataset to tensors...


In [3]:
# dataset.save_vocabularies(
#     input_vocabulary_file="input_vocabulary.txt", 
#     target_vocabulary_file="target_vocabulary.txt"
# )

Loading model to the computational graph

In [4]:
model = Model(
    input_vocabulary_size=dataset.input_vocabulary_size,
    target_vocabulary_size=dataset.target_vocabulary_size,
    num_cnn_channels=dataset.image_channels,
    input_padding_idx=dataset.input_vocabulary.pad_idx,
    target_pad_idx=dataset.target_vocabulary.pad_idx,
    target_eos_idx=dataset.target_vocabulary.eos_idx,
    # language encoder config
    embedding_dimension=25,
    encoder_hidden_size=100,
    num_encoder_layers=1,
    encoder_dropout_p=0.3,
    encoder_bidirectional=True,
    # world encoder config
    simple_situation_representation=True,
    cnn_hidden_num_channels=50,
    cnn_kernel_size=7,
    cnn_dropout_p=0.1,
    auxiliary_task=False,
    # decoder config
    num_decoder_layers=1,
    attention_type="bahdanau",
    decoder_dropout_p=0.3,
    decoder_hidden_size=100,
    conditional_attention=True,
    output_directory="../../../saved_models/gSCAN-Simple/"
)
model.eval()
model.to(device)
G = ReaSCANMultiModalLSTMCompGraph(
    model=model,
    max_decode_step=13,
    is_cf=True
)

  "num_layers={}".format(dropout, num_layers))


In [5]:
train_data, _ = dataset.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=1)

Loading some examples to verify

In [6]:
for step, batch in enumerate(train_dataloader):
    # just using this loop to get a pair of examples
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch
    break

In [7]:
low1 = {
    "commands_input": input_batch, 
    "commands_lengths": input_lengths_batch,
    "situations_input": situation_batch,
    "target_batch": target_batch,
    "target_lengths": target_lengths_batch,
}
low1 = GraphInput(low1, batched=True, batch_dim=0, cache_results=False)

low2 = {
    "commands_input": dual_input_batch, 
    "commands_lengths": dual_input_lengths_batch,
    "situations_input": dual_situation_batch,
    "target_batch": dual_target_batch,
    "target_lengths": dual_target_lengths_batch,
}
low2 = GraphInput(low2, batched=True, batch_dim=0, cache_results=False)

In [8]:
dual_target_batch.shape

torch.Size([1, 13])

In [9]:
low2_hidden = G.compute_node('lstm_step_2', low2)

In [10]:
low2_hidden.shape

torch.Size([1, 4638])

In [52]:
intervention_dict = {"lstm_step_2[:,0:2]": torch.rand(2,2)}
low_interv = Intervention(
    low1, intervention_dict, 
    cache_results=False,
    cache_base_results=False,
)

{'lstm_step_2[::,0:2:]': tensor([[0.7991, 0.7719],
        [0.5270, 0.2030]])}


In [53]:
G.intervene_node("lstm_step_0", low_interv)

(tensor([[-0.1271,  0.1078,  0.0014,  ...,  0.0408,  0.0281,  0.0707]],
        grad_fn=<CatBackward>),
 tensor([[-0.1271,  0.1078,  0.0014,  ...,  0.0408,  0.0281,  0.0707]],
        grad_fn=<CatBackward>))

Setting up training loop for this model in antra

In [45]:
low_interv

{'base': [(('commands_input', (1, 3, 4, 5, 6, 7, 8, 2)),
           ('commands_lengths', (8,)),
           ('situations_input',
            (((0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
               0.0, 0.0, 0.0, 0.0),
              (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
               0.0, 0.0, 0.0, 0.0),
              (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
               0.0, 0.0, 0.0, 0.0),
              (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
               0.0, 0.0, 0.0, 0.0),
              (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
               0.0, 0.0, 0.0, 0.0),
              (0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
               0.0, 0.0, 0.0, 0.0)),
             ((0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
               0.0, 0.0, 0.0, 0.0),
              (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,