In [1]:

# Import common dependencies
import pandas as pd  # noqa
import numpy as np
import matplotlib  # noqa
import matplotlib.pyplot as plt
import datetime  # noqa
import PIL  # noqa
import glob  # noqa
import pickle  # noqa
from pathlib import Path  # noqa
from scipy import misc  # noqa
import sys
import tensorflow as tf
import pdb
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
TRADE_COST_FRAC = .003
EPSILON = 1e-10
ADV_MULT = 1e-3

In [2]:
uni_tokens = set()
uni_commands = set()
uni_actions = set()
fname = 'tasks_with_length_tags.txt'
with open(fname) as f:
    content = f.readlines()
content2 = [c.split(' ') for c in content]
# you may also want to remove whitespace characters like `\n` at the end of each line
commands = []
actions = []
content = [l.replace('\n', '') for l in content]
commands = [x.split(':::')[1].split(' ')[1:-1] for x in content]
actions = [x.split(':::')[2].split(' ')[1:-2] for x in content]
structures = [x.split(':::')[3].split(' ')[2:] for x in content]

structures = [[int(l) for l in program] for program in structures]
#actions = [[wd.replace('\n', '') for wd in res] for res in actions]

In [3]:
max_actions_per_subprogram = max([max([s for s in struct]) for struct in structures]) + 1
max_num_subprograms = max([len(s) for s in structures]) + 1
max_cmd_len = max([len(s) for s in commands]) + 1
max_act_len = max([len(a) for a in actions]) + 1
cmd_lengths_list = [len(s)+1 for s in commands]
cmd_lengths_np = np.array(cmd_lengths_list)
max_num_subprograms, max_cmd_len, max_act_len, max_actions_per_subprogram


(7, 10, 49, 9)

In [4]:
def build_fmap_invmap(unique, num_unique):
    fmap = dict(zip(unique, range(num_unique)))
    invmap = dict(zip(range(num_unique), unique))
    return fmap, invmap


In [5]:
for li in content2:
    for wd in li:
        uni_tokens.add(wd)
for li in commands:
    for wd in li:
        uni_commands.add(wd)
for li in actions:
    for wd in li:
        uni_actions.add(wd)
uni_commands.add('end_command')
uni_actions.add('end_subprogram')
uni_actions.add('end_action')
num_cmd = len(uni_commands)
num_act = len(uni_actions)
command_map, command_invmap = build_fmap_invmap(uni_commands, num_cmd)
action_map, action_invmap = build_fmap_invmap(uni_actions, num_act)

In [6]:


def dense_scaled(prev_layer, layer_size, name=None, reuse=False, scale=1.0):
    output = tf.layers.dense(prev_layer, layer_size, reuse=reuse) * scale
    return output


def dense_relu(dense_input, layer_size, scale=1.0):
    dense = dense_scaled(dense_input, layer_size, scale=scale)
    output = tf.nn.leaky_relu(dense)

    return output

def get_grad_norm(opt_fcn, loss):
    gvs = opt_fcn.compute_gradients(loss)
    grad_norm = tf.sqrt(tf.reduce_sum(
        [tf.reduce_sum(tf.square(grad)) for grad, var in gvs if grad is not None]))
    return grad_norm


def apply_clipped_optimizer(opt_fcn, loss, clip_norm=.1, clip_single=.03, clip_global_norm=False):
    gvs = opt_fcn.compute_gradients(loss)

    if clip_global_norm:
        gs, vs = zip(*[(g, v) for g, v in gvs if g is not None])
        capped_gs, grad_norm_total = tf.clip_by_global_norm([g for g in gs], clip_norm)
        capped_gvs = list(zip(capped_gs, vs))
    else:
        grad_norm_total = tf.sqrt(
            tf.reduce_sum([tf.reduce_sum(tf.square(grad)) for grad, var in gvs if grad is not None]))
        capped_gvs = [(tf.clip_by_value(grad, -1 * clip_single, clip_single), var)
                      for grad, var in gvs if grad is not None]
        capped_gvs = [(tf.clip_by_norm(grad, clip_norm), var) for grad, var in capped_gvs if grad is not None]

    optimizer = opt_fcn.apply_gradients(capped_gvs)

    return optimizer, grad_norm_total


def mlp(x, hidden_sizes, output_size=None, name='', reuse=False):
    prev_layer = x

    for idx, l in enumerate(hidden_sizes):
        dense = dense_scaled(prev_layer, l, name='mlp' + name + '_' + str(idx))
        prev_layer = tf.nn.leaky_relu(dense)

    output = prev_layer

    if output_size is not None:
        output = dense_scaled(prev_layer, output_size, name='mlp' + name + 'final')

    return output

def mlp_with_adversaries(x, hidden_sizes, output_size=None, name='', reuse=False):
    prev_layer = x
    adv_phs = []
    for idx, l in enumerate(hidden_sizes):
        
        adversary = tf.placeholder_with_default(tf.zeros_like(prev_layer), prev_layer.shape)
        prev_layer = prev_layer + adversary
        adv_phs.append(adversary)
        
        dense = dense_scaled(prev_layer, l, name='mlp' + name + '_' + str(idx))
        prev_layer = tf.nn.leaky_relu(dense)

    output = prev_layer

    if output_size is not None:
        output = dense_scaled(prev_layer, output_size, name='mlp' + name + 'final')

    return output, adv_phs



In [7]:

commands_ind = [[command_map[c] for c in cmd] + [0] * (max_cmd_len - len(cmd)) for cmd in commands]
actions_ind = [[action_map[a] for a in act] + [0] * (max_act_len - len(act)) for act in actions]
cmd_np = np.array(commands_ind)
actions_structured = []
mask_structured = []
for row in range(len(structures)):
    mask_row = []
    action_row = []
    act = actions_ind[row]
    struct = structures[row]
    start = 0
    for step in struct:
        end = start + step
        a = act[start:end]
        padding = max_actions_per_subprogram - step - 1
        action_row.append(a + [action_map['end_action']] + [0] * padding)
        start = end
    actions_structured.append(
        action_row + [[action_map['end_subprogram']] + [0] * (max_actions_per_subprogram - 1)] +
        [[0] * max_actions_per_subprogram] * (max_num_subprograms - len(struct) - 1)
    )
act_np = np.array(actions_structured)
struct_padded = [[sa + 1 for sa in s] + [1] + [0] * (max_num_subprograms - len(s) - 1) for s in structures]
struct_np = np.array(struct_padded)

mask_list = [[np.concatenate((np.ones(st), np.zeros(max_actions_per_subprogram - st)), 0) 
              for st in s] for s in struct_np]
mask_np = np.array(mask_list)

In [8]:
tf.reset_default_graph()
default_sizes = 128
size_emb = 64
num_layers_encoder = 6
hidden_filters = 128
num_layers_subprogram = 3
hidden_filters_subprogram = 128
init_mag = 1e-3
cmd_mat = tf.Variable(init_mag*tf.random_normal([num_cmd, size_emb]))
act_mat = tf.Variable(init_mag*tf.random_normal([num_act, size_emb]))
act_st_emb = tf.Variable(init_mag*tf.random_normal([size_emb]))
global_bs = None
global_time_len = 7
action_lengths = None
max_num_actions= None
# global_bs = 8
global_time_len = 7
max_num_actions = 9
output_keep_prob = tf.placeholder_with_default(1.0, ())
state_keep_prob = tf.placeholder_with_default(1.0, ())
cmd_ind = tf.placeholder(tf.int32, shape=(global_bs, 10,))
act_ind = tf.placeholder(tf.int32, shape=(global_bs, global_time_len, 9))
mask_ph = tf.placeholder(tf.float32, shape=(global_bs, global_time_len, 9))
cmd_lengths = tf.placeholder(tf.int32, shape=(global_bs,))
act_lengths = tf.placeholder(tf.int32, shape=(global_bs, 7))
learning_rate = tf.placeholder(tf.float32, shape = (None))

cmd_emb = tf.nn.embedding_lookup(cmd_mat, cmd_ind)
act_emb = tf.nn.embedding_lookup(act_mat, act_ind)
tf_bs = tf.shape(act_ind)[0]
act_st_emb_expanded = tf.tile(tf.reshape(
    act_st_emb, [1, 1, 1, size_emb]), [tf_bs, global_time_len, 1, 1])
act_emb_with_st = tf.concat((act_st_emb_expanded, act_emb), 2)

first_cell_encoder = [tf.nn.rnn_cell.LSTMCell(
    hidden_filters, forget_bias=1., name = 'layer1_'+d) for d in ['f', 'b']]
hidden_cells_encoder = [[tf.nn.rnn_cell.LSTMCell(
    hidden_filters,forget_bias=1., name = 'layer' + str(lidx) + '_' + d)  for d in ['f', 'b']]
                        for lidx in range(num_layers_encoder - 1)]
hidden_cells_encoder = [[tf.nn.rnn_cell.DropoutWrapper(cell,
    output_keep_prob=output_keep_prob, state_keep_prob=state_keep_prob,
    variational_recurrent=True, dtype=tf.float32) for cell in cells] 
                        for cells in hidden_cells_encoder[:-1]] + [hidden_cells_encoder[-1]]
cells_encoder = [first_cell_encoder] + hidden_cells_encoder
c1, c2 = zip(*cells_encoder)
cells_encoder = [c1, c2]
def encode(x, num_layers, cells, initial_states, lengths, name='',):
    prev_layer = x
    shortcut = x
    hiddenlayers = []
    returncells = []
    cell_fw, cell_bw = cells
    bs = tf_bs
    for idx in range(num_layers):
        prev_layer, c = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cell_fw[idx],
                cell_bw = cell_bw[idx],
                inputs = prev_layer,
                sequence_length=lengths,
                initial_state_fw=None,
                initial_state_bw=None,
                dtype=tf.float32,
                scope='encoder'+str(idx)
            )
        if idx == num_layers - 1:
            fw = prev_layer[0]
            bw = prev_layer[1]
            stacked = tf.stack([tf.range(bs), lengths - 1], 1)
            fw_final = tf.gather_nd(fw,stacked,name=None)
            bw_final = bw[:,0,:]
            output = tf.concat((fw_final, bw_final), 1)
        prev_layer = tf.concat(prev_layer, 2)
        prev_layer = tf.nn.leaky_relu(prev_layer)
        returncells.append(c)
        hiddenlayers.append(prev_layer)
        if idx == num_layers - 1:
            #pdb.set_trace()
            #stacked = tf.stack([tf.range(bs), lengths - 1], 1)
            #output = tf.gather_nd(prev_layer,stacked,name=None)
            return prev_layer, returncells, hiddenlayers, output, fw, stacked
        prev_layer = tf.concat((prev_layer, shortcut), 2)
encoding_last_layer, encoding_final_cells, encoding_hidden_layers, encoding_last_timestep, dbg1, dbg2 = encode(
    cmd_emb, num_layers_encoder, cells_encoder,None, lengths = cmd_lengths, name = 'encoder')
# encoding_last_timestep = encoding_last_layer[:,cmd_lengths, :]
hidden_filters_encoder = encoding_last_timestep.shape[-1].value
first_cell_subprogram = tf.nn.rnn_cell.LSTMCell(
    hidden_filters_subprogram, forget_bias=1., name = 'subpogramlayer1_')
hidden_cells_subprogram = [tf.nn.rnn_cell.LSTMCell(
    hidden_filters_subprogram,forget_bias=1., name = 'subpogramlayer' + str(lidx))
                        for lidx in range(num_layers_subprogram - 1)]

cells_subprogram_rnn = [first_cell_subprogram] + hidden_cells_subprogram

attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
    num_units=hidden_filters_encoder, memory=encoding_last_layer,
    memory_sequence_length=cmd_lengths)
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units=hidden_filters_encoder//2, memory=encoding_last_layer,
    memory_sequence_length=cmd_lengths)
cells_subprogram = [
    tf.contrib.seq2seq.AttentionWrapper(
        cell, attention_mechanism, attention_layer_size = hidden_filters_subprogram) 
    for cell in cells_subprogram_rnn]

def subprogram(x, num_layers, cells, initial_states, lengths, reuse, name='',):
    prev_layer = x
    shortcut = x
    hiddenlayers = []
    returncells = []
    bs = tf.shape(x)[0]
    for idx in range(num_layers):
        print(idx)
        if idx == 0:
            num_past_units = hidden_filters
        else:
            num_past_units = hidden_filters_subprogram
        with tf.variable_scope(name + 'subprogram' + str(idx), reuse=reuse):
#             self_attention_mechanism = tf.contrib.seq2seq.LuongAttention(
#                 num_units=num_past_units, memory=prev_layer,
#                 memory_sequence_length=tf.expand_dims(tf.range(10), 0))
#             cell_with_selfattention = tf.contrib.seq2seq.AttentionWrapper(
#                     cells[idx], self_attention_mechanism, attention_layer_size = num_past_units)

            prev_layer, c = tf.nn.dynamic_rnn(
                    cell = cells[idx],
                    inputs = prev_layer,
                    sequence_length=lengths,
                    initial_state = None,
                    dtype=tf.float32,
                )
            prev_layer = tf.concat(prev_layer, 2)
            prev_layer = tf.nn.leaky_relu(prev_layer)
            returncells.append(c)
            hiddenlayers.append(prev_layer)
            if idx == num_layers - 1:
                output = tf.gather_nd(
                            prev_layer,
                            tf.stack([tf.range(bs), lengths], 1),
                            name=None
                        )
                return prev_layer, returncells, hiddenlayers, output
            prev_layer = tf.concat((prev_layer, shortcut), 2)
encodings = [encoding_last_timestep]
last_encoding = encoding_last_timestep
initial_cmb_encoding = last_encoding
loss = 0
action_probabilities_presoftmax = []
for sub_idx in range(max_num_subprograms): 
    from_last_layer = tf.tile(tf.expand_dims(tf.concat((
        initial_cmb_encoding, last_encoding), 1), 1), [1, max_num_actions + 1, 1])
    autoregressive = act_emb_with_st[:,sub_idx, :, :]
    x_input = tf.concat((from_last_layer, autoregressive), -1)
    subprogram_last_layer, _, subprogram_hidden_layers, subprogram_output = subprogram(
        x_input, num_layers_subprogram, cells_subprogram,None, 
        lengths = act_lengths[:, sub_idx], reuse = (sub_idx > 0), name = 'subprogram')
    action_prob_flat = mlp(
        tf.reshape(subprogram_last_layer, [-1, hidden_filters_subprogram]),
        [], output_size = num_act, name = 'action_choice_mlp', reuse = (sub_idx > 0))
    action_prob_expanded = tf.reshape(action_prob_flat, [-1, max_num_actions + 1, num_act])
    action_probabilities_layer = tf.nn.softmax(action_prob_expanded, axis=-1)
    action_probabilities_presoftmax.append(action_prob_expanded)
    delta1, delta2 = [mlp(
        subprogram_output, [256,], output_size = hidden_filters_encoder, name = 'global_transform' + str(idx),
        reuse = (sub_idx > 0)
    ) for idx in range(2)]
    remember = tf.sigmoid(delta1)
    insert = tf.tanh(delta2) + delta2/100
    last_encoding = last_encoding * remember + insert
    encodings.append(last_encoding)
    encodings.append(last_encoding)
act_presoftmax = tf.stack(action_probabilities_presoftmax, 1)[:, :, :-1, :]
#batch, subprogram, timestep, action_selection
logprobabilities = tf.nn.log_softmax(act_presoftmax, -1)
act_presoftmax_flat = tf.reshape(act_presoftmax, [-1, 9, num_act])
mask_ph_flat = tf.reshape(mask_ph, [-1, max_actions_per_subprogram])
act_ind_flat = tf.reshape(act_ind, [-1, max_actions_per_subprogram])
ppl_loss = tf.contrib.seq2seq.sequence_loss(
    logits = act_presoftmax_flat,
    targets = act_ind_flat,
    weights = mask_ph_flat,
    average_across_timesteps=False,
    average_across_batch=False,
    softmax_loss_function=None,
    name=None
)
ppl_loss_avg = tf.reduce_mean(tf.pow(ppl_loss, 2.0)) * 10000 # + tf.reduce_mean(tf.pow(ppl_loss, 1.0)) * 100

tfvars = tf.trainable_variables()
weight_norm = tf.reduce_mean([tf.reduce_sum(tf.square(var)) for var in tfvars])*1e-3

action_taken = tf.argmax(act_presoftmax, -1, output_type=tf.int32)
correct_mat = tf.cast(tf.equal(action_taken, act_ind), tf.float32) * mask_ph
correct_percent = tf.reduce_sum(correct_mat, [1, 2])/tf.reduce_sum(mask_ph, [1, 2])
percent_correct = tf.reduce_mean(correct_percent)
percent_fully_correct = tf.reduce_mean(tf.cast(tf.equal(correct_percent, 1), tf.float32))

loss = ppl_loss_avg + weight_norm

opt_fcn = tf.train.AdamOptimizer(learning_rate=learning_rate)
#opt_fcn = tf.train.MomentumOptimizer(learning_rate=learning_rate, use_nesterov=True, momentum=.8)
optimizer, grad_norm_total = apply_clipped_optimizer(opt_fcn, loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

0
1
2
0
1
2
0
1
2
0
1
2
0
1
2
0
1
2
0
1
2
Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [9]:
hidden_filters

128

In [10]:
encoding_last_layer

<tf.Tensor 'LeakyRelu_5/Maximum:0' shape=(?, 10, 256) dtype=float32>

In [None]:
np.random.seed(0)
trn_percent = .1
num_samples = mask_np.shape[0]
ordered_samples = np.arange(num_samples)
np.random.shuffle(ordered_samples)
trn_samples = ordered_samples[:int(np.ceil(num_samples*trn_percent))]
val_samples_all = ordered_samples[int(np.ceil(num_samples*trn_percent)):]
val_samples = val_samples_all
trn_samples.shape, val_samples.shape

((2091,), (18819,))

In [None]:
eval_itr = -1
bs = 32# trn_samples.shape[0]
for itr in range(1000000):
    if 1:#itr == 0:
        samples = np.random.choice(trn_samples, size = bs, replace = False)
        trn_feed_dict = {
            cmd_ind : cmd_np[samples],
            act_ind : act_np[samples],
            mask_ph : mask_np[samples],
            act_lengths : np.clip(struct_np[samples], a_min = 1, a_max = None),
            cmd_lengths : cmd_lengths_np[samples],
        }
        
    trn_feed_dict[learning_rate] = .1 / (np.power(itr + 10, .6))
    _, trn_loss, acc_trn_single, acc_trn = sess.run(
        [optimizer, loss, percent_correct, percent_fully_correct], trn_feed_dict)
    if itr == 0:
        trn_loss_avg = trn_loss
        acc_trn_avg = acc_trn
        acc_trn_single_avg = acc_trn_single
    else:
        trn_loss_avg = trn_loss_avg * .9 + trn_loss * .1
        acc_trn_avg = acc_trn_avg * .9 + acc_trn * .1
        acc_trn_single_avg = acc_trn_single_avg * .9 + acc_trn_single * .1
    if itr % 10 == 0 and itr > 0:
        # val_samples = np.random.choice(val_samples_all, size = bs, replace = False)
        eval_itr += 1
        val_feed_dict = {
            cmd_ind : cmd_np[val_samples],
            act_ind : act_np[val_samples],
            mask_ph : mask_np[val_samples],
            act_lengths : np.clip(struct_np[val_samples], a_min = 1, a_max = None),
            cmd_lengths : cmd_lengths_np[val_samples]
        }
        val_loss, acc_val = sess.run([loss, percent_fully_correct], val_feed_dict)
        if eval_itr == 0:
            val_loss_avg = val_loss
            acc_val_avg = acc_val
        else:
            val_loss_avg = val_loss_avg * .9 + val_loss * .1
            acc_val_avg = acc_val_avg * .9 + acc_val * .1
        print('itr:', itr, 'trn_loss', trn_loss_avg, 'val_loss', val_loss_avg)
        print('itr:', itr, 'trn_acc', acc_trn_avg, 
              'trn_single_acc', acc_trn_single_avg, 'val_acc', acc_val_avg)

itr: 10 trn_loss 16630.384048821194 val_loss 12018.526
itr: 10 trn_acc 0.0 trn_single_acc 0.16103598996132829 val_acc 0.0
itr: 20 trn_loss 13392.964377365557 val_loss 11845.1580078125
itr: 20 trn_acc 0.0 trn_single_acc 0.25108245340846247 val_acc 0.0
itr: 30 trn_loss 10869.12145663787 val_loss 11602.02326171875
itr: 30 trn_acc 0.0 trn_single_acc 0.3183858962893595 val_acc 0.0
itr: 40 trn_loss 8814.813520986525 val_loss 11170.865564453125
itr: 40 trn_acc 0.0 trn_single_acc 0.3589312781092003 val_acc 0.0
itr: 50 trn_loss 6834.988671566404 val_loss 10679.984427929687
itr: 50 trn_acc 0.0 trn_single_acc 0.429301544099401 val_acc 5.313778456184082e-06
itr: 60 trn_loss 6111.141023412533 val_loss 10141.669725371094
itr: 60 trn_acc 0.0 trn_single_acc 0.48764815344230383 val_acc 4.782400610565674e-06
itr: 70 trn_loss 5426.891519935889 val_loss 9594.882635646485
itr: 70 trn_acc 0.0 trn_single_acc 0.5313174842620938 val_acc 3.08730531942274e-05
itr: 80 trn_loss 4811.951671645115 val_loss 9089.2522

itr: 540 trn_loss 518.4194455228023 val_loss 407.58694709475407
itr: 540 trn_acc 0.638443374447829 trn_single_acc 0.9754300913201217 val_acc 0.7825579692856883
itr: 550 trn_loss 296.250253621785 val_loss 371.55852620424844
itr: 550 trn_acc 0.7644459801383993 trn_single_acc 0.9849190089232011 val_acc 0.7993869289698973
itr: 560 trn_loss 234.7399438609749 val_loss 346.87777352889196
itr: 560 trn_acc 0.8921897139234978 trn_single_acc 0.9929274504418617 val_acc 0.8119664345630504
itr: 570 trn_loss 128.387076369661 val_loss 314.9952989866717
itr: 570 trn_acc 0.9149270046334856 trn_single_acc 0.9945011695872368 val_acc 0.8290162462993113
itr: 580 trn_loss 54.187113276873056 val_loss 285.7136195823893
itr: 580 trn_acc 0.9537715403059692 trn_single_acc 0.9971327150729316 val_acc 0.8448446314636153
itr: 590 trn_loss 63.27112749288171 val_loss 270.7552886300098
itr: 590 trn_acc 0.9448753293169099 trn_single_acc 0.9967476366841903 val_acc 0.8500036109480217
itr: 600 trn_loss 75.7824826232668 val_

itr: 1060 trn_loss 18.47589411817309 val_loss 142.92142605302269
itr: 1060 trn_acc 0.9941139031094104 trn_single_acc 0.9995860306614497 val_acc 0.9794609989166377
itr: 1070 trn_loss 8.852100679114708 val_loss 136.64480850265204
itr: 1070 trn_acc 0.9962868917929117 trn_single_acc 0.9998068128819639 val_acc 0.9807444017028915
itr: 1080 trn_loss 15.425051165335566 val_loss 129.92077793925208
itr: 1080 trn_acc 0.9890258801943053 trn_single_acc 0.9994773677546641 val_acc 0.9820748210754491
itr: 1090 trn_loss 5.739136252289459 val_loss 127.30046375616672
itr: 1090 trn_acc 0.9961735610246798 trn_single_acc 0.9998177694039504 val_acc 0.9828630364662256
itr: 1100 trn_loss 127.19792260270466 val_loss 135.4211864235188
itr: 1100 trn_acc 0.9907550501019475 trn_single_acc 0.9995244090307203 val_acc 0.9831632682733995
itr: 1110 trn_loss 117.01314612755964 val_loss 136.448174836831
itr: 1110 trn_acc 0.9870315549813696 trn_single_acc 0.9992949270471595 val_acc 0.9832900025594171
itr: 1120 trn_loss 50.

itr: 1570 trn_loss 57.38798131951699 val_loss 22.601793708055528
itr: 1570 trn_acc 0.9960360027530036 trn_single_acc 0.9998523954062404 val_acc 0.9946640880892093
itr: 1580 trn_loss 20.47700044900031 val_loss 20.700480432792883
itr: 1580 trn_acc 0.9986178396233567 trn_single_acc 0.9999485334604963 val_acc 0.9950488941660001
itr: 1590 trn_loss 7.412791542309354 val_loss 18.939658258692365
itr: 1590 trn_acc 0.9995180704759038 trn_single_acc 0.9999820547272885 val_acc 0.9954271021596022
itr: 1600 trn_loss 2.8558480723293735 val_loss 17.31541448921198
itr: 1600 trn_acc 0.9998319615653001 trn_single_acc 0.9999937428703037 val_acc 0.9957674893538441
itr: 1610 trn_loss 1.2656639822957543 val_loss 16.012928588005018
itr: 1610 trn_acc 0.9999414086207119 trn_single_acc 0.9999978182737779 val_acc 0.9960313277960079
itr: 1620 trn_loss 0.7147779177571 val_loss 14.565954014192918
itr: 1620 trn_acc 0.9999795704492666 trn_single_acc 0.9999992392791041 val_acc 0.9963644299674265
itr: 1630 trn_loss 0.52

itr: 2080 trn_loss 0.3369230999942918 val_loss 13.999696853277097
itr: 2080 trn_acc 0.9999999999999497 trn_single_acc 0.9999999999999976 val_acc 0.9984872756359457
itr: 2090 trn_loss 0.3354359069916078 val_loss 13.982544340953538
itr: 2090 trn_acc 0.9999999999999825 trn_single_acc 0.9999999999999991 val_acc 0.9985003904662264
itr: 2100 trn_loss 0.33395898442462923 val_loss 13.99284732186673
itr: 2100 trn_acc 0.9999999999999938 trn_single_acc 0.9999999999999994 val_acc 0.9985121938134789
itr: 2110 trn_loss 0.33249311611151994 val_loss 14.024265275989872
itr: 2110 trn_acc 0.9999999999999978 trn_single_acc 0.9999999999999994 val_acc 0.9985228168260061
itr: 2120 trn_loss 0.3310377971554547 val_loss 14.075074183876726
itr: 2120 trn_acc 0.9999999999999992 trn_single_acc 0.9999999999999994 val_acc 0.9985323775372805
itr: 2130 trn_loss 0.3295942260269996 val_loss 14.142529953660441
itr: 2130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985409821774276
itr: 2140 trn_l

itr: 2580 trn_loss 0.27303652609214935 val_loss 18.44792107312207
itr: 2580 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985780951120112
itr: 2590 trn_loss 0.27192869596622493 val_loss 18.381887014821096
itr: 2590 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985768112603712
itr: 2600 trn_loss 0.27082643254063343 val_loss 18.297974779708614
itr: 2600 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985756557938953
itr: 2610 trn_loss 0.26972807497377677 val_loss 18.2146609137983
itr: 2610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985746158740669
itr: 2620 trn_loss 0.2686325290106458 val_loss 18.141759733307143
itr: 2620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985736799462214
itr: 2630 trn_loss 0.2675457801165735 val_loss 18.077488774014515
itr: 2630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985728376111604
itr: 2640 trn_l

itr: 3080 trn_loss 0.22302722103827538 val_loss 12.22570236339406
itr: 3080 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985640959968152
itr: 3090 trn_loss 0.2221226782459449 val_loss 12.051784203775846
itr: 3090 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985642120566949
itr: 3100 trn_loss 0.22122336538184129 val_loss 11.866509157116523
itr: 3100 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985643165105865
itr: 3110 trn_loss 0.2203328447359417 val_loss 11.71573289426742
itr: 3110 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985644105190891
itr: 3120 trn_loss 0.21943552838402003 val_loss 11.56894693546812
itr: 3120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985644951267413
itr: 3130 trn_loss 0.21854131852651765 val_loss 11.418662635445477
itr: 3130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985645712736283
itr: 3140 trn_l

itr: 3590 trn_loss 0.18062980239364795 val_loss 3.832119235271194
itr: 3590 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9989481126341653
itr: 3600 trn_loss 0.17986713866776996 val_loss 3.673575552291255
itr: 3600 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9989735920792907
itr: 3610 trn_loss 0.17910455739255032 val_loss 3.5283404040216633
itr: 3610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9989912128060541
itr: 3620 trn_loss 0.17834682278396743 val_loss 3.375417814959219
itr: 3620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9990123822339906
itr: 3630 trn_loss 0.1775909364169483 val_loss 3.2452289101265053
itr: 3630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9990314347191335
itr: 3640 trn_loss 0.17683699875319955 val_loss 3.143787464807825
itr: 3640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9990379544475986
itr: 3650 trn_l

itr: 4100 trn_loss 0.32807753453123584 val_loss 37.7262859903892
itr: 4100 trn_acc 0.999884951640765 trn_single_acc 0.9999939952267757 val_acc 0.9906018324012822
itr: 4110 trn_loss 2.5957111540017737 val_loss 38.7028433349538
itr: 4110 trn_acc 0.9981146038675658 trn_single_acc 0.9998661009642669 val_acc 0.9906542492492735
itr: 4120 trn_loss 3.096912005943724 val_loss 35.44067000151946
itr: 4120 trn_acc 0.9981319139894471 trn_single_acc 0.9998432509137362 val_acc 0.9914612942263847
itr: 4130 trn_loss 1.2108092413487561 val_loss 35.56919287563509
itr: 4130 trn_acc 0.9993486386838677 trn_single_acc 0.9999453449731144 val_acc 0.9919060027592211
itr: 4140 trn_loss 0.555924049703387 val_loss 34.12395285595488
itr: 4140 trn_acc 0.9997728843523496 trn_single_acc 0.999980942970482 val_acc 0.9923487504714277
itr: 4150 trn_loss 0.3653399836505904 val_loss 36.23881739701955
itr: 4150 trn_acc 0.9999208096702549 trn_single_acc 0.9999933552246746 val_acc 0.9925665557736348
itr: 4160 trn_loss 0.256630

itr: 4610 trn_loss 0.17343204854309435 val_loss 3.675760669775147
itr: 4610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9992796513623595
itr: 4620 trn_loss 0.17288227976773857 val_loss 3.6440957337027835
itr: 4620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9992879211771429
itr: 4630 trn_loss 0.1723289788527032 val_loss 3.615999050385301
itr: 4630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9992953640104479
itr: 4640 trn_loss 0.17177598978944622 val_loss 3.5881939298926695
itr: 4640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993020625604225
itr: 4650 trn_loss 0.17123253355167553 val_loss 3.560290917785666
itr: 4650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993080912553997
itr: 4660 trn_loss 0.170677601244779 val_loss 3.5302212223694043
itr: 4660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999313517080879
itr: 4670 trn_los

itr: 5120 trn_loss 0.14663303054137544 val_loss 2.786480251334358
itr: 5120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994649567369969
itr: 5130 trn_loss 0.14613286580877127 val_loss 2.7822368950089054
itr: 5130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.99946532352248
itr: 5140 trn_loss 0.14563687206401413 val_loss 2.7754722127711986
itr: 5140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994656536294148
itr: 5150 trn_loss 0.14513592761497487 val_loss 2.7673866709526966
itr: 5150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999465950725656
itr: 5160 trn_loss 0.14464245176581894 val_loss 2.7603988373687796
itr: 5160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994662181122732
itr: 5170 trn_loss 0.14414881035660324 val_loss 2.7555230847964527
itr: 5170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994664587602287
itr: 5180 trn_

itr: 5620 trn_loss 0.12282789313783563 val_loss 2.566508849569568
itr: 5620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994942251294505
itr: 5630 trn_loss 0.12237057733183662 val_loss 2.5427392960563977
itr: 5630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994969758495378
itr: 5640 trn_loss 0.12191491399908896 val_loss 2.5211196757235244
itr: 5640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994994514976163
itr: 5650 trn_loss 0.12146711691141704 val_loss 2.503734852392505
itr: 5650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9995016795808869
itr: 5660 trn_loss 0.12101534872208121 val_loss 2.493372224735347
itr: 5660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9995036848558305
itr: 5670 trn_loss 0.12056436238760111 val_loss 2.487140747577364
itr: 5670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9995054896032798
itr: 5680 trn_

itr: 6120 trn_loss 0.1012069479149661 val_loss 2.4779312358911394
itr: 6120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994735976866949
itr: 6130 trn_loss 0.1008115372546184 val_loss 2.478788377042016
itr: 6130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994731003772082
itr: 6140 trn_loss 0.100405394306258 val_loss 2.4828623543936614
itr: 6140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.99947265279867
itr: 6150 trn_loss 0.09998467559475623 val_loss 2.4992662558637804
itr: 6150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994722499779858
itr: 6160 trn_loss 0.0995722657544399 val_loss 2.5136882087304664
itr: 6160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994665707050561
itr: 6170 trn_loss 0.09916847948677895 val_loss 2.5215610403461284
itr: 6170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994667760937332
itr: 6180 trn_loss

itr: 6620 trn_loss 0.08183594937897196 val_loss 2.3493908351816466
itr: 6620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999445882665985
itr: 6630 trn_loss 0.08149790333999418 val_loss 2.33184360022763
itr: 6630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994481568585692
itr: 6640 trn_loss 0.08112564382924818 val_loss 2.3492663731165493
itr: 6640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994395761237317
itr: 6650 trn_loss 0.08075122069960551 val_loss 2.362536933719018
itr: 6650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994318534623778
itr: 6660 trn_loss 0.08039616863345564 val_loss 2.319762420683115
itr: 6660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994355305753228
itr: 6670 trn_loss 0.08003282114476107 val_loss 2.3284600305816614
itr: 6670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994335232426593
itr: 6680 trn_l

itr: 7120 trn_loss 0.0648932932676293 val_loss 2.052195954107476
itr: 7120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999413327091487
itr: 7130 trn_loss 0.06457684764460335 val_loss 2.0643864701240355
itr: 7130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9994029125990436
itr: 7140 trn_loss 0.0642651533954653 val_loss 2.093355390135863
itr: 7140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993935395558445
itr: 7150 trn_loss 0.06395873174033856 val_loss 2.132561873827355
itr: 7150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999379793043116
itr: 7160 trn_loss 0.06366958618495595 val_loss 2.1399604601567774
itr: 7160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993727319555097
itr: 7170 trn_loss 0.06334687860404567 val_loss 2.1561150422310043
itr: 7170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9993610662028146
itr: 7180 trn_los

itr: 7630 trn_loss 11.974597729862467 val_loss 52.14012199705934
itr: 7630 trn_acc 0.9985783000863143 trn_single_acc 0.9999278765645989 val_acc 0.9897930187396278
itr: 7640 trn_loss 4.232118018120615 val_loss 47.61274815260853
itr: 7640 trn_acc 0.9995042838918057 trn_single_acc 0.9999748521130496 val_acc 0.990601166702396
itr: 7650 trn_loss 7.302914154401743 val_loss 42.8637736826493
itr: 7650 trn_acc 0.9944834044806624 trn_single_acc 0.9996840166178587 val_acc 0.9915304225239929
itr: 7660 trn_loss 2.604174836881281 val_loss 38.610177808048164
itr: 7660 trn_acc 0.9980764820796546 trn_single_acc 0.9998898234072173 val_acc 0.9923508085209528
itr: 7670 trn_loss 0.9662171370001308 val_loss 34.99706747505829
itr: 7670 trn_acc 0.9993293107720295 trn_single_acc 0.999961583797493 val_acc 0.9929456875382423
itr: 7680 trn_loss 0.3940053861478063 val_loss 31.531567511323974
itr: 7680 trn_acc 0.9997661451261992 trn_single_acc 0.9999866050984353 val_acc 0.9936192362599277
itr: 7690 trn_loss 0.19362

itr: 8130 trn_loss 0.08315488170541095 val_loss 2.247345849431753
itr: 8130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9985185123261725
itr: 8140 trn_loss 0.08323088777718037 val_loss 2.0464865844463107
itr: 8140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9986507168510779
itr: 8150 trn_loss 0.11670736108669844 val_loss 1.9138534320308087
itr: 8150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9987165633826754
itr: 8160 trn_loss 0.09247834179147922 val_loss 1.7536273407560574
itr: 8160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.998823652028081
itr: 8170 trn_loss 0.0840126013265152 val_loss 1.6095762170215238
itr: 8170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.998914715074632
itr: 8180 trn_loss 0.08109130731356376 val_loss 1.5394635781818438
itr: 8180 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9989382235018613
itr: 8190 trn_

itr: 8630 trn_loss 0.0720778494690358 val_loss 0.26403586706774923
itr: 8630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997610954323297
itr: 8640 trn_loss 0.07191247625747725 val_loss 0.2748247233741655
itr: 8640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997584141384559
itr: 8650 trn_loss 0.07175175226384933 val_loss 0.2866805357676602
itr: 8650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997560009739695
itr: 8660 trn_loss 0.07158559525092574 val_loss 0.29934603836612617
itr: 8660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997485183520822
itr: 8670 trn_loss 0.07142621296356884 val_loss 0.313934539554805
itr: 8670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997417839923837
itr: 8680 trn_loss 0.0712595751970176 val_loss 0.3289382739051793
itr: 8680 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999735723068655
itr: 8690 tr

itr: 9130 trn_loss 0.06391603591798431 val_loss 0.8550290807277187
itr: 9130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996249313472925
itr: 9140 trn_loss 0.06375104370422992 val_loss 0.8538425038030945
itr: 9140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996305556880729
itr: 9150 trn_loss 0.06358616106150364 val_loss 0.8497118118200782
itr: 9150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996356175947754
itr: 9160 trn_loss 0.06341770636387654 val_loss 0.83975255752749
itr: 9160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996401733108076
itr: 9170 trn_loss 0.06325657629817866 val_loss 0.8269544744355687
itr: 9170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996442734552364
itr: 9180 trn_loss 0.06309730052947866 val_loss 0.8149072322723158
itr: 9180 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9996479635852225
itr: 9190 tr

itr: 9630 trn_loss 0.055818992723356615 val_loss 0.3332144316458648
itr: 9630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997572908325728
itr: 9640 trn_loss 0.055668460427499296 val_loss 0.3234798487803887
itr: 9640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997603067329887
itr: 9650 trn_loss 0.055493429667487854 val_loss 0.31487379462554016
itr: 9650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997630210433629
itr: 9660 trn_loss 0.05532324435782561 val_loss 0.30705012969106993
itr: 9660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997654639226997
itr: 9670 trn_loss 0.055177475802964615 val_loss 0.3003026314848414
itr: 9670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997676625141029
itr: 9680 trn_loss 0.055014362514786704 val_loss 0.29365971847835937
itr: 9680 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997696412463657
it

itr: 10130 trn_loss 0.047964878594471884 val_loss 0.2264492963720994
itr: 10130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997941888015275
itr: 10140 trn_loss 0.04780573309651904 val_loss 0.22627333000030847
itr: 10140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997935149050479
itr: 10150 trn_loss 24.570524878956352 val_loss 105.87980566496904
itr: 10150 trn_acc 0.9878719374999997 trn_single_acc 0.9992512362893222 val_acc 0.9882248150983505
itr: 10160 trn_loss 77.02347779685749 val_loss 102.67343596883347
itr: 10160 trn_acc 0.9825787163423148 trn_single_acc 0.9985880698931898 val_acc 0.9874468601655113
itr: 10170 trn_loss 51.636313947453694 val_loss 97.78285247326848
itr: 10170 trn_acc 0.9513831333646989 trn_single_acc 0.9971419663637657 val_acc 0.9870655319313242
itr: 10180 trn_loss 324.76050675307215 val_loss 147.172432216176
itr: 10180 trn_acc 0.9434139390009288 trn_single_acc 0.9956495098989643 val_acc 0.9644894881147177
itr:

itr: 10630 trn_loss 0.058480791044520586 val_loss 7.810820319043499
itr: 10630 trn_acc 0.9999999956440793 trn_single_acc 0.9999999997707028 val_acc 0.9980785220393651
itr: 10640 trn_loss 0.05782773205831561 val_loss 7.245140890288563
itr: 10640 trn_acc 0.9999999984811844 trn_single_acc 0.999999999920049 val_acc 0.9981006297048134
itr: 10650 trn_loss 0.05752506268655263 val_loss 6.813116616018007
itr: 10650 trn_acc 0.9999999994704217 trn_single_acc 0.9999999999721227 val_acc 0.9981205266037169
itr: 10660 trn_loss 0.05735772526968731 val_loss 6.435092055208138
itr: 10660 trn_acc 0.9999999998153474 trn_single_acc 0.9999999999902799 val_acc 0.99813843381273
itr: 10670 trn_loss 0.0572254900127309 val_loss 6.092316988176093
itr: 10670 trn_acc 0.9999999999356155 trn_single_acc 0.9999999999966108 val_acc 0.9981545503008419
itr: 10680 trn_loss 0.9322537846893737 val_loss 5.499603944777833
itr: 10680 trn_acc 0.9971874999775506 trn_single_acc 0.999783653019677 val_acc 0.9983072127462673
itr: 1069

itr: 11130 trn_loss 0.05413372348420498 val_loss 0.2597277219170441
itr: 11130 trn_acc 0.9999999999999988 trn_single_acc 0.9999999999999994 val_acc 0.9999051757261478
itr: 11140 trn_loss 0.05402709455715593 val_loss 0.2404207882351377
itr: 11140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999909341419219
itr: 11150 trn_loss 0.053923574911199396 val_loss 0.2230424423873453
itr: 11150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999130905429833
itr: 11160 trn_loss 0.0538179087838758 val_loss 0.2074091201895501
itr: 11160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999916464754371
itr: 11170 trn_loss 0.05371126759570034 val_loss 0.19335027926033965
itr: 11170 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.99991950154462
itr: 11180 trn_loss 0.05360838430763549 val_loss 0.1807058600443563
itr: 11180 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999222346558441
i

itr: 11620 trn_loss 0.04901381777633289 val_loss 0.06618394741417911
itr: 11620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999465941119089
itr: 11630 trn_loss 0.04891068911187585 val_loss 0.06601036374868703
itr: 11630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999946617966404
itr: 11640 trn_loss 0.048798538866604735 val_loss 0.06585189284475224
itr: 11640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999466394354497
itr: 11650 trn_loss 0.04869970090566717 val_loss 0.06571318654648879
itr: 11650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999466587575908
itr: 11660 trn_loss 0.048595625894710906 val_loss 0.06557358457236664
itr: 11660 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999466761475178
itr: 11670 trn_loss 0.048486770084458515 val_loss 0.06542122799386839
itr: 11670 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999466

itr: 12110 trn_loss 0.04387430136267254 val_loss 0.06409139686154304
itr: 12110 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999165846923774
itr: 12120 trn_loss 0.043769624664531805 val_loss 0.06427392441028636
itr: 12120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999142987149763
itr: 12130 trn_loss 0.04366580743333123 val_loss 0.06445601205224646
itr: 12130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999122413353153
itr: 12140 trn_loss 0.04356650074121359 val_loss 0.06464368659432101
itr: 12140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999103896936203
itr: 12150 trn_loss 0.043452399712386805 val_loss 0.06483291439570817
itr: 12150 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9999087232160949
itr: 12160 trn_loss 0.043350324778308455 val_loss 0.06503966616712063
itr: 12160 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999907

itr: 12600 trn_loss 75.45805332068186 val_loss 41.528633970194285
itr: 12600 trn_acc 0.9673602844664062 trn_single_acc 0.9979188842129969 val_acc 0.9873547070421035
itr: 12610 trn_loss 26.524947911589972 val_loss 37.84268308118572
itr: 12610 trn_acc 0.9886192349024387 trn_single_acc 0.9992743597937201 val_acc 0.9884970170137812
itr: 12620 trn_loss 27.478160181467082 val_loss 36.656120423396736
itr: 12620 trn_acc 0.9838442725786378 trn_single_acc 0.9991621852574691 val_acc 0.985274074841517
itr: 12630 trn_loss 14.433818628606506 val_loss 33.238123742270744
itr: 12630 trn_acc 0.9913108758859137 trn_single_acc 0.9995416176152834 val_acc 0.9865128621777693
itr: 12640 trn_loss 5.541571217455349 val_loss 30.330755240495453
itr: 12640 trn_acc 0.994157789758065 trn_single_acc 0.9997522813201277 val_acc 0.9876915358293772
itr: 12650 trn_loss 1.9634491135268737 val_loss 27.618333131172168
itr: 12650 trn_acc 0.9979629472461059 trn_single_acc 0.9999136258371184 val_acc 0.9887523421158243
itr: 1266

itr: 13100 trn_loss 0.056960116115193714 val_loss 6.060400041228839
itr: 13100 trn_acc 0.9999998345556865 trn_single_acc 0.9999999913054889 val_acc 0.9982475243338371
itr: 13110 trn_loss 0.056150431958993 val_loss 5.497039078730024
itr: 13110 trn_acc 0.9999999423131349 trn_single_acc 0.9999999969684115 val_acc 0.9983696343596361
itr: 13120 trn_loss 0.05580048418981307 val_loss 4.9848508707204555
itr: 13120 trn_acc 0.999999979885834 trn_single_acc 0.9999999989429503 val_acc 0.9984848441567047
itr: 13130 trn_loss 0.055640679229718146 val_loss 4.518278634982028
itr: 13130 trn_acc 0.9999999929866239 trn_single_acc 0.9999999996314295 val_acc 0.9985885329740665
itr: 13140 trn_loss 0.05552833588667504 val_loss 4.090863618049072
itr: 13140 trn_acc 0.9999999975545869 trn_single_acc 0.9999999998714872 val_acc 0.9986871696440061
itr: 13150 trn_loss 0.05543634290733911 val_loss 3.7022386248049557
itr: 13150 trn_acc 0.9999999991473371 trn_single_acc 0.9999999999551903 val_acc 0.9987759426469517
itr

itr: 13600 trn_loss 0.05180480235794887 val_loss 0.1967202887155703
itr: 13600 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997666187332789
itr: 13610 trn_loss 0.05171996845015467 val_loss 0.19399091588420492
itr: 13610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997687018436241
itr: 13620 trn_loss 0.051639866667965356 val_loss 0.19161934244897505
itr: 13620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997705766429348
itr: 13630 trn_loss 0.051556257452759054 val_loss 0.189597920372485
itr: 13630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997722639623144
itr: 13640 trn_loss 0.05147630457516833 val_loss 0.1878840200254907
itr: 13640 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997737825497561
itr: 13650 trn_loss 0.05139183030084355 val_loss 0.1863723312750092
itr: 13650 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999775149278

itr: 14090 trn_loss 0.047670185110839876 val_loss 0.18700580231605696
itr: 14090 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997547613240869
itr: 14100 trn_loss 0.04758209581199567 val_loss 0.1874215134882217
itr: 14100 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997527134410374
itr: 14110 trn_loss 0.04749814284590103 val_loss 0.1878800596806364
itr: 14110 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997508703462928
itr: 14120 trn_loss 0.04741380948947826 val_loss 0.18848916941997024
itr: 14120 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997492115610227
itr: 14130 trn_loss 0.04732613308387101 val_loss 0.18918408352836918
itr: 14130 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997477186542796
itr: 14140 trn_loss 0.047242007682456316 val_loss 0.18975766955648318
itr: 14140 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999746375

itr: 14580 trn_loss 0.04335788943692678 val_loss 0.18226250879733114
itr: 14580 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997898144183719
itr: 14590 trn_loss 0.04326848100775481 val_loss 0.1803943935582192
itr: 14590 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997948887340573
itr: 14600 trn_loss 0.04318715700730577 val_loss 0.17864209922097488
itr: 14600 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997994556181742
itr: 14610 trn_loss 0.04309394357599637 val_loss 0.17707102756448714
itr: 14610 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997982550400298
itr: 14620 trn_loss 0.04300951227196584 val_loss 0.1757812717027795
itr: 14620 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.9997971745197
itr: 14630 trn_loss 0.04291511372148821 val_loss 0.1747554318726657
itr: 14630 trn_acc 0.9999999999999994 trn_single_acc 0.9999999999999994 val_acc 0.999796202051403

In [None]:
correct_percent.shape, percent_fully_correct.shape

In [None]:
10, None, 7, 9, None

In [None]:
cmd_np.shape, act_np.shape, mask_np.shape, struct_np.shape, cmd_lengths_np.shape

In [None]:
cmd_ind.shape, act_ind.shape, mask_ph.shape, act_lengths.shape, cmd_lengths.shape

In [None]:
trn_samples.shape, val_samples.shape

In [None]:

act_presoftmax = tf.stack(action_probabilities_presoftmax, 1)[:, :, :-1, :]
#batch, subprogram, timestep, action_selection
logprobabilities = tf.nn.log_softmax(act_presoftmax, -1)
act_presoftmax_flat = tf.reshape(act_presoftmax, [-1, 9, num_act])
mask_ph_flat = tf.reshape(mask_ph, [-1, max_actions_per_subprogram])
act_ind_flat = tf.reshape(act_ind, [-1, max_actions_per_subprogram])

In [None]:
act_presoftmax_flat = tf.reshape(act_presoftmax, [-1, 9, num_act])

In [None]:
max_actions_per_subprogram

In [None]:
sess.run(act_presoftmax, feed_dict).shape

In [None]:
action_map

In [None]:
print(*actions_ind[2])

In [None]:
command_map

In [None]:
action_map

In [None]:
subprogram_output

In [None]:
subprogram_last_layer[:,cmd_lengths,:]

In [None]:
encoding_last_layer

In [None]:
tf.gather(
    encoding_last_layer,
    [1,2],
    axis=1
)

In [None]:
tf.gather_nd(
    encoding_last_layer,
    np.array([[0,1,2,3,4], [1,4,3,2,5]]).T,
    name=None
)

In [None]:
cmd_lengths

In [None]:
def generate_command(sub_cmd, num_repeat):
    return sub_cmd * num_repeat

In [None]:
def process_command(cmd):
    

In [None]:
uni_commands

In [None]:
uni_actions

In [None]:
uni_tokens

In [None]:
df.shape