In [376]:
import argparse
import logging
import os
import torch
import logging
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import time
import random
import torch.nn.functional as F

from decode_graphical_models import *
from decode_abstract_models import *
from seq2seq.ReaSCAN_dataset import *
from seq2seq.helpers import *
from torch.optim.lr_scheduler import LambdaLR

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

In [None]:
path_to_data = "../../../data-files/gSCAN-Simple/data-compositional-splits.txt"
data_json = json.load(open(path_to_data, "r"))

In [None]:
NUM = 200
agent_positions_batch = []
target_positions_batch = []
target_commands = []
for ex in data_json["examples"]["train"]:
    target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    agent_positions_batch.append(agent)
    target_positions_batch.append(target)
    if len(agent_positions_batch) == NUM:
        break
agent_positions_batch = torch.stack(agent_positions_batch, dim=0)
target_positions_batch = torch.stack(target_positions_batch, dim=0)

In [319]:
hi_model = HighLevelModel()

In [320]:
hidden_states = hi_model(agent_positions_batch, target_positions_batch, tag="situation_encode")
actions = torch.zeros(hidden_states.size(0), 1).long()

In [322]:
actions_sequence = []
actions_length = torch.zeros(hidden_states.size(0), 1).long()
for i in range(20):
    hidden_states, actions = hi_model(
        hmm_states=hidden_states, 
        hmm_actions=actions, 
        tag="_hmm_step_fxn"
    )
    actions_length += (actions!=0).long()
    actions_sequence += [actions]

In [323]:
actions_sequence = torch.cat(actions_sequence, dim=-1)

In [325]:
for i in range(actions_sequence.size(0)):
    pred = (hi_model.actions_list_to_sequence(actions_sequence[i,:actions_length[i]].tolist()))
    actual = target_commands[i]
    assert pred == actual

#### try some interventions

In [378]:
data_json = json.load(open(path_to_data, "r"))
training_set = ReaSCANDataset(
    data_json, 
    "../../../data-files/gSCAN-Simple/", split="train",
    input_vocabulary_file="input_vocabulary.txt",
    target_vocabulary_file="target_vocabulary.txt",
    generate_vocabulary=False, k=0
)
training_set.read_dataset(
    max_examples=100,
    simple_situation_representation=False
)

2021-08-21 02:39 Formulating the dataset from the passed in json file...
2021-08-21 02:39 Loading vocabularies...
2021-08-21 02:39 Done loading vocabularies.
2021-08-21 02:39 Converting dataset to tensors...


In [380]:
train_data, _ = training_set.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=50)

In [413]:
hi_model = HighLevelModel(
    # None
)

In [408]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()

    dual_high_hidden_states = hi_model(
        agent_positions_batch=dual_agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=dual_target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    dual_high_actions = torch.zeros(
        dual_high_hidden_states.size(0), 1
    ).long()

    break # just steal one batch

In [475]:
intervene_time = 1
intervene_attribute = 0

In [476]:
# get the intercepted dual hidden states.
for j in range(intervene_time):
    dual_high_hidden_states, dual_high_actions = hi_model(
        hmm_states=dual_high_hidden_states, 
        hmm_actions=dual_high_actions, 
        tag="_hmm_step_fxn"
    )

In [477]:
train_max_decoding_steps = 20
# main intervene for loop.
cf_high_hidden_states = high_hidden_states
cf_high_actions = high_actions
intervened_target_batch = [torch.ones(high_hidden_states.size(0), 1).long()] # SOS tokens
intervened_target_lengths_batch = torch.zeros(high_hidden_states.size(0), 1).long()
# we need to take of the SOS and EOS tokens.
for j in range(train_max_decoding_steps-2):
    # intercept like antra!
    if j == intervene_time:
        # only swap out this part.
        cf_high_hidden_states[:,intervene_attribute] = dual_high_hidden_states[:,intervene_attribute]
        # comment out two lines below if it is not for testing.
        # cf_high_hidden_states = dual_high_hidden_states
        # cf_high_actions = dual_high_actions
    cf_high_hidden_states, cf_high_actions = hi_model(
        hmm_states=cf_high_hidden_states, 
        hmm_actions=cf_high_actions, 
        tag="_hmm_step_fxn"
    )
    # record the output for loss calculation.
    intervened_target_batch += [cf_high_actions]
    intervened_target_lengths_batch += (cf_high_actions!=0).long()
intervened_target_batch += [torch.zeros(high_hidden_states.size(0), 1).long()] # pad for extra eos
intervened_target_lengths_batch += 2
intervened_target_batch = torch.cat(intervened_target_batch, dim=-1)
for i in range(high_hidden_states.size(0)):
    intervened_target_batch[i,intervened_target_lengths_batch[i,0]-1] = 2

In [478]:
(intervened_target_batch[:,:target_batch.size(1)] != target_batch).sum(1)

tensor([7, 7, 4, 6, 3, 6, 5, 5, 6, 6, 7, 4, 3, 6, 0, 3, 4, 0, 6, 3, 6, 7, 3, 4,
        6, 4, 3, 3, 5, 3, 0, 5, 3, 5, 7, 7, 0, 4, 4, 5, 3, 3, 3, 0, 4, 3, 4, 7,
        5, 4])

In [473]:
intervened_target_batch[:,:target_batch.size(1)]

tensor([[1, 4, 3, 4, 4, 3, 4, 4, 4, 4, 4, 4, 2],
        [1, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 2],
        [1, 4, 4, 4, 5, 4, 4, 2, 0, 0, 0, 0, 0],
        [1, 4, 3, 4, 5, 4, 4, 4, 2, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 5, 4, 4, 2, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 4, 4, 4, 3, 4, 4, 4, 2],
        [1, 4, 5, 4, 4, 4, 4, 4, 3, 4, 4, 2, 0],
        [1, 4, 3, 4, 4, 4, 4, 4, 5, 4, 4, 2, 0],
        [1, 4, 5, 4, 4, 4, 4, 4, 3, 4, 4, 4, 2],
        [1, 4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4],
        [1, 4, 4, 3, 4, 4, 4, 4, 4, 2, 0, 0, 0],
        [1, 4, 5, 4, 4, 5, 4, 4, 4, 2, 0, 0, 0],
        [1, 4, 5, 4, 5, 4, 4, 2, 0, 0, 0, 0, 0],
        [1, 4, 4, 4, 4, 3, 4, 4, 4, 4, 2, 0, 0],
        [1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 4, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 4, 3, 4, 2, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 4, 5, 4, 2, 0, 0, 0, 0],
        [1, 4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4],
        [1, 3, 5, 4, 4, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 3, 5, 4,

In [479]:
a = torch.rand(3,3)

In [480]:
a 

tensor([[0.0017, 0.2133, 0.3266],
        [0.2905, 0.4855, 0.1845],
        [0.8512, 0.1628, 0.3731]])

In [481]:
b = torch.rand(3,3)

In [482]:
b

tensor([[0.7027, 0.4703, 0.8446],
        [0.9190, 0.3526, 0.1672],
        [0.5256, 0.4236, 0.5546]])

In [483]:
a[:,1]=b[:,1]

In [484]:
a

tensor([[0.0017, 0.4703, 0.3266],
        [0.2905, 0.3526, 0.1845],
        [0.8512, 0.4236, 0.3731]])