In [1]:
import numpy as np
import tensorflow as tf

In [2]:
import utils
from dataset import Dataset

In [3]:
word_emb_size = 64
emb_size = 256
cell_size = emb_size
output_size = emb_size
word_size = 64
words_num = 128
read_heads = 4
interface_size = word_size * read_heads + word_size * 3 + 5 * read_heads + 3

In [4]:
batch_size = 64
learning_rate = 0.001
epoches = 10000 // batch_size

In [5]:
DataPath = "/notebooks/Share/dataset/babi/processed/"
SummariesDir = "/notebooks/WorkDir/tensorlog"
CheckPtDir = "./chpts"

In [6]:
filenames = [
    'qa1_single-supporting-fact',
    'qa20_agents-motivations',
    'qa15_basic-deduction',
    'qa16_basic-induction',
    'qa9_simple-negation',
    'qa4_two-arg-relations',
    'qa6_yes-no-questions',
    'qa10_indefinite-knowledge',
    'qa11_basic-coreference',
    'qa12_conjunction',
    'qa13_compound-coreference',
    'qa14_time-reasoning',
    'qa17_positional-reasoning',
    'qa18_size-reasoning',
    'qa19_path-finding',
    'qa7_counting',
    'qa8_lists-sets']

In [7]:
def parse_interface_vector(interface_vector):
    """
        pasres the flat interface_vector into its various components with their
        correct shapes
        Parameters:
        ----------
        interface_vector: Tensor (batch_size, interface_vector_size)
            the flattened inetrface vector to be parsed
        Returns: dict
            a dictionary with the components of the interface_vector parsed
    """

    parsed = {}

    r_keys_end = word_size * read_heads
    r_strengths_end = r_keys_end + read_heads
    w_key_end = r_strengths_end + word_size
    erase_end = w_key_end + 1 + word_size
    write_end = erase_end + word_size
    free_end = write_end + read_heads

    r_keys_shape = (-1, word_size, read_heads)
    r_strengths_shape = (-1,read_heads)
    w_key_shape = (-1, word_size, 1)
    write_shape = erase_shape = (-1,word_size)
    free_shape = (-1, read_heads)
    modes_shape = (-1, 3, read_heads)

    # parsing the vector into its individual components
    parsed['read_keys'] = tf.reshape(interface_vector[:, :r_keys_end], r_keys_shape)
    parsed['read_strengths'] = tf.reshape(interface_vector[:, r_keys_end:r_strengths_end], r_strengths_shape)
    parsed['write_key'] = tf.reshape(interface_vector[:, r_strengths_end:w_key_end], w_key_shape)
    parsed['write_strength'] = tf.reshape(interface_vector[:, w_key_end], (-1, 1))
    parsed['erase_vector'] = tf.reshape(interface_vector[:, w_key_end + 1:erase_end], erase_shape)
    parsed['write_vector'] = tf.reshape(interface_vector[:, erase_end:write_end], write_shape)
    parsed['free_gates'] = tf.reshape(interface_vector[:, write_end:free_end], free_shape)
    parsed['allocation_gate'] = tf.expand_dims(interface_vector[:, free_end], 1)
    parsed['write_gate'] = tf.expand_dims(interface_vector[:, free_end + 1], 1)
    parsed['read_modes'] = tf.reshape(interface_vector[:, free_end + 2:], modes_shape)

    # transforming the components to ensure they're in the right ranges
    parsed['read_strengths'] = 1 + tf.nn.softplus(parsed['read_strengths'])
    parsed['write_strength'] = 1 + tf.nn.softplus(parsed['write_strength'])
    parsed['erase_vector'] = tf.nn.sigmoid(parsed['erase_vector'])
    parsed['free_gates'] = tf.nn.sigmoid(parsed['free_gates'])
    parsed['allocation_gate'] = tf.nn.sigmoid(parsed['allocation_gate'])
    parsed['write_gate'] = tf.nn.sigmoid(parsed['write_gate'])
    parsed['read_modes'] = tf.nn.softmax(parsed['read_modes'], 1)

    return parsed

In [8]:
def __input_stage__(op, last_read_vectors):
    """
        processes input data through the controller network and returns the
        pre-output and interface_vector
        Parameters:
        ----------
        X: Tensor (batch_size, input_size)
            the input data batch
        last_read_vectors: (batch_size, word_size, read_heads)
            the last batch of read vectors from memory
        state: Tuple
            state vectors if the network is recurrent
        Returns: Tuple
            pre-output: Tensor (batch_size, output_size)
            parsed_interface_vector: dict
    """

    flat_read_vectors = tf.reshape(
        last_read_vectors,
        (-1, word_size * read_heads))
    cell_input = tf.concat(1, [op, flat_read_vectors])
    
    input_weights = tf.get_variable(
        "input_weights", [word_size * read_heads + emb_size, cell_size])
 
    return tf.matmul(cell_input, input_weights)

In [9]:
def __rnn_stage__(cell_input, state, cell):
    return cell(cell_input, state)

In [10]:
def __interface_stage__(cell_output):
    interface_weights = tf.get_variable(
        "interface_weights", [cell_size, interface_size])
    
    interface = tf.matmul(cell_output, interface_weights)
   
    return parse_interface_vector(interface)

In [11]:
def __output_stage__(cell_output, new_read):
    output_weights = tf.get_variable(
        "output_weights", [cell_size, output_size])
    
    mem_output_weights = tf.get_variable(
        "mem_output_weights", [word_size * read_heads, output_size])
    
    flat_read = tf.reshape(new_read, (-1, word_size * read_heads))
    
    controller_output = tf.matmul(cell_output, output_weights)
    memory_output = tf.matmul(flat_read, mem_output_weights)
    return controller_output + memory_output

In [12]:
def DNC_controller_exe_query(cur_inp, last_state, last_reads, cell):
    
    with tf.variable_scope("controller_input"):
        cell_input = __input_stage__(cur_inp, last_reads)
        
    with tf.variable_scope("controller_cell"):
        cell_output, cell_state = __rnn_stage__(cell_input, last_state, cell)
        
    with tf.variable_scope("controller_interface"):
        interface_dict = __interface_stage__(cell_output)
        
    return cell_output, cell_state, interface_dict

In [13]:
def DNC_controller_wb_output(cell_output, mem_reads):
    with tf.variable_scope("controller_output"):
        DNC_output = __output_stage__(cell_output, mem_reads)
        
    return DNC_output

In [14]:
def __get_memory__():
    with tf.variable_scope("memory_banks"):
        return (
            tf.get_variable("memory_matrix", [batch_size, words_num, word_size],
                initializer=tf.constant_initializer(1e-6), trainable=False),
            tf.get_variable("usage_vector", [batch_size, words_num],
                initializer=tf.constant_initializer(0), trainable=False),
            tf.get_variable("precedence_vector", [batch_size, words_num],
                initializer=tf.constant_initializer(0), trainable=False),
            tf.get_variable("link_matrix", [batch_size, words_num, words_num],
                initializer=tf.constant_initializer(0), trainable=False),
            tf.get_variable("write_weights", [batch_size, words_num],
                initializer=tf.constant_initializer(1e-6), trainable=False),
            tf.get_variable("read_weights", [batch_size, words_num, read_heads],
                initializer=tf.constant_initializer(1e-6), trainable=False),
            tf.get_variable("read_vector", [batch_size, word_size, read_heads],
                initializer=tf.constant_initializer(1e-6), trainable=False),
        )

In [15]:
def __get_lookup_weighting__(mem_mat, keys, strengths):
    normalized_memory = tf.nn.l2_normalize(mem_mat, 2)
    normalized_keys = tf.nn.l2_normalize(keys, 1)
    
    correlation = tf.matmul(normalized_memory, normalized_keys)
    strengths = tf.expand_dims(strengths, 1)
    
    return tf.nn.softmax(correlation * strengths, 1)

In [16]:
def __update_usage_vector__(usage_vec, read_weightings, write_weighting, free_gates):
    free_gates = tf.expand_dims(free_gates, 1)
    
    retention_vector = tf.reduce_prod(1 - read_weightings * free_gates, 2)
    updated_usage = (usage_vec + write_weighting - usage_vec * write_weighting) * retention_vector
    
    return updated_usage

In [17]:
def __get_allocation_weighting__(sorted_usage, free_list):
    shifted_cumprod = tf.cumprod(sorted_usage, axis = 1, exclusive=True)
    unordered_allocation_weighting = (1 - sorted_usage) * shifted_cumprod

    mapped_free_list = free_list + index_mapper
    flat_unordered_allocation_weighting = tf.reshape(unordered_allocation_weighting, (-1,))
    flat_mapped_free_list = tf.reshape(mapped_free_list, (-1,))
    flat_container = tf.TensorArray(tf.float32, batch_size *words_num)

    flat_ordered_weightings = flat_container.scatter(
        flat_mapped_free_list,
        flat_unordered_allocation_weighting
    )

    packed_wightings = flat_ordered_weightings.pack()
    return tf.reshape(packed_wightings, (batch_size, words_num))

In [18]:
def __update_write_weighting__(lookup_weighting, alloc_weighting, write_gate, alloc_gate):
    lookup_weighting = tf.squeeze(lookup_weighting)
    
    updated_write_weighting = write_gate * (alloc_gate * alloc_weighting + (1 - alloc_gate) * lookup_weighting)
    
    return updated_write_weighting

In [19]:
def __update_memory__(mem_mat, write_weighting, write_vec, erase_vec):
    write_weighting = tf.expand_dims(write_weighting, 2)
    write_vec = tf.expand_dims(write_vec, 1)
    erase_vec = tf.expand_dims(erase_vec, 1)
    
    erasing = mem_mat * (1 - tf.matmul(write_weighting, erase_vec))
    writing = tf.matmul(write_weighting, write_vec)
    updated_memory = erasing + writing
    
    return updated_memory

In [20]:
def __update_precedence_vector__(last_precedence, write_weighting):
    reset_factor = 1 - tf.reduce_sum(write_weighting, 1, keep_dims=True)
    updated_precedence = reset_factor * last_precedence + write_weighting
    
    return updated_precedence

In [21]:
def __update_link_matrix__(precedence_vec, link_matrix, write_weighting):
    write_weighting = tf.expand_dims(write_weighting, 2)
    precedence_vec = tf.expand_dims(precedence_vec, 1)
    
    reset_factor = 1 - utils.pairwise_add(write_weighting, is_batch=True)
    updated_link_matrix = reset_factor * link_matrix + tf.matmul(write_weighting, precedence_vec)
    updated_link_matrix = (1 - IMat) * updated_link_matrix
    
    return updated_link_matrix

In [22]:
def __get_directional_weightings__(read_weightings, link_matrix):
    forward_weighting = tf.matmul(link_matrix, read_weightings)
    backward_weighting = tf.batch_matmul(link_matrix, read_weightings, adj_x=True)
    
    return forward_weighting, backward_weighting

In [23]:
def __update_read_weightings__(lookup_weightings, forward_weighting, backward_weighting, read_mode):
    backward_mode = tf.expand_dims(read_mode[:, 0, :], 1) * backward_weighting
    lookup_mode = tf.expand_dims(read_mode[:, 1, :], 1) * lookup_weightings
    forward_mode = tf.expand_dims(read_mode[:, 2, :], 1) * forward_weighting
    updated_read_weightings = backward_mode + lookup_mode + forward_mode

    return updated_read_weightings

In [24]:
def __update_read_vectors__(mem_mat, read_weightings):
    updated_read_vectors = tf.batch_matmul(mem_mat, read_weightings, adj_x=True)
    return updated_read_vectors

In [25]:
def write_mem(mem_mat, usage_vector, read_weightings, write_weighting, precedence_vector,
             link_matrix, key, strength, free_gates, alloc_gate, write_gate, write_vec, erase_vec):
    
    lookup_weighting = __get_lookup_weighting__(mem_mat, key, strength)
    new_usage_vector = __update_usage_vector__(usage_vector, read_weightings, write_weighting, free_gates)

    sorted_usage, free_list = tf.nn.top_k(-1 * new_usage_vector, words_num)
    sorted_usage = -1 * sorted_usage

    alloc_weighting = __get_allocation_weighting__(sorted_usage, free_list)
    new_write_weighting = __update_write_weighting__(lookup_weighting, alloc_weighting, write_gate, alloc_gate)
    new_memory_matrix = __update_memory__(mem_mat, new_write_weighting, write_vec, erase_vec)
    new_link_matrix = __update_link_matrix__(precedence_vector, link_matrix, new_write_weighting)
    new_precedence_vector = __update_precedence_vector__(precedence_vector, new_write_weighting)

    return new_usage_vector, new_write_weighting, new_memory_matrix, new_link_matrix, new_precedence_vector

In [26]:
def read_mem(memory_matrix, read_weightings, keys, strengths, link_matrix, read_modes):
    lookup_weighting = __get_lookup_weighting__(memory_matrix, keys, strengths)
    forward_weighting, backward_weighting = __get_directional_weightings__(read_weightings, link_matrix)
    new_read_weightings = __update_read_weightings__(lookup_weighting, forward_weighting, backward_weighting, read_modes)
    new_read_vectors = __update_read_vectors__(memory_matrix, new_read_weightings)

    return new_read_weightings, new_read_vectors

In [27]:
def graph_step(cur_inp, memory_bank, controller_state, cell, time):
    
    last_read = memory_bank[6]
    cell_output, cell_state, interface = DNC_controller_exe_query(
        cur_inp, controller_state, last_read, cell)
    
    mem_mat, usage_vec, pred_vec, link_mat, wrt_wght, rd_wght, rd_vec = memory_bank
    
    usage_vec, wrt_wght, mem_mat, link_mat, pred_vec = write_mem(
    mem_mat, usage_vec, rd_wght, wrt_wght, pred_vec, link_mat,
    interface['write_key'], interface['write_strength'], interface['free_gates'],
    interface['allocation_gate'], interface['write_gate'], interface['write_vector'],
    interface['erase_vector'])
    
    rd_wght, rd_vec = read_mem(
    mem_mat, rd_wght, interface['read_keys'], interface['read_strengths'],
    link_mat, interface['read_modes'])
    
    unit_output = DNC_controller_wb_output(cell_output, rd_vec)
    
    return (
        (mem_mat, usage_vec, pred_vec, link_mat, wrt_wght, rd_wght, rd_vec),
        interface['free_gates'], interface['allocation_gate'], interface['write_gate'],
        rd_wght, wrt_wght, usage_vec, unit_output, cell_state, time + 1)

In [28]:
def build_graph(stories, query, decoder_inputs, loss_labels, meta, keep_prob, prob_list):
    
    with tf.variable_scope("embedding"):
        embedding_weights = tf.get_variable(
            "embedding", [meta["vocab_size"], word_emb_size])
    
    embedded_stories = tf.nn.embedding_lookup(embedding_weights, stories)
    embedded_query = tf.nn.embedding_lookup(embedding_weights, query)
    
    sentence_emb = []
        
    # sentence embedding
    with tf.variable_scope("sentence_encoder") as scope:
        cell = tf.nn.rnn_cell.DropoutWrapper(
            tf.nn.rnn_cell.LSTMCell(word_emb_size, use_peepholes=True),
            output_keep_prob=keep_prob)
        
        init_state = cell.zero_state(batch_size, tf.float32)
        
        output_mat = tf.get_variable("output_mat", [word_emb_size, emb_size])
        
        for sub_id in range(meta['max_story_length']):
            if sub_id > 0: scope.reuse_variables()
            sen_emb_out, sen_emb_state = tf.nn.dynamic_rnn(
                cell, embedded_stories[:, sub_id, :, :], 
                sequence_length=[meta["max_sentence_length"]] * batch_size, 
                initial_state=init_state)
            
            sentence_emb.append(
                tf.matmul(tf.squeeze(sen_emb_out[:, -1, :]), output_mat))
        
        scope.reuse_variables()
            
        query_emb_out, query_emb_state = tf.nn.dynamic_rnn(
                cell, embedded_query, 
                sequence_length=[meta["max_query_length"]] * batch_size, 
                initial_state=init_state)
        
        query_emb = tf.matmul(tf.squeeze(query_emb_out[:, -1, :]), output_mat)


    DNC_out = []
        
    with tf.variable_scope("DNC") as scope:
        
        controller_cell = tf.nn.rnn_cell.DropoutWrapper(
            tf.nn.rnn_cell.LSTMCell(emb_size, use_peepholes=True),
            output_keep_prob=keep_prob)
        
        # encode stories
        for time in range(meta["max_story_length"]):
            if time == 0:
                memory_bank = __get_memory__()
                controller_state = controller_cell.zero_state(
                    batch_size, dtype=tf.float32)
            else:
                scope.reuse_variables()
            
            cur_inp = sentence_emb[time]

            DNC_out.append(
                graph_step(
                    cur_inp, memory_bank, controller_state, 
                    controller_cell, time))
            
            memory_bank = DNC_out[-1][0]
            controller_state = DNC_out[-1][-2]
         
        # encode query
        scope.reuse_variables()
        query_out = graph_step(query_emb, memory_bank, controller_state, controller_cell, time)
        
        
    # decoder answers
    with tf.variable_scope("decoder") as scope:
        
        decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(
            tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(
                    word_emb_size, use_peepholes=True),
                output_keep_prob=keep_prob),
            meta["vocab_size"])
            

        proj_to_c = tf.get_variable("proj_to_c", [output_size, word_emb_size])
        proj_to_h = tf.get_variable("proj_to_h", [output_size, word_emb_size])

        init_state = tf.nn.rnn_cell.LSTMStateTuple(
            tf.matmul(query_out[-3], proj_to_c),
            tf.matmul(query_out[-3], proj_to_h))
        
        prob_list.append(decoder_inputs)
        prob_list.append(init_state)
        
        decode_outputs, decode_states = tf.nn.seq2seq.embedding_rnn_decoder(
            decoder_inputs, init_state, decoder_cell, meta["vocab_size"], word_emb_size,
            output_projection=None, feed_previous=False, update_embedding_for_previous=True,
            scope=None)
                
        scope.reuse_variables()
                
        decode_outputs_with_exposure, decode_states_with_exposure = tf.nn.seq2seq.embedding_rnn_decoder(
            decoder_inputs, init_state, decoder_cell, meta["vocab_size"], word_emb_size,
            output_projection=None, feed_previous=True, update_embedding_for_previous=True,
            scope=None)

    prob_list.append(decode_outputs)    
    
    loss_weights = [ 
        tf.ones_like([batch_size], dtype=tf.float32)
    ] * meta["max_ans_length"]
    loss = tf.nn.seq2seq.sequence_loss(
        decode_outputs, loss_labels, loss_weights)
    train_op = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(loss)
    
    tf.summary.scalar("loss", loss)
    test_output = tf.argmax(decode_outputs_with_exposure, axis=2)
            
    merged = tf.summary.merge_all()
    
    return train_op, loss, test_output, merged

In [29]:
def define_placeholders(meta):
    return {
        
        # story shape: [batch, story_length, sentence_length]
        "stories": tf.placeholder(tf.int32,
            shape=[batch_size, meta['max_story_length'], meta['max_sentence_length']]),
        
        # query shape: [batch, query_length]
        "query": tf.placeholder(tf.int32,
            shape=[batch_size, meta['max_query_length']]),
        
        # decoder input shape: [answer_length][batch]
        "decoder_inputs": [
            tf.zeros_like(tf.zeros([batch_size]), dtype=tf.int32, name="GO") 
        ] + [
            tf.placeholder(tf.int32, shape=[batch_size], name="decoder_inputs_{}".format(t))
            for t in range(meta["max_ans_length"] - 1)
        ],
        
        # loss_labels shape: [answer_length][batch]
        "loss_labels": [
            tf.placeholder(tf.int32, shape=[batch_size], name="loss_label_{}".format(t))
            for t in range(meta["max_ans_length"])
        ],
        
        # keep_prob: keep probability for LSTM cell dropout wrappers
        "keep_prob": tf.placeholder(tf.float32)
    }

In [30]:
def build_feeds(ph_dict, dataset, train_or_test="train"):
    batch = dataset.next_batch(train_or_test)
    meta = dataset.metadata
    story = np.concatenate(batch[:, 0].tolist(), axis=0).reshape(batch_size, meta['max_story_length'], meta['max_sentence_length'])
    query = np.concatenate(batch[:, 1].tolist(), axis=0).reshape(batch_size, -1)
    ans = [ np.squeeze(t) for t in np.split(np.transpose(np.concatenate(
        batch[:, 2].tolist(), axis=0).reshape(batch_size, -1)), dataset.metadata["max_ans_length"], axis=0)]
    dec_inp = ans[:-1]
    los_lbl = ans
    #sup = np.concatenate(batch[:, 3].tolist(), axis=0).reshape(batch_size, -1)
    keep_prob = 0.8 if train_or_test=="train" else 1.0
    
    feed_dict = {
        ph_dict["stories"]: story,
        ph_dict["query"]: query,
        ph_dict["keep_prob"]: keep_prob
    }
    
    for time_step in range(len(los_lbl)):
        if time_step < len(los_lbl) - 1:
            feed_dict[ph_dict['decoder_inputs'][time_step + 1]] = ans[time_step]
        feed_dict[ph_dict['loss_labels'][time_step]] = ans[time_step]
        
    return feed_dict

In [34]:
epoch_counter = 1
train_loss_curve = []
train_loss_curve_by_cat = {}
test_loss_curve = []
test_loss_curve_by_cat = {}

#filename = 'qa6_yes-no-questions'

for filename in filenames:    
    
    dataset = Dataset(DataPath + filename, batch_size)
    dataset.load_dataset()
    meta = dataset.metadata
    
    prob_list = []
    
    tf.reset_default_graph()
    graph = tf.Graph()
    
    print("graph reseted")
    
    with graph.as_default():
               
        # define constants
        index_mapper = tf.constant(
            np.cumsum([0] + [words_num] * (batch_size - 1), dtype=np.int32)[:, np.newaxis])

        IMat = tf.constant(np.identity(words_num, dtype=np.float32))
    
        # define placeholders
        placeholders = define_placeholders(meta)
    
        # define graph
        train_op, loss, test_output, merged = build_graph(
            placeholders['stories'],
            placeholders['query'],
            placeholders['decoder_inputs'],
            placeholders['loss_labels'],
            dataset.metadata,
            placeholders['keep_prob'],
            prob_list)
        
        saver = tf.train.Saver(tf.trainable_variables())
        
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            print(shape)
            variable_parametes = 1
            for dim in shape:
                variable_parametes *= dim.value
            print(variable_parametes)
            total_parameters += variable_parametes
        print(total_parameters)
        
        raise KeyboardInterrupt
        
        print("graph built")
    
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 1.0
        
        with tf.Session(graph=graph, config=config) as sess:
            sess.run(tf.global_variables_initializer())
        
            #if not filename == 'qa1_single-supporting-fact':
            saver.restore(sess, "./chkpts/DNC_bAbI_full.ckpt")
            epoch_counter = np.load("./counter.npy")[0]

            train_writer = tf.summary.FileWriter(
                SummariesDir + '/train')
            test_writer = tf.summary.FileWriter(
                SummariesDir + '/test')
                
            print("parameters loaded")
            
            # training loop
            for epoch in range(epoches):
                
                feed_dict = build_feeds(placeholders, dataset, "train")
                _, loss_val, merged_val = sess.run([train_op, loss, merged], feed_dict=feed_dict)
                
                train_loss_curve.append(loss_val)
                
                train_writer.add_summary(merged_val, epoch_counter)
                
                epoch_counter += 1
                
                if epoch % 20==0:
                    
                    feed_dict = build_feeds(placeholders, dataset, "test")
                    
                    loss_val, merged_val = sess.run([loss, merged], feed_dict=feed_dict)
                    
                    test_loss_curve.append(loss_val)
                    test_writer.add_summary(merged_val, epoch_counter)
                
                    print(test_loss_curve[-1])
                
                
            print(loss_val)   
            saver.save(sess, "./chkpts/DNC_bAbI_full.ckpt")
            
            np.save("./{}_train_curve_full_1.npy".format(filename), np.array(train_loss_curve))
            np.save("./{}_test_curve_full_1.npy".format(filename), np.array(test_loss_curve))
            train_loss_curve = []
            test_loss_curve = []
            np.save("./counter.npy", np.array([epoch_counter]))
            
            print("parameter saved")  
            print

graph reseted
(163, 64)
10432
(64, 256)
16384
(128, 256)
32768
(256,)
256
(64,)
64
(64,)
64
(64,)
64
(512, 256)
131072
(512, 1024)
524288
(1024,)
1024
(256,)
256
(256,)
256
(256,)
256
(256, 471)
120576
(256, 256)
65536
(256, 256)
65536
(256, 64)
16384
(256, 64)
16384
(163, 64)
10432
(128, 256)
32768
(256,)
256
(64,)
64
(64,)
64
(64,)
64
(64, 163)
10432
(163,)
163
1055843


KeyboardInterrupt: 

In [33]:
total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    print(shape)
    variable_parametes = 1
    for dim in shape:
        print(dim)
        variable_parametes *= dim.value
    print(variable_parametes)
    total_parameters += variable_parametes
print(total_parameters)

0


In [None]:
tf.