In [1]:

# Import common dependencies
import pandas as pd  # noqa
import numpy as np
import matplotlib  # noqa
import matplotlib.pyplot as plt
import datetime  # noqa
import PIL  # noqa
import glob  # noqa
import pickle  # noqa
from pathlib import Path  # noqa
from scipy import misc  # noqa
import sys
import tensorflow as tf
import pdb
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
TRADE_COST_FRAC = .003
EPSILON = 1e-10
ADV_MULT = 1e-3

In [2]:
uni_tokens = set()
uni_commands = set()
uni_actions = set()
fname = 'tasks_with_length_tags.txt'
with open(fname) as f:
    content = f.readlines()
content2 = [c.split(' ') for c in content]
# you may also want to remove whitespace characters like `\n` at the end of each line
commands = []
actions = []
content = [l.replace('\n', '') for l in content]
commands = [x.split(':::')[1].split(' ')[1:-1] for x in content]
actions = [x.split(':::')[2].split(' ')[1:-2] for x in content]
structures = [x.split(':::')[3].split(' ')[2:] for x in content]

structures = [[int(l) for l in program] for program in structures]
#actions = [[wd.replace('\n', '') for wd in res] for res in actions]

In [3]:
max_actions_per_subprogram = max([max([s for s in struct]) for struct in structures]) + 1
max_num_subprograms = max([len(s) for s in structures]) + 1
max_cmd_len = max([len(s) for s in commands]) + 1
max_act_len = max([len(a) for a in actions]) + 1
cmd_lengths_list = [len(s)+1 for s in commands]
cmd_lengths_np = np.array(cmd_lengths_list)
max_num_subprograms, max_cmd_len, max_act_len, max_actions_per_subprogram


(7, 10, 49, 9)

In [4]:
def build_fmap_invmap(unique, num_unique):
    fmap = dict(zip(unique, range(num_unique)))
    invmap = dict(zip(range(num_unique), unique))
    return fmap, invmap


In [5]:
for li in content2:
    for wd in li:
        uni_tokens.add(wd)
for li in commands:
    for wd in li:
        uni_commands.add(wd)
for li in actions:
    for wd in li:
        uni_actions.add(wd)
uni_commands.add('end_command')
uni_actions.add('end_subprogram')
uni_actions.add('end_action')
num_cmd = len(uni_commands)
num_act = len(uni_actions)
command_map, command_invmap = build_fmap_invmap(uni_commands, num_cmd)
action_map, action_invmap = build_fmap_invmap(uni_actions, num_act)

In [6]:


def dense_scaled(prev_layer, layer_size, name=None, reuse=False, scale=1.0):
    output = tf.layers.dense(prev_layer, layer_size, reuse=reuse) * scale
    return output


def dense_relu(dense_input, layer_size, scale=1.0):
    dense = dense_scaled(dense_input, layer_size, scale=scale)
    output = tf.nn.leaky_relu(dense)

    return output

def get_grad_norm(opt_fcn, loss):
    gvs = opt_fcn.compute_gradients(loss)
    grad_norm = tf.sqrt(tf.reduce_sum(
        [tf.reduce_sum(tf.square(grad)) for grad, var in gvs if grad is not None]))
    return grad_norm


def apply_clipped_optimizer(opt_fcn, loss, clip_norm=.1, clip_single=.03, clip_global_norm=False):
    gvs = opt_fcn.compute_gradients(loss)

    if clip_global_norm:
        gs, vs = zip(*[(g, v) for g, v in gvs if g is not None])
        capped_gs, grad_norm_total = tf.clip_by_global_norm([g for g in gs], clip_norm)
        capped_gvs = list(zip(capped_gs, vs))
    else:
        grad_norm_total = tf.sqrt(
            tf.reduce_sum([tf.reduce_sum(tf.square(grad)) for grad, var in gvs if grad is not None]))
        capped_gvs = [(tf.clip_by_value(grad, -1 * clip_single, clip_single), var)
                      for grad, var in gvs if grad is not None]
        capped_gvs = [(tf.clip_by_norm(grad, clip_norm), var) for grad, var in capped_gvs if grad is not None]

    optimizer = opt_fcn.apply_gradients(capped_gvs)

    return optimizer, grad_norm_total


def mlp(x, hidden_sizes, output_size=None, name='', reuse=False):
    prev_layer = x

    for idx, l in enumerate(hidden_sizes):
        dense = dense_scaled(prev_layer, l, name='mlp' + name + '_' + str(idx))
        prev_layer = tf.nn.leaky_relu(dense)

    output = prev_layer

    if output_size is not None:
        output = dense_scaled(prev_layer, output_size, name='mlp' + name + 'final')

    return output

def mlp_with_adversaries(x, hidden_sizes, output_size=None, name='', reuse=False):
    prev_layer = x
    adv_phs = []
    for idx, l in enumerate(hidden_sizes):
        
        adversary = tf.placeholder_with_default(tf.zeros_like(prev_layer), prev_layer.shape)
        prev_layer = prev_layer + adversary
        adv_phs.append(adversary)
        
        dense = dense_scaled(prev_layer, l, name='mlp' + name + '_' + str(idx))
        prev_layer = tf.nn.leaky_relu(dense)

    output = prev_layer

    if output_size is not None:
        output = dense_scaled(prev_layer, output_size, name='mlp' + name + 'final')

    return output, adv_phs



In [7]:

commands_ind = [[command_map[c] for c in cmd] + [0] * (max_cmd_len - len(cmd)) for cmd in commands]
actions_ind = [[action_map[a] for a in act] + [0] * (max_act_len - len(act)) for act in actions]
cmd_np = np.array(commands_ind)
actions_structured = []
mask_structured = []
for row in range(len(structures)):
    mask_row = []
    action_row = []
    act = actions_ind[row]
    struct = structures[row]
    start = 0
    for step in struct:
        end = start + step
        a = act[start:end]
        padding = max_actions_per_subprogram - step - 1
        action_row.append(a + [action_map['end_action']] + [0] * padding)
        start = end
    actions_structured.append(
        action_row + [[action_map['end_subprogram']] + [0] * (max_actions_per_subprogram - 1)] +
        [[0] * max_actions_per_subprogram] * (max_num_subprograms - len(struct) - 1)
    )
act_np = np.array(actions_structured)
struct_padded = [[sa + 1 for sa in s] + [1] + [0] * (max_num_subprograms - len(s) - 1) for s in structures]
struct_np = np.array(struct_padded)

mask_list = [[np.concatenate((np.ones(st), np.zeros(max_actions_per_subprogram - st)), 0) 
              for st in s] for s in struct_np]
mask_np = np.array(mask_list)

In [8]:
tf.reset_default_graph()
default_sizes = 128
size_emb = 64
num_layers_encoder = 6
hidden_filters = 128
num_layers_subprogram = 3
hidden_filters_subprogram = 128
init_mag = 1e-3
cmd_mat = tf.Variable(init_mag*tf.random_normal([num_cmd, size_emb]))
act_mat = tf.Variable(init_mag*tf.random_normal([num_act, size_emb]))
act_st_emb = tf.Variable(init_mag*tf.random_normal([size_emb]))
global_bs = None
global_time_len = 7
action_lengths = None
max_num_actions= None
# global_bs = 8
global_time_len = 7
max_num_actions = 9
output_keep_prob = tf.placeholder_with_default(1.0, ())
state_keep_prob = tf.placeholder_with_default(1.0, ())
cmd_ind = tf.placeholder(tf.int32, shape=(global_bs, 10,))
act_ind = tf.placeholder(tf.int32, shape=(global_bs, global_time_len, 9))
mask_ph = tf.placeholder(tf.float32, shape=(global_bs, global_time_len, 9))
cmd_lengths = tf.placeholder(tf.int32, shape=(global_bs,))
act_lengths = tf.placeholder(tf.int32, shape=(global_bs, 7))
learning_rate = tf.placeholder(tf.float32, shape = (None))

cmd_emb = tf.nn.embedding_lookup(cmd_mat, cmd_ind)
act_emb = tf.nn.embedding_lookup(act_mat, act_ind)
tf_bs = tf.shape(act_ind)[0]
act_st_emb_expanded = tf.tile(tf.reshape(
    act_st_emb, [1, 1, 1, size_emb]), [tf_bs, global_time_len, 1, 1])
act_emb_with_st = tf.concat((act_st_emb_expanded, act_emb), 2)

first_cell_encoder = [tf.nn.rnn_cell.LSTMCell(
    hidden_filters, forget_bias=1., name = 'layer1_'+d) for d in ['f', 'b']]
hidden_cells_encoder = [[tf.nn.rnn_cell.LSTMCell(
    hidden_filters,forget_bias=1., name = 'layer' + str(lidx) + '_' + d)  for d in ['f', 'b']]
                        for lidx in range(num_layers_encoder - 1)]
hidden_cells_encoder = [[tf.nn.rnn_cell.DropoutWrapper(cell,
    output_keep_prob=output_keep_prob, state_keep_prob=state_keep_prob,
    variational_recurrent=True, dtype=tf.float32) for cell in cells] 
                        for cells in hidden_cells_encoder[:-1]] + [hidden_cells_encoder[-1]]
cells_encoder = [first_cell_encoder] + hidden_cells_encoder
c1, c2 = zip(*cells_encoder)
cells_encoder = [c1, c2]
def encode(x, num_layers, cells, initial_states, lengths, name='',):
    prev_layer = x
    shortcut = x
    hiddenlayers = []
    returncells = []
    cell_fw, cell_bw = cells
    bs = tf_bs
    for idx in range(num_layers):
        prev_layer, c = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cell_fw[idx],
                cell_bw = cell_bw[idx],
                inputs = prev_layer,
                sequence_length=lengths,
                initial_state_fw=None,
                initial_state_bw=None,
                dtype=tf.float32,
                scope='encoder'+str(idx)
            )
        if idx == num_layers - 1:
            fw = prev_layer[0]
            bw = prev_layer[1]
            stacked = tf.stack([tf.range(bs), lengths - 1], 1)
            fw_final = tf.gather_nd(fw,stacked,name=None)
            bw_final = bw[:,0,:]
            output = tf.concat((fw_final, bw_final), 1)
        prev_layer = tf.concat(prev_layer, 2)
        prev_layer = tf.nn.leaky_relu(prev_layer)
        returncells.append(c)
        hiddenlayers.append(prev_layer)
        if idx == num_layers - 1:
            #pdb.set_trace()
            #stacked = tf.stack([tf.range(bs), lengths - 1], 1)
            #output = tf.gather_nd(prev_layer,stacked,name=None)
            return prev_layer, returncells, hiddenlayers, output, fw, stacked
        prev_layer = tf.concat((prev_layer, shortcut), 2)
encoding_last_layer, encoding_final_cells, encoding_hidden_layers, encoding_last_timestep_raw, dbg1, dbg2 = encode(
    cmd_emb, num_layers_encoder, cells_encoder,None, lengths = cmd_lengths, name = 'encoder')
encoding_last_timestep = tf.tanh(encoding_last_timestep_raw) + encoding_last_timestep_raw/10
# encoding_last_timestep = encoding_last_layer[:,cmd_lengths, :]
hidden_filters_encoder = encoding_last_timestep.shape[-1].value
first_cell_subprogram = tf.nn.rnn_cell.LSTMCell(
    hidden_filters_subprogram, forget_bias=1., name = 'subpogramlayer1_')
hidden_cells_subprogram = [tf.nn.rnn_cell.LSTMCell(
    hidden_filters_subprogram,forget_bias=1., name = 'subpogramlayer' + str(lidx))
                        for lidx in range(num_layers_subprogram - 1)]

cells_subprogram_rnn = [first_cell_subprogram] + hidden_cells_subprogram

attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
    num_units=hidden_filters_encoder, memory=encoding_last_layer,
    memory_sequence_length=cmd_lengths)
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units=hidden_filters_encoder//2, memory=encoding_last_layer,
    memory_sequence_length=cmd_lengths)
cells_subprogram = [
    tf.contrib.seq2seq.AttentionWrapper(
        cell, attention_mechanism, attention_layer_size = hidden_filters_subprogram) 
    for cell in cells_subprogram_rnn]

def subprogram(x, num_layers, cells, initial_states, lengths, reuse, name='',):
    prev_layer = x
    shortcut = x
    hiddenlayers = []
    returncells = []
    bs = tf.shape(x)[0]
    for idx in range(num_layers):
        print(idx)
        if idx == 0:
            num_past_units = hidden_filters
        else:
            num_past_units = hidden_filters_subprogram
        with tf.variable_scope(name + 'subprogram' + str(idx), reuse=reuse):
#             self_attention_mechanism = tf.contrib.seq2seq.LuongAttention(
#                 num_units=num_past_units, memory=prev_layer,
#                 memory_sequence_length=tf.expand_dims(tf.range(10), 0))
#             cell_with_selfattention = tf.contrib.seq2seq.AttentionWrapper(
#                     cells[idx], self_attention_mechanism, attention_layer_size = num_past_units)

            prev_layer, c = tf.nn.dynamic_rnn(
                    cell = cells[idx],
                    inputs = prev_layer,
                    sequence_length=lengths,
                    initial_state = None,
                    dtype=tf.float32,
                )
            prev_layer = tf.concat(prev_layer, 2)
            prev_layer = tf.nn.leaky_relu(prev_layer)
            returncells.append(c)
            hiddenlayers.append(prev_layer)
            if idx == num_layers - 1:
                output = tf.gather_nd(
                            prev_layer,
                            tf.stack([tf.range(bs), lengths], 1),
                            name=None
                        )
                return prev_layer, returncells, hiddenlayers, output
            prev_layer = tf.concat((prev_layer, shortcut), 2)
encodings = [encoding_last_timestep]
last_encoding = encoding_last_timestep
initial_cmb_encoding = last_encoding
loss = 0
action_probabilities_presoftmax = []
for sub_idx in range(max_num_subprograms): 
    from_last_layer = tf.tile(tf.expand_dims(tf.concat((
        initial_cmb_encoding, last_encoding), 1), 1), [1, max_num_actions + 1, 1])
    autoregressive = act_emb_with_st[:,sub_idx, :, :]
    x_input = tf.concat((from_last_layer, autoregressive), -1)
    subprogram_last_layer, _, subprogram_hidden_layers, subprogram_output = subprogram(
        x_input, num_layers_subprogram, cells_subprogram,None, 
        lengths = act_lengths[:, sub_idx], reuse = (sub_idx > 0), name = 'subprogram')
    action_prob_flat = mlp(
        tf.reshape(subprogram_last_layer, [-1, hidden_filters_subprogram]),
        [], output_size = num_act, name = 'action_choice_mlp', reuse = (sub_idx > 0))
    action_prob_expanded = tf.reshape(action_prob_flat, [-1, max_num_actions + 1, num_act])
    action_probabilities_layer = tf.nn.softmax(action_prob_expanded, axis=-1)
    action_probabilities_presoftmax.append(action_prob_expanded)
    delta1, delta2 = [mlp(
        subprogram_output, [256,], output_size = hidden_filters_encoder, name = 'global_transform' + str(idx),
        reuse = (sub_idx > 0)
    ) for idx in range(2)]
    remember = tf.sigmoid(delta1 + 1.0)
    insert = tf.tanh(delta2) + delta2/100
    last_encoding = last_encoding * remember + insert
    encodings.append(last_encoding)
act_presoftmax = tf.stack(action_probabilities_presoftmax, 1)[:, :, :-1, :]
#batch, subprogram, timestep, action_selection
logprobabilities = tf.nn.log_softmax(act_presoftmax, -1)
act_presoftmax_flat = tf.reshape(act_presoftmax, [-1, 9, num_act])
mask_ph_flat = tf.reshape(mask_ph, [-1, max_actions_per_subprogram])
act_ind_flat = tf.reshape(act_ind, [-1, max_actions_per_subprogram])
ppl_loss = tf.contrib.seq2seq.sequence_loss(
    logits = act_presoftmax_flat,
    targets = act_ind_flat,
    weights = mask_ph_flat,
    average_across_timesteps=False,
    average_across_batch=False,
    softmax_loss_function=None,
    name=None
)
ppl_loss_avg = tf.reduce_mean(tf.pow(ppl_loss, 2.0)) * 10000 # + tf.reduce_mean(tf.pow(ppl_loss, 1.0)) * 100

tfvars = tf.trainable_variables()
weight_norm = tf.reduce_mean([tf.reduce_sum(tf.square(var)) for var in tfvars])*1e-3

action_taken = tf.argmax(act_presoftmax, -1, output_type=tf.int32)
correct_mat = tf.cast(tf.equal(action_taken, act_ind), tf.float32) * mask_ph
correct_percent = tf.reduce_sum(correct_mat, [1, 2])/tf.reduce_sum(mask_ph, [1, 2])
percent_correct = tf.reduce_mean(correct_percent)
percent_fully_correct = tf.reduce_mean(tf.cast(tf.equal(correct_percent, 1), tf.float32))

loss = ppl_loss_avg + weight_norm

opt_fcn = tf.train.AdamOptimizer(learning_rate=learning_rate)
#opt_fcn = tf.train.MomentumOptimizer(learning_rate=learning_rate, use_nesterov=True, momentum=.8)
optimizer, grad_norm_total = apply_clipped_optimizer(opt_fcn, loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

0
1
2
0
1
2
0
1
2
0
1
2
0
1
2
0
1
2
0
1
2
Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [9]:
hidden_filters

128

In [10]:
encoding_last_layer

<tf.Tensor 'LeakyRelu_5/Maximum:0' shape=(?, 10, 256) dtype=float32>

In [None]:
np.random.seed(0)
trn_percent = .1
num_samples = mask_np.shape[0]
ordered_samples = np.arange(num_samples)
np.random.shuffle(ordered_samples)
trn_samples = ordered_samples[:int(np.ceil(num_samples*trn_percent))]
val_samples_all = ordered_samples[int(np.ceil(num_samples*trn_percent)):]
val_samples = val_samples_all
trn_samples.shape, val_samples.shape

((2091,), (18819,))

In [None]:
eval_itr = -1
bs = 32# trn_samples.shape[0]
for itr in range(1000000):
    if 1:#itr == 0:
        samples = np.random.choice(trn_samples, size = bs, replace = False)
        trn_feed_dict = {
            cmd_ind : cmd_np[samples],
            act_ind : act_np[samples],
            mask_ph : mask_np[samples],
            act_lengths : np.clip(struct_np[samples], a_min = 1, a_max = None),
            cmd_lengths : cmd_lengths_np[samples],
        }
        
    trn_feed_dict[learning_rate] = .1 / (np.power(itr + 10, .6))
    _, trn_loss, acc_trn_single, acc_trn = sess.run(
        [optimizer, loss, percent_correct, percent_fully_correct], trn_feed_dict)
    if itr == 0:
        trn_loss_avg = trn_loss
        acc_trn_avg = acc_trn
        acc_trn_single_avg = acc_trn_single
    else:
        trn_loss_avg = trn_loss_avg * .9 + trn_loss * .1
        acc_trn_avg = acc_trn_avg * .9 + acc_trn * .1
        acc_trn_single_avg = acc_trn_single_avg * .9 + acc_trn_single * .1
    if itr % 10 == 0 and itr > 0:
        # val_samples = np.random.choice(val_samples_all, size = bs, replace = False)
        eval_itr += 1
        val_feed_dict = {
            cmd_ind : cmd_np[val_samples],
            act_ind : act_np[val_samples],
            mask_ph : mask_np[val_samples],
            act_lengths : np.clip(struct_np[val_samples], a_min = 1, a_max = None),
            cmd_lengths : cmd_lengths_np[val_samples]
        }
        val_loss, acc_val = sess.run([loss, percent_fully_correct], val_feed_dict)
        if eval_itr == 0:
            val_loss_avg = val_loss
            acc_val_avg = acc_val
        else:
            val_loss_avg = val_loss_avg * .9 + val_loss * .1
            acc_val_avg = acc_val_avg * .9 + acc_val * .1
        print('itr:', itr, 'trn_loss', trn_loss_avg, 'val_loss', val_loss_avg)
        print('itr:', itr, 'trn_acc', acc_trn_avg, 
              'trn_single_acc', acc_trn_single_avg, 'val_acc', acc_val_avg)

itr: 10 trn_loss 17743.508590912777 val_loss 11770.555
itr: 10 trn_acc 0.0 trn_single_acc 0.174762902899793 val_acc 0.0
itr: 20 trn_loss 12820.023633354474 val_loss 11454.735546875001
itr: 20 trn_acc 0.0 trn_single_acc 0.25547205996127 val_acc 0.0
itr: 30 trn_loss 9559.54254694952 val_loss 11094.5252734375
itr: 30 trn_acc 0.0 trn_single_acc 0.33386064958977646 val_acc 0.0
itr: 40 trn_loss 8093.383663992886 val_loss 10642.9288984375
itr: 40 trn_acc 0.0 trn_single_acc 0.39448233638430213 val_acc 0.0
itr: 50 trn_loss 6663.4444227862305 val_loss 10177.647825
itr: 50 trn_acc 0.0 trn_single_acc 0.4401941991533718 val_acc 1.0627556912368164e-05
itr: 60 trn_loss 6026.818668639783 val_loss 9688.582603046874
itr: 60 trn_acc 0.0 trn_single_acc 0.48044688708532696 val_acc 9.564801221131347e-06
itr: 70 trn_loss 5541.321717133731 val_loss 9214.238600554687
itr: 70 trn_acc 0.0 trn_single_acc 0.522250518505039 val_acc 1.3922099555202295e-05
itr: 80 trn_loss 4890.6482315140875 val_loss 8740.05853932734

itr: 530 trn_loss 544.1613048121183 val_loss 868.0265767696999
itr: 530 trn_acc 0.11807339104528455 trn_single_acc 0.8997922664770637 val_acc 0.07354723737307946
itr: 540 trn_loss 574.1721532015056 val_loss 842.8251520028862
itr: 540 trn_acc 0.12387945114136216 trn_single_acc 0.904047049908251 val_acc 0.07904654436411167
itr: 550 trn_loss 548.160603085138 val_loss 819.9583960799414
itr: 550 trn_acc 0.13354034997191438 trn_single_acc 0.900837007596677 val_acc 0.08069074997098778
itr: 560 trn_loss 569.3252687721214 val_loss 798.4559768821035
itr: 560 trn_acc 0.11533737788424023 trn_single_acc 0.8976804676611965 val_acc 0.08216522126309456
itr: 570 trn_loss 517.8944013739538 val_loss 774.3009614692838
itr: 570 trn_acc 0.12770321268402615 trn_single_acc 0.8974881787058377 val_acc 0.08452311893325193
itr: 580 trn_loss 490.46257851890556 val_loss 754.1322605860273
itr: 580 trn_acc 0.1149204422944248 trn_single_acc 0.8989950458738746 val_acc 0.08656551978010961
itr: 590 trn_loss 471.888767340

itr: 1040 trn_loss 296.5005190098111 val_loss 492.79375273368066
itr: 1040 trn_acc 0.41181103893060106 trn_single_acc 0.9443199660931008 val_acc 0.24416097838926
itr: 1050 trn_loss 317.30186213016805 val_loss 492.20991579039077
itr: 1050 trn_acc 0.4508249491421574 trn_single_acc 0.9494390088006076 val_acc 0.25240867888254837
itr: 1060 trn_loss 333.4584274846123 val_loss 484.1726461351798
itr: 1060 trn_acc 0.4223046862531744 trn_single_acc 0.9466668622433769 val_acc 0.26226531763663513
itr: 1070 trn_loss 289.84472965992876 val_loss 477.4085461945134
itr: 1070 trn_acc 0.4576688572684269 trn_single_acc 0.9481250611816878 val_acc 0.2736922201495524
itr: 1080 trn_loss 284.22742723991263 val_loss 474.5040074930308
itr: 1080 trn_acc 0.4954376892690797 trn_single_acc 0.9543607823154435 val_acc 0.2883018490953287
itr: 1090 trn_loss 235.39375446741724 val_loss 462.3730739068137
itr: 1090 trn_acc 0.5575659852079663 trn_single_acc 0.961184163555607 val_acc 0.31148292755630974
itr: 1100 trn_loss 19

itr: 1550 trn_loss 25.34085316171879 val_loss 133.26936169211547
itr: 1550 trn_acc 0.9783266618705171 trn_single_acc 0.9984486046352892 val_acc 0.9283468169606675
itr: 1560 trn_loss 25.32514165968258 val_loss 130.0837669230504
itr: 1560 trn_acc 0.9764261714567523 trn_single_acc 0.9985495334451522 val_acc 0.9305490651104107
itr: 1570 trn_loss 17.271791903963894 val_loss 120.29458197268873
itr: 1570 trn_acc 0.9760220882363556 trn_single_acc 0.9986705165280308 val_acc 0.9354536651110579
itr: 1580 trn_loss 6.946173942163357 val_loss 117.10365755837886
itr: 1580 trn_acc 0.989794137879397 trn_single_acc 0.9994562086163674 val_acc 0.9396818266990121
itr: 1590 trn_loss 14.360810998861382 val_loss 111.33623842729683
itr: 1590 trn_acc 0.9724084518252875 trn_single_acc 0.998595722375754 val_acc 0.9437953698649112
itr: 1600 trn_loss 11.612144561009083 val_loss 107.02414015829763
itr: 1600 trn_acc 0.9770111566818723 trn_single_acc 0.9988933700176119 val_acc 0.9463391543639394
itr: 1610 trn_loss 12.

itr: 2060 trn_loss 1.4725715254361185 val_loss 55.724894522922895
itr: 2060 trn_acc 0.9985367881894217 trn_single_acc 0.9999202941496761 val_acc 0.9884108311955271
itr: 2070 trn_loss 1.1031310873634046 val_loss 65.42412900861889
itr: 2070 trn_acc 0.9972116845883517 trn_single_acc 0.9999010168821922 val_acc 0.9876036471448069
itr: 2080 trn_loss 3.7257447042332474 val_loss 62.95521426144352
itr: 2080 trn_acc 0.9976825645005097 trn_single_acc 0.999853387595472 val_acc 0.9885032021690957
itr: 2090 trn_loss 1.9357386537333825 val_loss 64.8663670296351
itr: 2090 trn_acc 0.9969138352050054 trn_single_acc 0.9998064966029898 val_acc 0.989249036641975
itr: 2100 trn_loss 1.2612650518202835 val_loss 66.23476649000168
itr: 2100 trn_acc 0.9989239208733897 trn_single_acc 0.9999325295373764 val_acc 0.9899840527165471
itr: 2110 trn_loss 8.915728499103032 val_loss 94.18369584686089
itr: 2110 trn_acc 0.9940889506587093 trn_single_acc 0.99973681474452 val_acc 0.9903745627452617
itr: 2120 trn_loss 3.454891

itr: 2570 trn_loss 0.43516324487524377 val_loss 45.97481873074741
itr: 2570 trn_acc 0.9999998971935498 trn_single_acc 0.9999999950739358 val_acc 0.9971619320601869
itr: 2580 trn_loss 0.43263125779589046 val_loss 45.059816716071104
itr: 2580 trn_acc 0.9999999641536073 trn_single_acc 0.9999999982823875 val_acc 0.9973288362643703
itr: 2590 trn_loss 0.43055216834289844 val_loss 44.214239631255985
itr: 2590 trn_acc 0.9999999875011355 trn_single_acc 0.9999999994011055 val_acc 0.9974843608219848
itr: 2600 trn_loss 0.42861339342565496 val_loss 43.43752697184132
itr: 2600 trn_acc 0.9999999956419154 trn_single_acc 0.9999999997911785 val_acc 0.9976296496581518
itr: 2610 trn_loss 0.42675390550475123 val_loss 42.72434983618063
itr: 2610 trn_acc 0.9999999984804299 trn_single_acc 0.9999999999271885 val_acc 0.9977604096107021
itr: 2620 trn_loss 0.4249317878783864 val_loss 42.070464199486395
itr: 2620 trn_acc 0.9999999994701586 trn_single_acc 0.9999999999746121 val_acc 0.9978780935679974
itr: 2630 trn_

itr: 3080 trn_loss 0.3566166399097213 val_loss 17.42930873783308
itr: 3080 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9991171819584724
itr: 3090 trn_loss 0.3553548758670053 val_loss 17.201120238469937
itr: 3090 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9991204436973176
itr: 3100 trn_loss 0.35410939217774834 val_loss 16.99204213109633
itr: 3100 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9991233792622782
itr: 3110 trn_loss 0.35286688567385277 val_loss 16.76929065007288
itr: 3110 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9991260212707429
itr: 3120 trn_loss 0.3516242583315564 val_loss 16.545820561811198
itr: 3120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999128399078361
itr: 3130 trn_loss 0.3503941754087392 val_loss 16.351094729690136
itr: 3130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9991305391052174
itr: 3140 trn_loss

itr: 3590 trn_loss 0.5772132230882991 val_loss 82.03558507642941
itr: 3590 trn_acc 0.9978990302419479 trn_single_acc 0.9998082266755565 val_acc 0.9779232477499935
itr: 3600 trn_loss 49.27802059049265 val_loss 80.52709416522202
itr: 3600 trn_acc 0.9965770170795651 trn_single_acc 0.9997196081891293 val_acc 0.9793285431284213
itr: 3610 trn_loss 18.515538901820904 val_loss 73.26855922913805
itr: 3610 trn_acc 0.9939436671548139 trn_single_acc 0.9996617971469555 val_acc 0.9811990769343231
itr: 3620 trn_loss 9.640721211978807 val_loss 71.64356564142932
itr: 3620 trn_acc 0.9947632873108142 trn_single_acc 0.9995979882988347 val_acc 0.9824096302661288
itr: 3630 trn_loss 3.610190302438488 val_loss 124.60031503431765
itr: 3630 trn_acc 0.9981740711882828 trn_single_acc 0.9998598271871356 val_acc 0.9796732134049823
itr: 3640 trn_loss 76.07181324950511 val_loss 114.53590921753138
itr: 3640 trn_acc 0.9921888844901968 trn_single_acc 0.9996584605472925 val_acc 0.9811798274301609
itr: 3650 trn_loss 26.94

itr: 4100 trn_loss 0.303099483258903 val_loss 13.17892299816511
itr: 4100 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9984024494606469
itr: 4110 trn_loss 0.3022705528566653 val_loss 13.112737870742398
itr: 4110 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9984240469084573
itr: 4120 trn_loss 0.30143725630492185 val_loss 13.055429461475775
itr: 4120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9984434846114867
itr: 4130 trn_loss 0.3006118135845107 val_loss 13.003214361855298
itr: 4130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9984609785442131
itr: 4140 trn_loss 0.2997826550007436 val_loss 12.955822994639496
itr: 4140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9984767230836669
itr: 4150 trn_loss 0.29895854692499224 val_loss 12.917595240768808
itr: 4150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9984908931691754
itr: 4160 trn_lo

itr: 4600 trn_loss 0.263335990600064 val_loss 12.532429672926165
itr: 4600 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.998668595041512
itr: 4610 trn_loss 0.26257162273460277 val_loss 12.526567820593266
itr: 4610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9986688887050854
itr: 4620 trn_loss 0.2618147181874188 val_loss 12.508216134206549
itr: 4620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9986691530023015
itr: 4630 trn_loss 0.2610512197377943 val_loss 12.470501057651129
itr: 4630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.998669390869796
itr: 4640 trn_loss 0.2602888946235131 val_loss 12.431699853253205
itr: 4640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.998669604950541
itr: 4650 trn_loss 0.259540799734408 val_loss 12.392640013496733
itr: 4650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9986697976232116
itr: 4660 trn_loss 0

itr: 5100 trn_loss 0.22638854855531204 val_loss 11.31581887815644
itr: 5100 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9986001573654407
itr: 5110 trn_loss 0.2256697559994119 val_loss 11.289109026321022
itr: 5110 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985966672884579
itr: 5120 trn_loss 0.22495432775051682 val_loss 11.265923412080033
itr: 5120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985935262191733
itr: 5130 trn_loss 0.22423886606301965 val_loss 11.233229081125936
itr: 5130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985906992568171
itr: 5140 trn_loss 0.22352954462193872 val_loss 11.196334338449134
itr: 5140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985881549906965
itr: 5150 trn_loss 0.2228179611390649 val_loss 11.16360733770969
itr: 5150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985805543773386
itr: 5160 trn_

itr: 5610 trn_loss 0.19115354662118822 val_loss 9.256845061836497
itr: 5610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9983734262756471
itr: 5620 trn_loss 0.19048986383701833 val_loss 9.219134722522965
itr: 5620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9983713542913166
itr: 5630 trn_loss 0.18981947042682765 val_loss 9.208533925746742
itr: 5630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9983694895054193
itr: 5640 trn_loss 0.18915714214330578 val_loss 9.151503279448924
itr: 5640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9983625004242621
itr: 5650 trn_loss 0.18849216673647312 val_loss 9.116622765041631
itr: 5650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9983562102512207
itr: 5660 trn_loss 0.18783018092503256 val_loss 9.053172921398795
itr: 5660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9983505490954835
itr: 5670 trn_lo

itr: 6120 trn_loss 40.27255936573408 val_loss 144.85743033631283
itr: 6120 trn_acc 0.9754223208797405 trn_single_acc 0.9987131713482125 val_acc 0.940632971235525
itr: 6130 trn_loss 59.68939074835207 val_loss 130.85791843900358
itr: 6130 trn_acc 0.9535957330705697 trn_single_acc 0.9963126952780058 val_acc 0.9462189663425786
itr: 6140 trn_loss 21.056099045412427 val_loss 126.4796156087751
itr: 6140 trn_acc 0.9838198325930622 trn_single_acc 0.9987143163413617 val_acc 0.9503536452927325
itr: 6150 trn_loss 9.928209059822516 val_loss 123.17897231571986
itr: 6150 trn_acc 0.9899922045023671 trn_single_acc 0.99930566500481 val_acc 0.9547550228307964
itr: 6160 trn_loss 3.8265382511611707 val_loss 113.31448733085197
itr: 6160 trn_acc 0.9965104974770456 trn_single_acc 0.9997579003569703 val_acc 0.9586524975660731
itr: 6170 trn_loss 1.4685452989074603 val_loss 102.49560459202948
itr: 6170 trn_acc 0.9987832857035711 trn_single_acc 0.9999155850741196 val_acc 0.9626703452196679
itr: 6180 trn_loss 0.64

itr: 6630 trn_loss 0.17867395468126174 val_loss 6.400997053589068
itr: 6630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999351987227547
itr: 6640 trn_loss 0.17827642704905072 val_loss 6.285443160836973
itr: 6640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993902167541515
itr: 6650 trn_loss 0.17788486683988514 val_loss 6.179815304790508
itr: 6650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994246233280956
itr: 6660 trn_loss 0.17749018852893933 val_loss 6.082997905476009
itr: 6660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994555892446452
itr: 6670 trn_loss 0.17709912869668434 val_loss 5.99394264274518
itr: 6670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994834585695398
itr: 6680 trn_loss 0.17670952948552238 val_loss 5.912808619025472
itr: 6680 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999508540961945
itr: 6690 trn_loss 

itr: 7140 trn_loss 0.15881697467824687 val_loss 4.967254372178575
itr: 7140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997837831356118
itr: 7150 trn_loss 0.15843123309983348 val_loss 4.964591930871362
itr: 7150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997788330714098
itr: 7160 trn_loss 0.15804712358838616 val_loss 4.9616118942783665
itr: 7160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999774378013628
itr: 7170 trn_loss 0.157659624891614 val_loss 4.95876101330695
itr: 7170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997703684616244
itr: 7180 trn_loss 0.1572717407331115 val_loss 4.955888423405089
itr: 7180 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997667598648211
itr: 7190 trn_loss 0.1568850481103284 val_loss 4.951989406122563
itr: 7190 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997635121276982
itr: 7200 trn_loss 0.

itr: 7650 trn_loss 0.1391607282309594 val_loss 4.516751635936194
itr: 7650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997292013234058
itr: 7660 trn_loss 0.13877826346283068 val_loss 4.5008500834113
itr: 7660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997243986665749
itr: 7670 trn_loss 0.13839598802891628 val_loss 4.483680047375468
itr: 7670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999720076275427
itr: 7680 trn_loss 0.13801263648814333 val_loss 4.4683007816515925
itr: 7680 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997161861233941
itr: 7690 trn_loss 0.1376285149335129 val_loss 4.454451383952132
itr: 7690 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997126849865643
itr: 7700 trn_loss 0.13724450451479706 val_loss 4.4414262132082865
itr: 7700 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997095339634174
itr: 7710 trn_loss 

itr: 8160 trn_loss 0.14050043711850974 val_loss 10.299797383639287
itr: 8160 trn_acc 0.9999999932116344 trn_single_acc 0.9999999996437631 val_acc 0.9971902147973568
itr: 8170 trn_loss 0.14360609316310782 val_loss 9.451591438972715
itr: 8170 trn_acc 0.9999999976330431 trn_single_acc 0.9999999998757878 val_acc 0.9973436632196597
itr: 8180 trn_loss 0.15691368301235625 val_loss 8.640699482193973
itr: 8180 trn_acc 0.9999999991746932 trn_single_acc 0.9999999999566899 val_acc 0.9975774143732035
itr: 8190 trn_loss 0.14985985967427373 val_loss 8.021660676209379
itr: 8190 trn_acc 0.9999999997122333 trn_single_acc 0.9999999999848986 val_acc 0.9977771629032294
itr: 8200 trn_loss 0.14237188749872945 val_loss 7.362967911906464
itr: 8200 trn_acc 0.9999999998996618 trn_single_acc 0.9999999999947345 val_acc 0.9979728748622656
itr: 8210 trn_loss 0.13941150186821727 val_loss 6.791298303199491
itr: 8210 trn_acc 0.9999999999650142 trn_single_acc 0.9999999999981639 val_acc 0.9981383881172348
itr: 8220 trn_l

itr: 8660 trn_loss 0.1253963553641856 val_loss 2.0620536560857667
itr: 8660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996541734411117
itr: 8670 trn_loss 0.1251202658410493 val_loss 2.065035893474016
itr: 8670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996568735725102
itr: 8680 trn_loss 0.12483837344791868 val_loss 2.069205612499112
itr: 8680 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996539869564549
itr: 8690 trn_loss 0.12455705036955361 val_loss 2.074067482851679
itr: 8690 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996513890020051
itr: 8700 trn_loss 0.12428064380912507 val_loss 2.0789213146293775
itr: 8700 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996490508430003
itr: 8710 trn_loss 0.12399932525188186 val_loss 2.084304830329831
itr: 8710 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996469464998959
itr: 8720 trn_lo

itr: 9160 trn_loss 0.1113090929379194 val_loss 4.5753813128877585
itr: 9160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994113391440488
itr: 9170 trn_loss 0.11102633683719265 val_loss 4.846114427555404
itr: 9170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993958126724998
itr: 9180 trn_loss 0.1107459904399305 val_loss 5.1320715940375345
itr: 9180 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993818388481056
itr: 9190 trn_loss 0.11046332828920538 val_loss 5.429692071352531
itr: 9190 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999363945671837
itr: 9200 trn_loss 0.11017490991632675 val_loss 5.732900823659417
itr: 9200 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993425310393456
itr: 9210 trn_loss 0.10989485692051887 val_loss 6.041104214346454
itr: 9210 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993179411357894
itr: 9220 trn_los

In [None]:
correct_percent.shape, percent_fully_correct.shape

In [None]:
10, None, 7, 9, None

In [None]:
cmd_np.shape, act_np.shape, mask_np.shape, struct_np.shape, cmd_lengths_np.shape

In [None]:
cmd_ind.shape, act_ind.shape, mask_ph.shape, act_lengths.shape, cmd_lengths.shape

In [None]:
trn_samples.shape, val_samples.shape

In [None]:

act_presoftmax = tf.stack(action_probabilities_presoftmax, 1)[:, :, :-1, :]
#batch, subprogram, timestep, action_selection
logprobabilities = tf.nn.log_softmax(act_presoftmax, -1)
act_presoftmax_flat = tf.reshape(act_presoftmax, [-1, 9, num_act])
mask_ph_flat = tf.reshape(mask_ph, [-1, max_actions_per_subprogram])
act_ind_flat = tf.reshape(act_ind, [-1, max_actions_per_subprogram])

In [None]:
act_presoftmax_flat = tf.reshape(act_presoftmax, [-1, 9, num_act])

In [None]:
max_actions_per_subprogram

In [None]:
sess.run(act_presoftmax, feed_dict).shape

In [None]:
action_map

In [None]:
print(*actions_ind[2])

In [None]:
command_map

In [None]:
action_map

In [None]:
subprogram_output

In [None]:
subprogram_last_layer[:,cmd_lengths,:]

In [None]:
encoding_last_layer

In [None]:
tf.gather(
    encoding_last_layer,
    [1,2],
    axis=1
)

In [None]:
tf.gather_nd(
    encoding_last_layer,
    np.array([[0,1,2,3,4], [1,4,3,2,5]]).T,
    name=None
)

In [None]:
cmd_lengths

In [None]:
def generate_command(sub_cmd, num_repeat):
    return sub_cmd * num_repeat

In [None]:
def process_command(cmd):
    

In [None]:
uni_commands

In [None]:
uni_actions

In [None]:
uni_tokens

In [None]:
df.shape