In [3]:
import argparse
import logging
import os
import torch
import logging
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
import time
import random
import torch.nn.functional as F

from decode_abstract_models import *
from seq2seq.ReaSCAN_dataset import *
from seq2seq.helpers import *
from torch.optim.lr_scheduler import LambdaLR

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

In [None]:
path_to_data = "../../../data-files/gSCAN-Simple/data-compositional-splits.txt"
data_json = json.load(open(path_to_data, "r"))

In [None]:
NUM = 200
agent_positions_batch = []
target_positions_batch = []
target_commands = []
for ex in data_json["examples"]["situational_1"]:
    target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    agent_positions_batch.append(agent)
    target_positions_batch.append(target)
    if len(agent_positions_batch) == NUM:
        break
agent_positions_batch = torch.stack(agent_positions_batch, dim=0)
target_positions_batch = torch.stack(target_positions_batch, dim=0)

In [None]:
hi_model = HighLevelModel()

In [None]:
hidden_states = hi_model(agent_positions_batch, target_positions_batch, tag="situation_encode")
actions = torch.zeros(hidden_states.size(0), 1).long()

In [None]:
actions_sequence = []
actions_length = torch.zeros(hidden_states.size(0), 1).long()
for i in range(1):
    hidden_states, actions = hi_model(
        hmm_states=hidden_states, 
        hmm_actions=actions, 
        tag="_hmm_step_fxn"
    )
    actions_length += (actions!=0).long()
    actions_sequence += [actions]

In [None]:
grid_size = 6
x_target = torch.zeros(hidden_states.shape[0], (grid_size*2-1)).long()
y_target = torch.zeros(hidden_states.shape[0], (grid_size*2-1)).long()
indices = hidden_states + 5
x_target[range(x_target.shape[0]), indices[:,0]] = 1
y_target[range(y_target.shape[0]), indices[:,1]] = 1

In [None]:
actions_sequence = torch.cat(actions_sequence, dim=-1)

In [None]:
for i in range(actions_sequence.size(0)):
    pred = (hi_model.actions_list_to_sequence(actions_sequence[i,:actions_length[i]].tolist()))
    actual = target_commands[i]
    assert pred == actual

#### try some interventions

In [None]:
data_json = json.load(open(path_to_data, "r"))
training_set = ReaSCANDataset(
    data_json, 
    "../../../data-files/gSCAN-Simple/", split="train",
    input_vocabulary_file="input_vocabulary.txt",
    target_vocabulary_file="target_vocabulary.txt",
    generate_vocabulary=False, k=0
)
training_set.read_dataset(
    max_examples=100,
    simple_situation_representation=False
)

In [None]:
train_data, _ = training_set.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=50)

In [None]:
hi_model = HighLevelModel(
    # None
)

In [None]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()

    dual_high_hidden_states = hi_model(
        agent_positions_batch=dual_agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=dual_target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    dual_high_actions = torch.zeros(
        dual_high_hidden_states.size(0), 1
    ).long()

    break # just steal one batch

In [None]:
intervene_time = 1
intervene_attribute = 0

In [None]:
# get the intercepted dual hidden states.
for j in range(intervene_time):
    dual_high_hidden_states, dual_high_actions = hi_model(
        hmm_states=dual_high_hidden_states, 
        hmm_actions=dual_high_actions, 
        tag="_hmm_step_fxn"
    )

In [None]:
train_max_decoding_steps = 20
# main intervene for loop.
cf_high_hidden_states = high_hidden_states
cf_high_actions = high_actions
intervened_target_batch = [torch.ones(high_hidden_states.size(0), 1).long()] # SOS tokens
intervened_target_lengths_batch = torch.zeros(high_hidden_states.size(0), 1).long()
# we need to take of the SOS and EOS tokens.
for j in range(train_max_decoding_steps-2):
    # intercept like antra!
    if j == intervene_time:
        # only swap out this part.
        cf_high_hidden_states[:,intervene_attribute] = dual_high_hidden_states[:,intervene_attribute]
        # comment out two lines below if it is not for testing.
        # cf_high_hidden_states = dual_high_hidden_states
        # cf_high_actions = dual_high_actions
    cf_high_hidden_states, cf_high_actions = hi_model(
        hmm_states=cf_high_hidden_states, 
        hmm_actions=cf_high_actions, 
        tag="_hmm_step_fxn"
    )
    # record the output for loss calculation.
    intervened_target_batch += [cf_high_actions]
    intervened_target_lengths_batch += (cf_high_actions!=0).long()
intervened_target_batch += [torch.zeros(high_hidden_states.size(0), 1).long()] # pad for extra eos
intervened_target_lengths_batch += 2
intervened_target_batch = torch.cat(intervened_target_batch, dim=-1)
for i in range(high_hidden_states.size(0)):
    intervened_target_batch[i,intervened_target_lengths_batch[i,0]-1] = 2

In [None]:
(intervened_target_batch[:,:target_batch.size(1)] != target_batch).sum(1)

In [None]:
intervened_target_batch[:,:target_batch.size(1)]

#### Hidden states of high level model of the compositional generalization split.

In [None]:
# train hidden states
NUM = 200
agent_positions_batch = []
target_positions_batch = []
target_commands = []
for ex in data_json["examples"]["train"]:
    target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    agent_positions_batch.append(agent)
    target_positions_batch.append(target)
    if len(agent_positions_batch) == NUM:
        break
agent_positions_batch = torch.stack(agent_positions_batch, dim=0)
target_positions_batch = torch.stack(target_positions_batch, dim=0)

In [None]:
hi_model = HighLevelModel()
hidden_states = hi_model(agent_positions_batch, target_positions_batch, tag="situation_encode")

In [None]:
hidden_states

In [None]:
NUM = 200
cg_agent_positions_batch = []
cg_target_positions_batch = []
cg_target_commands = []
for ex in data_json["examples"]["situational_1"]:
    cg_target_commands += [ex["target_commands"]]
    situation_repr = ex['situation']
    agent = torch.tensor(
        (int(situation_repr["agent_position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["agent_position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    target = torch.tensor(
        (int(situation_repr["target_object"]["position"]["row"]) * int(situation_repr["grid_size"])) +
        int(situation_repr["target_object"]["position"]["column"]), dtype=torch.long).unsqueeze(dim=0)
    cg_agent_positions_batch.append(agent)
    cg_target_positions_batch.append(target)
    if len(cg_agent_positions_batch) == NUM:
        break
cg_agent_positions_batch = torch.stack(cg_agent_positions_batch, dim=0)
cg_target_positions_batch = torch.stack(cg_target_positions_batch, dim=0)

In [None]:
hi_model = HighLevelModel()
hidden_states = hi_model(cg_agent_positions_batch, cg_target_positions_batch, tag="situation_encode")

In [None]:
hidden_states[0]

In [None]:
data_json["examples"]["situational_1"][10]

In [None]:
# the first should be positive and the second should be negative.

Let us see, if your intervention produce any similar examples as above.

In [None]:
data_json = json.load(open(path_to_data, "r"))
training_set = ReaSCANDataset(
    data_json, 
    "../../../data-files/gSCAN-Simple/", split="train",
    input_vocabulary_file="input_vocabulary.txt",
    target_vocabulary_file="target_vocabulary.txt",
    generate_vocabulary=False, k=0
)
training_set.read_dataset(
    max_examples=1000,
    simple_situation_representation=False
)

In [None]:
train_data, _ = training_set.get_dual_dataset()
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=50)

In [None]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()
    
    print(high_hidden_states)
    break

In [None]:
# Shuffle the dataset and loop over it.
for step, batch in enumerate(train_dataloader):
    # main batch
    input_batch, target_batch, situation_batch, \
        agent_positions_batch, target_positions_batch, \
        input_lengths_batch, target_lengths_batch, \
        dual_input_batch, dual_target_batch, dual_situation_batch, \
        dual_agent_positions_batch, dual_target_positions_batch, \
        dual_input_lengths_batch, dual_target_lengths_batch = batch

    high_hidden_states = hi_model(
        agent_positions_batch=agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    high_actions = torch.zeros(
        high_hidden_states.size(0), 1
    ).long()

    dual_high_hidden_states = hi_model(
        agent_positions_batch=dual_agent_positions_batch.unsqueeze(dim=-1), 
        target_positions_batch=dual_target_positions_batch.unsqueeze(dim=-1), 
        tag="situation_encode"
    )
    dual_high_actions = torch.zeros(
        dual_high_hidden_states.size(0), 1
    ).long()

    intervene_attribute = 1
    intervene_time = random.choice([1,2,3])
    
    # get the intercepted dual hidden states.
    for j in range(intervene_time):
        dual_high_hidden_states, dual_high_actions = hi_model(
            hmm_states=dual_high_hidden_states, 
            hmm_actions=dual_high_actions, 
            tag="_hmm_step_fxn"
        )
        
    train_max_decoding_steps = 20
    # main intervene for loop.
    cf_high_hidden_states = high_hidden_states
    cf_high_actions = high_actions
    # we need to take of the SOS and EOS tokens.
    for j in range(train_max_decoding_steps-1):
        # intercept like antra!
        if j == intervene_time:
            # only swap out this part.
            cf_high_hidden_states[:,intervene_attribute] = dual_high_hidden_states[:,intervene_attribute]
            print(cf_high_hidden_states)
            break
        cf_high_hidden_states, cf_high_actions = hi_model(
            hmm_states=cf_high_hidden_states, 
            hmm_actions=cf_high_actions, 
            tag="_hmm_step_fxn"
        )
    
    cg_count = 0
    for i in range(input_batch.size(0)):
        if cf_high_hidden_states[i][0] > 0 and cf_high_hidden_states[i][1] < 0 and cf_high_hidden_states[i][2] == 0:
            cg_count += 1
    
    print(f"cg_count: {cg_count}/{input_batch.size(0)}")

Following sections are for counterfactual training for new attribute splits

In [24]:
path_to_data = "../../../data-files/ReaSCAN-novel-attribute/data-compositional-splits.txt"
data_json = json.load(open(path_to_data, "r"))

In [81]:
training_set = ReaSCANDataset(
    data_json, 
    "../../../data-files/ReaSCAN-novel-attribute/", split="train",
    input_vocabulary_file="input_vocabulary.txt",
    target_vocabulary_file="target_vocabulary.txt",
    generate_vocabulary=False, k=0
)
training_set.read_dataset(
    max_examples=1000,
    simple_situation_representation=False
)

2021-09-20 00:52 Formulating the dataset from the passed in json file...
2021-09-20 00:52 Loading vocabularies...
2021-09-20 00:52 Done loading vocabularies.
2021-09-20 00:52 Converting dataset to tensors...


In [82]:
train_data, _ = training_set.get_dual_dataset(novel_attribute=True)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=50)

In [76]:
for step, batch in enumerate(train_dataloader):
    input_sequence, target_sequence, situation, \
        agent_positions, target_positions, \
        input_lengths, target_lengths, \
        dual_input_sequence, dual_target_sequence, dual_situation, \
        dual_agent_positions, dual_target_positions, \
        dual_input_lengths, dual_target_lengths = batch

ValueError: too many values to unpack (expected 14)