In [1]:
import argparse
import logging
import os
import torch
import logging
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import time
import random
import torch.nn.functional as F

from decode_graphical_models import *
from decode_abstract_models import *
from seq2seq.ReaSCAN_dataset import *
from seq2seq.helpers import *
from torch.optim.lr_scheduler import LambdaLR

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

In [2]:
path_to_data = "../../../data-files/gSCAN-Simple/data-compositional-splits.txt"
data_json = json.load(open(path_to_data, "r"))

In [3]:
NUM = 200
agent_positions_batch = []
target_positions_batch = []
target_commands = []
for ex in data_json["examples"]["situational_1"]:
    target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    agent_positions_batch.append(agent)
    target_positions_batch.append(target)
    if len(agent_positions_batch) == NUM:
        break
agent_positions_batch = torch.stack(agent_positions_batch, dim=0)
target_positions_batch = torch.stack(target_positions_batch, dim=0)

In [4]:
hi_model = HighLevelModel()

In [16]:
hidden_states = hi_model(agent_positions_batch, target_positions_batch, tag="situation_encode")
actions = torch.zeros(hidden_states.size(0), 1).long()

In [17]:
actions_sequence = []
actions_length = torch.zeros(hidden_states.size(0), 1).long()
for i in range(1):
    hidden_states, actions = hi_model(
        hmm_states=hidden_states, 
        hmm_actions=actions, 
        tag="_hmm_step_fxn"
    )
    actions_length += (actions!=0).long()
    actions_sequence += [actions]

In [41]:
grid_size = 6
x_target = torch.zeros(hidden_states.shape[0], (grid_size*2-1)).long()
y_target = torch.zeros(hidden_states.shape[0], (grid_size*2-1)).long()
indices = hidden_states + 5
x_target[range(x_target.shape[0]), indices[:,0]] = 1
y_target[range(y_target.shape[0]), indices[:,1]] = 1

In [538]:
actions_sequence = torch.cat(actions_sequence, dim=-1)

In [540]:
for i in range(actions_sequence.size(0)):
    pred = (hi_model.actions_list_to_sequence(actions_sequence[i,:actions_length[i]].tolist()))
    actual = target_commands[i]
    assert pred == actual

#### try some interventions

In [44]:
data_json = json.load(open(path_to_data, "r"))
training_set = ReaSCANDataset(
    data_json, 
    "../../../data-files/gSCAN-Simple/", split="train",
    input_vocabulary_file="input_vocabulary.txt",
    target_vocabulary_file="target_vocabulary.txt",
    generate_vocabulary=False, k=0
)
training_set.read_dataset(
    max_examples=100,
    simple_situation_representation=False
)

2021-08-30 19:15 Formulating the dataset from the passed in json file...
2021-08-30 19:15 Loading vocabularies...
2021-08-30 19:15 Done loading vocabularies.
2021-08-30 19:15 Converting dataset to tensors...


In [380]:
train_data, _ = training_set.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=50)

In [413]:
hi_model = HighLevelModel(
    # None
)

In [408]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()

    dual_high_hidden_states = hi_model(
        agent_positions_batch=dual_agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=dual_target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    dual_high_actions = torch.zeros(
        dual_high_hidden_states.size(0), 1
    ).long()

    break # just steal one batch

In [475]:
intervene_time = 1
intervene_attribute = 0

In [476]:
# get the intercepted dual hidden states.
for j in range(intervene_time):
    dual_high_hidden_states, dual_high_actions = hi_model(
        hmm_states=dual_high_hidden_states, 
        hmm_actions=dual_high_actions, 
        tag="_hmm_step_fxn"
    )

In [477]:
train_max_decoding_steps = 20
# main intervene for loop.
cf_high_hidden_states = high_hidden_states
cf_high_actions = high_actions
intervened_target_batch = [torch.ones(high_hidden_states.size(0), 1).long()] # SOS tokens
intervened_target_lengths_batch = torch.zeros(high_hidden_states.size(0), 1).long()
# we need to take of the SOS and EOS tokens.
for j in range(train_max_decoding_steps-2):
    # intercept like antra!
    if j == intervene_time:
        # only swap out this part.
        cf_high_hidden_states[:,intervene_attribute] = dual_high_hidden_states[:,intervene_attribute]
        # comment out two lines below if it is not for testing.
        # cf_high_hidden_states = dual_high_hidden_states
        # cf_high_actions = dual_high_actions
    cf_high_hidden_states, cf_high_actions = hi_model(
        hmm_states=cf_high_hidden_states, 
        hmm_actions=cf_high_actions, 
        tag="_hmm_step_fxn"
    )
    # record the output for loss calculation.
    intervened_target_batch += [cf_high_actions]
    intervened_target_lengths_batch += (cf_high_actions!=0).long()
intervened_target_batch += [torch.zeros(high_hidden_states.size(0), 1).long()] # pad for extra eos
intervened_target_lengths_batch += 2
intervened_target_batch = torch.cat(intervened_target_batch, dim=-1)
for i in range(high_hidden_states.size(0)):
    intervened_target_batch[i,intervened_target_lengths_batch[i,0]-1] = 2

In [478]:
(intervened_target_batch[:,:target_batch.size(1)] != target_batch).sum(1)

tensor([7, 7, 4, 6, 3, 6, 5, 5, 6, 6, 7, 4, 3, 6, 0, 3, 4, 0, 6, 3, 6, 7, 3, 4,
        6, 4, 3, 3, 5, 3, 0, 5, 3, 5, 7, 7, 0, 4, 4, 5, 3, 3, 3, 0, 4, 3, 4, 7,
        5, 4])

In [473]:
intervened_target_batch[:,:target_batch.size(1)]

tensor([[1, 4, 3, 4, 4, 3, 4, 4, 4, 4, 4, 4, 2],
        [1, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 2],
        [1, 4, 4, 4, 5, 4, 4, 2, 0, 0, 0, 0, 0],
        [1, 4, 3, 4, 5, 4, 4, 4, 2, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 5, 4, 4, 2, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 4, 4, 4, 3, 4, 4, 4, 2],
        [1, 4, 5, 4, 4, 4, 4, 4, 3, 4, 4, 2, 0],
        [1, 4, 3, 4, 4, 4, 4, 4, 5, 4, 4, 2, 0],
        [1, 4, 5, 4, 4, 4, 4, 4, 3, 4, 4, 4, 2],
        [1, 4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4],
        [1, 4, 4, 3, 4, 4, 4, 4, 4, 2, 0, 0, 0],
        [1, 4, 5, 4, 4, 5, 4, 4, 4, 2, 0, 0, 0],
        [1, 4, 5, 4, 5, 4, 4, 2, 0, 0, 0, 0, 0],
        [1, 4, 4, 4, 4, 3, 4, 4, 4, 4, 2, 0, 0],
        [1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 4, 3, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 4, 3, 4, 2, 0, 0, 0, 0],
        [1, 4, 5, 4, 4, 4, 5, 4, 2, 0, 0, 0, 0],
        [1, 4, 3, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4],
        [1, 3, 5, 4, 4, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 3, 5, 4,

#### Hidden states of high level model of the compositional generalization split.

In [508]:
# train hidden states
NUM = 200
agent_positions_batch = []
target_positions_batch = []
target_commands = []
for ex in data_json["examples"]["train"]:
    target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    agent_positions_batch.append(agent)
    target_positions_batch.append(target)
    if len(agent_positions_batch) == NUM:
        break
agent_positions_batch = torch.stack(agent_positions_batch, dim=0)
target_positions_batch = torch.stack(target_positions_batch, dim=0)

In [509]:
hi_model = HighLevelModel()
hidden_states = hi_model(agent_positions_batch, target_positions_batch, tag="situation_encode")

In [516]:
hidden_states

tensor([[ 1,  0,  0],
        [ 2,  0,  0],
        [ 2,  0,  0],
        [ 2,  0,  0],
        [ 3,  0,  0],
        [ 3,  0,  0],
        [ 3,  0,  0],
        [ 3,  0,  0],
        [ 3,  0,  0],
        [ 4,  0,  0],
        [ 4,  0,  0],
        [ 4,  0,  0],
        [ 4,  0,  0],
        [ 4,  0,  0],
        [ 5,  0,  0],
        [-1,  1,  0],
        [-1,  1,  0],
        [-1,  1,  0],
        [-1,  1,  0],
        [-2,  1,  0],
        [-1,  2,  0],
        [-2,  1,  0],
        [-2,  1,  0],
        [-2,  2,  0],
        [-1,  3,  0],
        [-2,  2,  0],
        [-3,  1,  0],
        [-3,  2,  0],
        [-1,  4,  0],
        [-4,  1,  0],
        [-2,  3,  0],
        [-4,  2,  0],
        [-3,  3,  0],
        [-3,  3,  0],
        [-2,  4,  0],
        [-5,  2,  0],
        [-2,  5,  0],
        [-2,  5,  0],
        [-5,  2,  0],
        [-2,  5,  0],
        [-5,  3,  0],
        [-3,  5,  0],
        [-3,  5,  0],
        [-4,  4,  0],
        [-4,  5,  0],
        [-

In [491]:
NUM = 200
cg_agent_positions_batch = []
cg_target_positions_batch = []
cg_target_commands = []
for ex in data_json["examples"]["situational_1"]:
    cg_target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    cg_agent_positions_batch.append(agent)
    cg_target_positions_batch.append(target)
    if len(cg_agent_positions_batch) == NUM:
        break
cg_agent_positions_batch = torch.stack(cg_agent_positions_batch, dim=0)
cg_target_positions_batch = torch.stack(cg_target_positions_batch, dim=0)

In [494]:
hi_model = HighLevelModel()
hidden_states = hi_model(cg_agent_positions_batch, cg_target_positions_batch, tag="situation_encode")

In [496]:
hidden_states[0]

tensor([ 1, -1,  0])

In [528]:
data_json["examples"]["situational_1"][10]

{'command': 'walk,to,a,yellow,small,cylinder',
 'meaning': 'walk,to,a,yellow,small,cylinder',
 'derivation': "NP -> NN,NP -> JJ NP,NP -> JJ NP,DP -> 'a' NP,VP -> VV_intrans 'to' DP,ROOT -> VP;T:walk,NT:VV_intransitive -> walk,T:to,T:a,T:yellow,NT:JJ -> small:JJ -> yellow,T:small,T:cylinder,NT:NN -> cylinder",
 'situation': {'grid_size': 6,
  'agent_position': {'row': '2', 'column': '3'},
  'agent_direction': 0,
  'target_object': {'vector': '10000101000',
   'position': {'row': '5', 'column': '2'},
   'object': {'shape': 'cylinder', 'color': 'yellow', 'size': '1'}},
  'distance_to_target': '4',
  'direction_to_target': 'sw',
  'placed_objects': {'0': {'vector': '10000101000',
    'position': {'row': '5', 'column': '2'},
    'object': {'shape': 'cylinder', 'color': 'yellow', 'size': '1'}},
   '1': {'vector': '01000101000',
    'position': {'row': '3', 'column': '5'},
    'object': {'shape': 'cylinder', 'color': 'yellow', 'size': '2'}},
   '2': {'vector': '00010010010',
    'position': {

In [497]:
# the first should be positive and the second should be negative.

Let us see, if your intervention produce any similar examples as above.

In [520]:
data_json = json.load(open(path_to_data, "r"))
training_set = ReaSCANDataset(
    data_json, 
    "../../../data-files/gSCAN-Simple/", split="train",
    input_vocabulary_file="input_vocabulary.txt",
    target_vocabulary_file="target_vocabulary.txt",
    generate_vocabulary=False, k=0
)
training_set.read_dataset(
    max_examples=1000,
    simple_situation_representation=False
)

2021-08-23 17:44 Formulating the dataset from the passed in json file...
2021-08-23 17:44 Loading vocabularies...
2021-08-23 17:44 Done loading vocabularies.
2021-08-23 17:44 Converting dataset to tensors...


In [521]:
train_data, _ = training_set.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=50)

In [522]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()
    
    print(high_hidden_states)
    break

tensor([[ 5,  4,  0],
        [ 2,  0,  0],
        [ 3,  5,  0],
        [ 1,  5,  0],
        [ 5,  5,  0],
        [-1,  4,  0],
        [-1,  3,  0],
        [ 0,  2,  0],
        [ 4,  0,  0],
        [-5, -2,  0],
        [-1,  3,  0],
        [ 2,  1,  0],
        [ 0,  2,  0],
        [-3,  5,  0],
        [-5,  0,  0],
        [ 4,  5,  0],
        [ 0, -2,  0],
        [ 3,  0,  0],
        [-1,  4,  0],
        [-3,  1,  0],
        [ 0, -2,  0],
        [-3,  0,  0],
        [ 2,  0,  0],
        [ 0, -2,  0],
        [-2, -2,  0],
        [-1, -4,  0],
        [ 5,  0,  0],
        [-4,  4,  0],
        [-1, -1,  0],
        [-2,  0,  0],
        [-3,  2,  0],
        [ 0,  4,  0],
        [ 4,  2,  0],
        [-4,  0,  0],
        [ 4,  0,  0],
        [-3,  0,  0],
        [-5, -4,  0],
        [-1,  5,  0],
        [ 2,  0,  0],
        [ 4,  0,  0],
        [ 0,  2,  0],
        [-5,  0,  0],
        [-1,  0,  0],
        [-5, -2,  0],
        [-3, -4,  0],
        [-

In [526]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()

    dual_high_hidden_states = hi_model(
        agent_positions_batch=dual_agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=dual_target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    dual_high_actions = torch.zeros(
        dual_high_hidden_states.size(0), 1
    ).long()

    intervene_attribute = 1
    intervene_time = random.choice([1,2,3])
    
    # get the intercepted dual hidden states.
    for j in range(intervene_time):
        dual_high_hidden_states, dual_high_actions = hi_model(
            hmm_states=dual_high_hidden_states, 
            hmm_actions=dual_high_actions, 
            tag="_hmm_step_fxn"
        )
        
    train_max_decoding_steps = 20
    # main intervene for loop.
    cf_high_hidden_states = high_hidden_states
    cf_high_actions = high_actions
    # we need to take of the SOS and EOS tokens.
    for j in range(train_max_decoding_steps-1):
        # intercept like antra!
        if j == intervene_time:
            # only swap out this part.
            cf_high_hidden_states[:,intervene_attribute] = dual_high_hidden_states[:,intervene_attribute]
            print(cf_high_hidden_states)
            break
        cf_high_hidden_states, cf_high_actions = hi_model(
            hmm_states=cf_high_hidden_states, 
            hmm_actions=cf_high_actions, 
            tag="_hmm_step_fxn"
        )
    
    cg_count = 0
    for i in range(input_batch.size(0)):
        if cf_high_hidden_states[i][0] > 0 and cf_high_hidden_states[i][1] < 0 and cf_high_hidden_states[i][2] == 0:
            cg_count += 1
    
    print(f"cg_count: {cg_count}/{input_batch.size(0)}")

tensor([[ 2,  0,  0],
        [ 4,  0,  2],
        [-4,  0,  0],
        [ 0, -5,  3],
        [-2,  0,  0],
        [ 3, -1,  0],
        [ 0,  4,  3],
        [-5,  2,  3],
        [ 2, -3,  2],
        [ 3,  3,  0],
        [-4,  4,  0],
        [ 4,  0,  2],
        [-5, -1,  3],
        [ 5,  0,  0],
        [ 3,  2,  0],
        [-5, -1,  3],
        [-2,  0,  0],
        [-1, -1,  3],
        [ 2, -1,  0],
        [-4,  0,  0],
        [ 1,  0,  0],
        [-1,  3,  0],
        [ 4,  4,  0],
        [ 0, -3,  0],
        [-5,  0,  3],
        [-1,  2,  0],
        [-3,  0,  3],
        [-3,  0,  3],
        [ 4, -4,  2],
        [-4, -4,  0],
        [ 3,  1,  0],
        [ 1,  0,  0],
        [ 0,  2,  0],
        [-3,  4,  0],
        [ 4, -2,  0],
        [-1, -1,  3],
        [-4,  3,  0],
        [ 2,  1,  2],
        [-3,  0,  3],
        [ 5, -2,  0],
        [ 1,  0,  0],
        [ 4, -4,  2],
        [ 4,  0,  2],
        [ 0,  4,  0],
        [-3,  0,  3],
        [-

tensor([[-4, -4,  0],
        [ 1,  5,  0],
        [-5,  3,  0],
        [-1, -4,  0],
        [ 2,  4,  0],
        [ 0, -3,  0],
        [-1,  1,  0],
        [-2,  2,  0],
        [-1,  2,  0],
        [-1,  3,  0],
        [-4,  2,  0],
        [-2,  0,  0],
        [ 0, -2,  0],
        [-3,  2,  0],
        [-1,  1,  0],
        [ 3,  4,  0],
        [ 5, -3,  0],
        [ 3,  1,  0],
        [-1, -5,  0],
        [ 1,  0,  0],
        [-3,  3,  0],
        [-1,  2,  0],
        [-5,  3,  0],
        [ 0,  1,  0],
        [ 4,  4,  0],
        [-1,  0,  0],
        [ 5,  5,  0],
        [-3,  1,  0],
        [-4,  1,  0],
        [-2,  0,  0],
        [-1,  0,  0],
        [-5, -4,  0],
        [ 0,  4,  0],
        [-1, -4,  0],
        [ 4,  0,  0],
        [-1,  4,  0],
        [ 0, -5,  0],
        [ 0, -3,  0],
        [ 2,  5,  0],
        [-2,  3,  0],
        [-2,  0,  0],
        [ 5,  1,  0],
        [ 5,  1,  0],
        [-5, -3,  0],
        [-4, -3,  0],
        [-

tensor([[ 0,  2,  3],
        [-4, -5,  3],
        [-4,  4,  3],
        [-1,  4,  3],
        [-1,  0,  3],
        [ 4,  0,  2],
        [-2,  1,  0],
        [-5,  2,  3],
        [ 4,  0,  0],
        [-5,  2,  3],
        [ 1, -3,  2],
        [-2,  0,  0],
        [ 0, -4,  0],
        [-2,  2,  3],
        [ 2, -2,  0],
        [-2,  0,  3],
        [-4,  4,  0],
        [ 0, -3,  0],
        [ 3,  2,  2],
        [ 5, -5,  0],
        [-1, -5,  3],
        [-2,  3,  0],
        [ 0,  0,  3],
        [ 0,  1,  3],
        [-4,  1,  0],
        [-2, -4,  0],
        [-5, -1,  3],
        [-1,  2,  3],
        [ 2,  4,  0],
        [-5,  1,  0],
        [-1,  1,  3],
        [ 5,  3,  0],
        [-2, -3,  0],
        [-5, -2,  3],
        [-2,  2,  0],
        [-1,  3,  0],
        [ 2, -4,  0],
        [-3,  4,  3],
        [-1,  0,  0],
        [-1,  0,  3],
        [-2, -4,  3],
        [-3,  0,  0],
        [-4,  0,  3],
        [-1,  1,  3],
        [ 0,  4,  0],
        [-