In [1]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
import sys
import subprocess
import pdb
import time
import datetime
import math
import random
import _pickle as cPickle
from collections import defaultdict

from six.moves import zip_longest
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import distributions as tfd
from tensorflow.keras.preprocessing.sequence import pad_sequences

from data_structure import get_batches, get_test_batches
from components import tf_log, sample_latents, compute_kl_loss, dynamic_rnn, dynamic_bi_rnn
from topic_model import TopicModel

from topic_beam_search_decoder import BeamSearchDecoder

In [2]:
PAD = '<pad>' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNK = '<unk>' # This has a vocab id, which is used to represent out-of-vocabulary words
BOS = '<p>' # This has a vocab id, which is used at the beginning of every decoder input sequence
EOS = '</p>' # This has a vocab id, which is used at the end of untruncated target sequences

# load data & set config

In [3]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(tf.flags.FLAGS)

flags = tf.app.flags

flags.DEFINE_string('gpu', '2', 'visible gpu')

flags.DEFINE_string('mode', 'train', 'set train or eval')

flags.DEFINE_string('data_path', 'data/bags/instances.pkl', 'path of data')
flags.DEFINE_string('modeldir', 'model/tglm_vae_tmp3', 'directory of model')
flags.DEFINE_string('modelname', 'bags', 'name of model')

flags.DEFINE_integer('epochs', 50, 'epochs')
flags.DEFINE_integer('batch_size', 64, 'number of sentences in each batch')
flags.DEFINE_integer('log_period', 500, 'valid period')

flags.DEFINE_string('opt', 'Adagrad', 'optimizer')
flags.DEFINE_float('lr', 0.5, 'lr')
flags.DEFINE_float('reg', 1., 'regularization term')
flags.DEFINE_float('grad_clip', 5., 'grad_clip')

flags.DEFINE_float('keep_prob', 0.8, 'dropout rate')
flags.DEFINE_float('word_keep_prob', 0.75, 'word dropout rate')

flags.DEFINE_bool('warmup', True, 'flg of warming up')
flags.DEFINE_integer('epochs_cycle', 5, 'number of epochs within a cycle')
flags.DEFINE_float('r_cycle', 0.5, 'proportion used to increase beta within a cycle')
flags.DEFINE_integer('warmup_topic', 0, 'warmup period for KL of topic')

flags.DEFINE_integer('beam_width', 2, 'beam_width')
flags.DEFINE_float('length_penalty_weight', 0.0, 'length_penalty_weight')

flags.DEFINE_integer('n_topic', 10, 'number of topic')
flags.DEFINE_integer('dim_hidden_bow', 256, 'dim of hidden bow')
flags.DEFINE_integer('dim_latent_bow', 32, 'dim of latent topic')
flags.DEFINE_integer('dim_emb', 256, 'dim_emb')
flags.DEFINE_integer('dim_hidden', 512, 'dim_hidden')
flags.DEFINE_integer('dim_hidden_topic', 512, 'dim_hidden_topic')
flags.DEFINE_integer('dim_latent', 32, 'dim_latent')
flags.DEFINE_bool('bidirectional', True, 'flg of bidirectional encoding')

# for evaluation
flags.DEFINE_string('refdir', 'ref', 'refdir')
flags.DEFINE_string('outdir', 'out', 'outdir')

flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_bool('logtostderr', True, 'kernel')
flags.DEFINE_bool('showprefixforinfo', False, '')
flags.DEFINE_bool('verbosity', False, '')
# flags.DEFINE_integer('stderrthreshold', 20, 'kernel')

config = flags.FLAGS

flags.DEFINE_string('modelpath', os.path.join(config.modeldir, config.modelname), 'path of model')

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

In [5]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.data_path,'rb'))

In [6]:
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_test_batches(instances_test, config.batch_size)

In [7]:
flags.DEFINE_integer('PAD_IDX', word_to_idx[PAD], 'PAD_IDX')
flags.DEFINE_integer('UNK_IDX', word_to_idx[UNK], 'UNK_IDX')
flags.DEFINE_integer('BOS_IDX', word_to_idx[BOS], 'BOS_IDX')
flags.DEFINE_integer('EOS_IDX', word_to_idx[EOS], 'EOS_IDX')

flags.DEFINE_integer('n_vocab', len(word_to_idx), 'n_vocab')
flags.DEFINE_integer('dim_bow', len(bow_idxs), 'dim_bow')

maximum_iterations = max([max([instance.max_sent_l for instance in batch]) for ct, batch in dev_batches])
flags.DEFINE_integer('maximum_iterations', maximum_iterations, 'maximum_iterations')

flags.DEFINE_integer('cycle_steps', len(train_batches)*config.epochs_cycle, 'number of steps for each cycle')

# build language model 

## feed dict

In [8]:
tf.reset_default_graph()

t_variables = {}
t_variables['bow'] = tf.placeholder(tf.float32, [None, config.dim_bow], name='bow')
t_variables['input_token_idxs'] = tf.placeholder(tf.int32, [None, None], name='input_token_idxs')
t_variables['dec_input_idxs'] = tf.placeholder(tf.int32, [None, None], name='dec_input_idxs')
t_variables['dec_target_idxs'] = tf.placeholder(tf.int32, [None, None], name='dec_target_idxs')
t_variables['batch_l'] = tf.placeholder(tf.int32, name='batch_l')
t_variables['doc_l'] = tf.placeholder(tf.int32, [None], name='doc_l')
t_variables['sent_l'] = tf.placeholder(tf.int32, [None], name='sent_l')
t_variables['keep_prob'] = tf.placeholder(tf.float32, name='keep_prob')

In [9]:
def get_feed_dict(batch, mode='train', assertion=False):
    def token_dropout(sent_idxs):
        sent_idxs_dropout = np.asarray(sent_idxs)
        sent_idxs_dropout[np.random.rand(len(sent_idxs)) > config.word_keep_prob] = config.UNK_IDX
        return list(sent_idxs_dropout)

    bow = np.array([instance.bow for instance in batch]).astype(np.float32)
    
    doc_l = np.array([len(instance.token_idxs) for instance in batch])
    
    feed_input_token_idxs_list = [sent_idxs for instance in batch for sent_idxs in instance.token_idxs]
    feed_dec_input_idxs_list = [[config.BOS_IDX] + token_dropout(sent_idxs) for sent_idxs in feed_input_token_idxs_list]
    feed_dec_target_idxs_list = [sent_idxs + [config.EOS_IDX]  for sent_idxs in feed_input_token_idxs_list]
        
    sent_l = np.array([len(sent_idxs) for sent_idxs in feed_input_token_idxs_list], np.int32)
    batch_l = len(sent_l)
    
    feed_input_token_idxs = pad_sequences(feed_input_token_idxs_list, padding='post', value=config.PAD_IDX, dtype=np.int32)
    feed_dec_input_idxs = pad_sequences(feed_dec_input_idxs_list, padding='post', value=config.PAD_IDX, dtype=np.int32)
    feed_dec_target_idxs = pad_sequences(feed_dec_target_idxs_list, padding='post', value=config.PAD_IDX, dtype=np.int32)
    
    if assertion:
        index = 0
        for instance in batch:
            for line_idxs in instance.token_idxs:
                assert feed_input_token_idxs_list[index] == line_idxs
                index += 1
        assert feed_input_token_idxs.shape[1] == np.max(sent_l)
        assert feed_dec_input_idxs.shape[1] == np.max(sent_l) + 1
        assert feed_dec_target_idxs.shape[1] == np.max(sent_l) + 1
    
    keep_prob = config.keep_prob if mode == 'train' else 1.0

    feed_dict = {
                t_variables['bow']: bow, 
                t_variables['batch_l']: batch_l, t_variables['doc_l']: doc_l, t_variables['sent_l']: sent_l, 
                t_variables['input_token_idxs']: feed_input_token_idxs, t_variables['dec_input_idxs']: feed_dec_input_idxs, t_variables['dec_target_idxs']: feed_dec_target_idxs, 
                t_variables['keep_prob']: keep_prob
    }
    return  feed_dict

In [10]:
def debug_shape(variables):
    sample_batch = dev_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    if return_value: 
        return _variables
    else:
        for _variable, variable in zip(_variables, variables):
            if hasattr(variable, 'name'):
                print(variable.name, ':', _variable)
            else:
                print(_variable)
                
def check_shape(variables):
    if 'sess' in globals(): raise
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)
            
    sess.close()
    
def check_value(variables):
    if 'sess' in globals(): raise
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable)
        else:
            print(_variable.shape)
            
    sess.close()    
    
# sent_loss_kl_categ_tmp = tf.reduce_mean(tf.reduce_sum(tf.multiply(prob_topic_infer, tf_log(prob_topic_infer/prob_topic_sents)), 1))
# debug_value([sent_loss_kl_categ, sent_loss_kl_categ_tmp])
# sent_loss_kl_gauss_tmp = 0.5 * tf.reduce_sum(tf.exp(logvars_topic_infer-logvars_topic) + tf.square(means_topic - means_topic_infer) / tf.exp(logvars_topic) - 1 + (logvars_topic - logvars_topic_infer), -1)
# sent_loss_kl_gmm_tmp = tf.reduce_mean(tf.reduce_sum(tf.multiply(prob_topic_infer, sent_loss_kl_gauss_tmp), -1))
# debug_value([sent_loss_kl_gmm_tmp, sent_loss_kl_gmm])    

## topic model

In [11]:
# encode bow
with tf.variable_scope('topic/enc', reuse=False):
    hidden_bow_ = tf.layers.Dense(units=config.dim_hidden_bow, activation=tf.nn.relu, name='hidden_bow')(t_variables['bow'])
    hidden_bow = tf.layers.Dropout(t_variables['keep_prob'])(hidden_bow_)
    means_bow = tf.layers.Dense(units=config.dim_latent_bow, name='mean_bow')(hidden_bow)
    logvars_bow = tf.layers.Dense(units=config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_topic')(hidden_bow)
    latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors

    prob_topic = tf.layers.Dense(units=config.n_topic, activation=tf.nn.softmax, name='prob_topic')(latents_bow) # inference of topic probabilities

# decode bow
with tf.variable_scope('shared', reuse=False):
    embeddings = tf.get_variable('emb', [config.n_vocab, config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of vocab

bow_embeddings = tf.nn.embedding_lookup(embeddings, bow_idxs) # embeddings of each bow features

with tf.variable_scope('topic/dec', reuse=False):
    topic_embeddings = tf.get_variable('topic_emb', [config.n_topic, config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of topics

    topic_bow = tf.nn.softmax(tf.matmul(topic_embeddings, bow_embeddings, transpose_b=True), 1) # bow vectors for each topic
    logits_bow = tf_log(tf.matmul(prob_topic, topic_bow)) # predicted bow distribution

    # prior of each gaussian distribution (computed for each topic)
    hidden_topic = tf.layers.Dense(units=config.dim_hidden_topic, activation=tf.nn.relu, name='hidden_topic')(topic_bow)
    means_topic = tf.layers.Dense(units=config.dim_latent, name='mean_topic')(hidden_topic)
    logvars_topic = tf.layers.Dense(units=config.dim_latent, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_topic')(hidden_topic)
    sigma_topic = tf.exp(0.5 * logvars_topic)
    gauss_topic = tfd.Normal(loc=means_topic, scale=sigma_topic)    
    
# define losses
topic_losses_recon = -tf.reduce_sum(tf.multiply(t_variables['bow'], logits_bow), 1)
topic_loss_recon = tf.reduce_mean(topic_losses_recon) # negative log likelihood of each words

topic_loss_kl = compute_kl_loss(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std

topic_bow_norm = topic_bow / tf.norm(topic_bow, axis=1, keepdims=True)
topic_dots = tf.clip_by_value(tf.matmul(topic_bow_norm, tf.transpose(topic_bow_norm)), -1., 1.)
topic_loss_reg = tf.reduce_mean(tf.square(topic_dots - tf.eye(config.n_topic)))
# topic_angles = tf.acos(topic_dots)
# topic_angles_mean = tf.reduce_mean(topic_angles)
# topic_angles_vars = tf.reduce_mean(tf.square(topic_angles - topic_angles_mean))
# topic_loss_reg = tf.exp(topic_angles_vars - topic_angles_mean)

# monitor
n_bow = tf.reduce_sum(t_variables['bow'], 1)
topic_ppls = tf.divide(topic_losses_recon, tf.maximum(1e-5, n_bow))
topics_freq_bow_indices = tf.nn.top_k(topic_bow, 10, name='topic_freq_bow').indices

## encoder

In [12]:
# input
input_token_idxs = t_variables['input_token_idxs']
batch_l = t_variables['batch_l']
sent_l = t_variables['sent_l']
max_sent_l = tf.reduce_max(sent_l)

with tf.variable_scope('sent/enc', reuse=False):
    # get word embedding
    enc_input = tf.nn.embedding_lookup(embeddings, input_token_idxs)

    # get sentence embedding
    _, enc_state = dynamic_bi_rnn(enc_input, sent_l, config.dim_hidden, t_variables['keep_prob'])

    # TODO House Holder flow
    hidden_topic_infer =  tf.layers.Dense(units=config.dim_hidden, activation=tf.nn.relu, name='hidden_topic_infer')(enc_state)
    prob_topic_infer = tf.layers.Dense(units=config.n_topic, activation=tf.nn.softmax, name='prob_topic_infer')(hidden_topic_infer)

    w_mean_topic_infer = tf.get_variable('mean_topic_infer/kernel', [config.n_topic, enc_state.shape[-1], config.dim_latent], dtype=tf.float32)
    b_mean_topic_infer = tf.get_variable('mean_topic_infer/bias', [1, config.n_topic, config.dim_latent], dtype=tf.float32)
    means_topic_infer = tf.tensordot(enc_state, w_mean_topic_infer, axes=[[1], [1]]) + b_mean_topic_infer
    
    w_logvar_topic_infer = tf.get_variable('logvar_topic_infer/kernel', [config.n_topic, enc_state.shape[-1], config.dim_latent], dtype=tf.float32, initializer=tf.constant_initializer(0))
    b_logvar_topic_infer = tf.get_variable('logvar_topic_infer/bias', [1, config.n_topic, config.dim_latent], dtype=tf.float32, initializer=tf.constant_initializer(0))
    logvars_topic_infer = tf.tensordot(enc_state, w_logvar_topic_infer, axes=[[1], [1]]) + b_logvar_topic_infer
    sigma_topic_infer = tf.exp(0.5 * logvars_topic_infer)
    gauss_topic_infer = tfd.Normal(loc=means_topic_infer, scale=sigma_topic_infer)
    
    # latent vectors from each gaussian dist.
    latents_topic_infer = sample_latents(means_topic_infer, logvars_topic_infer) 
    # latent vector from gaussian mixture
    latents_input = tf.matmul(tf.expand_dims(prob_topic_infer, -1), latents_topic_infer, transpose_a=True)
    
    # for beam search
    means_input = tf.matmul(tf.expand_dims(prob_topic_infer, -1), means_topic_infer, transpose_a=True)    

## decoder

In [13]:
# prepare for decoding
dec_sent_l = tf.add(sent_l, 1)
dec_input_idxs = t_variables['dec_input_idxs']
dec_input = tf.nn.embedding_lookup(embeddings, dec_input_idxs)

dec_latents_input = tf.tile(latents_input, [1, tf.shape(dec_input)[1], 1])
dec_concat_input = tf.concat([dec_input, dec_latents_input], -1)

# decode for training
with tf.variable_scope('sent/dec/rnn', initializer=tf.contrib.layers.xavier_initializer(), dtype = tf.float32, reuse=False):
    dec_cell = tf.contrib.rnn.GRUCell(config.dim_hidden)
    dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell, output_keep_prob = t_variables['keep_prob'])

    dec_initial_state = tf.layers.Dense(units=config.dim_hidden, activation=tf.nn.relu, name='init_state')(tf.squeeze(latents_input, 1))
    
    helper = tf.contrib.seq2seq.TrainingHelper(inputs=dec_concat_input, sequence_length=dec_sent_l)

    train_decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=dec_cell,
        helper=helper,
        initial_state=dec_initial_state)

    dec_outputs, _, output_sent_l = tf.contrib.seq2seq.dynamic_decode(train_decoder)
    
    output_layer = tf.layers.Dense(config.n_vocab, use_bias=False, name='out')
    output_logits = output_layer(dec_outputs.rnn_output)
    
    output_token_idxs = tf.argmax(output_logits, 2)

In [14]:
start_tokens = tf.fill([batch_l], config.BOS_IDX)
end_token = config.EOS_IDX

with tf.variable_scope('sent/dec/rnn', reuse=True):
    infer_dec_initial_state = tf.layers.Dense(units=config.dim_hidden, activation=tf.nn.relu, name='init_state')(tf.squeeze(means_input, 1))
    beam_dec_initial_state = tf.contrib.seq2seq.tile_batch(infer_dec_initial_state, multiplier=config.beam_width)
    beam_latents_input = tf.contrib.seq2seq.tile_batch(tf.squeeze(means_input, 1), multiplier=config.beam_width) # added
    
    beam_decoder = BeamSearchDecoder(
        cell=dec_cell,
        embedding=embeddings,
        start_tokens=start_tokens,
        end_token=end_token,
        initial_state=beam_dec_initial_state,
        beam_width=config.beam_width, 
        output_layer=output_layer,
        length_penalty_weight=config.length_penalty_weight,
        latents_input=beam_latents_input)

    beam_dec_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        beam_decoder,
        maximum_iterations = config.maximum_iterations)

    beam_output_token_idxs = beam_dec_outputs.predicted_ids[:, :, 0]

In [15]:
with tf.variable_scope('sent/dec/rnn', reuse=True):
    inter_means_input = tf.placeholder(tf.float32, [None, config.dim_latent])
    
    inter_dec_initial_state = tf.layers.Dense(units=config.dim_hidden, activation=tf.nn.relu, name='init_state')(inter_means_input)
    inter_beam_dec_initial_state = tf.contrib.seq2seq.tile_batch(inter_dec_initial_state, multiplier=config.beam_width)
    inter_beam_latents_input = tf.contrib.seq2seq.tile_batch(inter_means_input, multiplier=config.beam_width) # added
    
    inter_beam_decoder = BeamSearchDecoder(
        cell=dec_cell,
        embedding=embeddings,
        start_tokens=start_tokens,
        end_token=end_token,
        initial_state=inter_beam_dec_initial_state,
        beam_width=config.beam_width, 
        output_layer=output_layer,
        length_penalty_weight=config.length_penalty_weight,
        latents_input=inter_beam_latents_input)

    inter_beam_dec_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        inter_beam_decoder,
        maximum_iterations = config.maximum_iterations)

    inter_beam_output_token_idxs = inter_beam_dec_outputs.predicted_ids[:, :, 0]

In [16]:
with tf.variable_scope('sent/dec/rnn', reuse=True):
    topic_dec_initial_state = tf.layers.Dense(units=config.dim_hidden, activation=tf.nn.relu, name='init_state')(means_topic)
    topic_beam_dec_initial_state = tf.contrib.seq2seq.tile_batch(topic_dec_initial_state, multiplier=config.beam_width)
    topic_beam_latents_input = tf.contrib.seq2seq.tile_batch(means_topic, multiplier=config.beam_width) # added
    
    topic_beam_decoder = BeamSearchDecoder(
        cell=dec_cell,
        embedding=embeddings,
        start_tokens=start_tokens,
        end_token=end_token,
        initial_state=topic_beam_dec_initial_state,
        beam_width=config.beam_width, 
        output_layer=output_layer,
        length_penalty_weight=config.length_penalty_weight,
        latents_input=topic_beam_latents_input)

    topic_beam_dec_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        topic_beam_decoder,
        maximum_iterations = config.maximum_iterations)

    topic_beam_output_token_idxs = topic_beam_dec_outputs.predicted_ids[:, :, 0]

In [17]:
with tf.variable_scope('sent/dec/rnn', reuse=True):
    means_topic_summary = tf.reduce_mean(means_topic_infer, 0)
    
    summary_dec_initial_state = tf.layers.Dense(units=config.dim_hidden, activation=tf.nn.relu, name='init_state')(means_topic_summary)
    summary_beam_dec_initial_state = tf.contrib.seq2seq.tile_batch(summary_dec_initial_state, multiplier=config.beam_width)
    summary_beam_latents_input = tf.contrib.seq2seq.tile_batch(means_topic_summary, multiplier=config.beam_width) # added
    
    summary_beam_decoder = BeamSearchDecoder(
        cell=dec_cell,
        embedding=embeddings,
        start_tokens=start_tokens,
        end_token=end_token,
        initial_state=summary_beam_dec_initial_state,
        beam_width=config.beam_width,
        output_layer=output_layer,
        length_penalty_weight=config.length_penalty_weight,
        latents_input=summary_beam_latents_input)

    summary_beam_dec_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        summary_beam_decoder,
        maximum_iterations = config.maximum_iterations)

    summary_beam_output_token_idxs = summary_beam_dec_outputs.predicted_ids[:, :, 0]

## language modeling cost

In [18]:
# target and mask
dec_target_idxs = t_variables['dec_target_idxs']
dec_mask_tokens = tf.sequence_mask(dec_sent_l, maxlen=max_sent_l+1, dtype=tf.float32)

# nll for each token (averaged over batch & sentence)
sent_loss_recon = tf.contrib.seq2seq.sequence_loss(output_logits, dec_target_idxs, dec_mask_tokens)

In [19]:
doc_l = t_variables['doc_l']
mask_sents = tf.sequence_mask(doc_l)
mask_sents_flatten = tf.reshape(mask_sents, [tf.shape(mask_sents)[0]*tf.shape(mask_sents)[1]])

prob_topic_tiled = tf.tile(tf.expand_dims(prob_topic, 1), [1, tf.shape(mask_sents)[1], 1])
prob_topic_flatten = tf.reshape(prob_topic_tiled, [tf.shape(mask_sents)[0]*tf.shape(mask_sents)[1], config.n_topic])
prob_topic_sents = tf.boolean_mask(prob_topic_flatten, mask_sents_flatten)

In [20]:
# inferred mixture probabilities (computed for each sentence)
categ_topic_infer = tfd.Categorical(probs=prob_topic_infer)

# prior of mixture probabilities (computed for each document, tiled for each sentence)
categ_topic = tfd.Categorical(probs=prob_topic_sents)

sent_loss_kl_categ = tf.reduce_mean(tfd.kl_divergence(categ_topic_infer, categ_topic))

# inference of each gaussian gaussribution (computed for each sentence)

sent_loss_kl_gauss = tf.reduce_sum(tfd.kl_divergence(gauss_topic_infer, gauss_topic), -1)
sent_loss_kl_gmm = tf.reduce_mean(tf.reduce_sum(tf.multiply(prob_topic_infer, sent_loss_kl_gauss), -1))

sent_loss_kl = sent_loss_kl_categ + sent_loss_kl_gmm

## optimizer

In [21]:
global_step = tf.Variable(0, name='global_step',trainable=False)
tau = tf.cast(tf.divide(tf.mod(global_step, tf.constant(config.cycle_steps)), tf.constant(config.cycle_steps)), dtype=tf.float32)
beta = tf.minimum(1., tau/config.r_cycle)

sent_loss = sent_loss_recon + beta * sent_loss_kl

topic_loss = topic_loss_recon + topic_loss_kl + config.reg * topic_loss_reg
loss = topic_loss + sent_loss

# define optimizer
if config.opt == 'Adam':
    optimizer = tf.train.AdamOptimizer(config.lr)
elif config.opt == 'Adagrad':
    optimizer = tf.train.AdagradOptimizer(config.lr)
    
grad_vars = optimizer.compute_gradients(loss)
clipped_grad_vars = [(tf.clip_by_value(grad, -config.grad_clip, config.grad_clip), var) for grad, var in grad_vars]

opt = optimizer.apply_gradients(clipped_grad_vars, global_step=global_step)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


# run model 

In [22]:
def idxs_to_sents(token_idxs, config, idx_to_word):
    sents = []
    for sent_idxs in token_idxs:
        tokens = []
        for idx in sent_idxs:
            if idx == config.EOS_IDX: break
            tokens.append(idx_to_word[idx])
        sent = ' '.join(tokens)
        sents.append(sent)
    return sents

In [23]:
def get_loss(sess, batches):
    losses = []
    ppl_list = []
    for ct, batch in batches:
        feed_dict = get_feed_dict(batch, mode='test')
        loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, sent_loss_recon_batch, sent_loss_kl_batch, ppls_batch \
            = sess.run([loss, topic_loss_recon, topic_loss_kl, topic_loss_reg, sent_loss_recon, sent_loss_kl, topic_ppls], feed_dict = feed_dict)
        losses += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, sent_loss_recon_batch, sent_loss_kl_batch]]
        ppl_list += list(ppls_batch)
    loss_mean, topic_loss_recon_mean, topic_loss_kl_mean, topic_loss_reg_mean, sent_loss_recon_mean, sent_loss_kl_mean = np.mean(losses, 0)
    ppl_mean = np.exp(np.mean(ppl_list))
    return loss_mean, topic_loss_recon_mean, topic_loss_kl_mean, topic_loss_reg_mean, sent_loss_recon_mean, sent_loss_kl_mean, ppl_mean

def get_all_losses(sess, batches):
    losses = []
    for ct, batch in batches:
        feed_dict = get_feed_dict(batch, mode='test')
        loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, sent_loss_recon_batch, sent_loss_kl_batch = \
        sess.run([loss, topic_loss_recon, topic_loss_kl, sent_loss_recon, sent_loss_kl], feed_dict = feed_dict)
        losses += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, sent_loss_recon_batch, sent_loss_kl_batch]]
    print('LOSS %.2f | TM NLL: %.2f, KL: %.4f | LM NLL: %.2f, KL: %.4f' %  np.mean(losses, 0))

In [24]:
def print_sample(sample_batch):
    feed_dict = get_feed_dict(sample_batch)
    pred_token_idxs = sess.run(output_token_idxs, feed_dict = feed_dict)
    true_token_idxs = [sent_idxs for instance in sample_batch for sent_idxs in instance.token_idxs]
    
    assert len(pred_token_idxs) == len(true_token_idxs)
    
    pred_sents = idxs_to_sents(pred_token_idxs, config, idx_to_word)
    true_sents = idxs_to_sents(true_token_idxs, config, idx_to_word)
    
    for i, (true_sent, pred_sent) in enumerate(zip(true_sents, pred_sents)):        
        print(i, 'TRUE: %s' % true_sent)
        print(i, 'PRED: %s' % pred_sent)

def print_topic_sample():
    pred_topics_freq_bow_indices, pred_topic_token_idxs = sess.run([topics_freq_bow_indices, topic_beam_output_token_idxs], 
                                                                                                           feed_dict={t_variables['batch_l']: config.n_topic, t_variables['keep_prob']: 1.,})
    pred_topic_sents = idxs_to_sents(pred_topic_token_idxs, config, idx_to_word)
    
    topics_freq_bow_idxs = bow_idxs[pred_topics_freq_bow_indices]
    
    print('-----------Topic Samples-----------')
    for i, (topic_freq_bow_idxs, pred_topic_sent) in enumerate(zip(topics_freq_bow_idxs, pred_topic_sents)):
        print(i, ' BOW:', ' '.join([idx_to_word[idx] for idx in topic_freq_bow_idxs]))
        print(i, ' SENTENCE:', pred_topic_sent)
        
def print_summary(test_batch):
    feed_dict = get_feed_dict(test_batch)
    feed_dict[t_variables['batch_l']] = config.n_topic
    feed_dict[t_variables['keep_prob']] = 1.
    pred_topics_freq_bow_indices, pred_summary_token_idxs = sess.run([topics_freq_bow_indices, summary_beam_output_token_idxs], feed_dict=feed_dict)
    pred_summary_sents = idxs_to_sents(pred_summary_token_idxs, config, idx_to_word)
    
    topics_freq_bow_idxs = bow_idxs[pred_topics_freq_bow_indices]
    
    print('-----------Output sentences for each topic-----------')
    print('Item idx:', test_batch[0].item_idx)
    for i, (topic_freq_bow_idxs, pred_summary_sent) in enumerate(zip(topics_freq_bow_idxs, pred_summary_sents)):
        print(i, ' BOW:', ' '.join([idx_to_word[idx] for idx in topic_freq_bow_idxs]))
        print(i, ' SENTENCE:', pred_summary_sent)
        
    print('-----------Summaries-----------')
    for i, summary in enumerate(test_batch[0].summaries):
        print('SUMMARY %i :'%i, '\n', summary)

In [25]:
if 'sess' in globals(): sess.close()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

losses_train = []
ppls_train = []
loss_min = np.inf
beta_eval = 1.
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)
saver = tf.train.Saver(max_to_keep=10)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','TM','','','','LM','','VALID:','TM','','','','LM','', ''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','NLL','KL','LOSS','PPL','NLL','KL','REG','NLL','KL', 'Beta']]))))

In [26]:
if len(log_df) == 0:
    cmd_rm = 'rm -r %s' % config.modeldir
    res = subprocess.call(cmd_rm.split())

    cmd_mk = 'mkdir %s' % config.modeldir
    res = subprocess.call(cmd_mk.split())

time_start = time.time()
while epoch < config.epochs:
    for ct, batch in train_batches:
        feed_dict = get_feed_dict(batch)

        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, sent_loss_recon_batch, sent_loss_kl_batch, sent_loss_kl_categ_batch, sent_loss_kl_gmm_batch, ppls_batch = \
        sess.run([opt, loss, topic_loss_recon, topic_loss_kl, topic_loss_reg, sent_loss_recon, sent_loss_kl, sent_loss_kl_categ, sent_loss_kl_gmm, topic_ppls], feed_dict = feed_dict)
   
        if sent_loss_kl_batch == np.inf:
            print('Nan occured')
            ckpt = tf.train.get_checkpoint_state(config.modeldir)
            model_checkpoint_path = ckpt.all_model_checkpoint_paths[-1]
            saver.restore(sess, model_checkpoint_path)            
            break
            
        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, sent_loss_recon_batch, sent_loss_kl_batch]]
        ppls_train += list(ppls_batch)

        if ct%config.log_period==0:
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train, sent_loss_recon_train, sent_loss_kl_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, sent_loss_recon_dev, sent_loss_kl_dev, ppl_dev = get_loss(sess, dev_batches)
            global_step_log, beta_eval = sess.run([tf.train.get_global_step(), beta])
            
            if loss_dev < loss_min:
                loss_min = loss_dev
                saver.save(sess, config.modelpath, global_step=global_step_log)

            clear_output()
    
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, '%.2f'%sent_loss_recon_train, '%.2f'%sent_loss_kl_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, '%.2f'%sent_loss_recon_dev, '%.2f'%sent_loss_kl_dev,  '%.3f'%beta_eval],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)

            print_summary(test_batches[1][1])
            print_sample(batch)
            
            time_start = time.time()
            
    epoch += 1
    train_batches = get_batches(instances_train, config.batch_size, iterator=True)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,TM,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,LM,Unnamed: 10_level_0,VALID:,TM,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,LM,Unnamed: 17_level_0,Unnamed: 18_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,NLL,KL,LOSS,PPL.1,NLL,KL,REG,NLL.1,KL,Beta
1,14,0,0,128.53,1036,118.02,0.48,0.90,9.12,1.49,128.62,1017,115.86,2.74,0.90,9.12,1.64,0.000
501,79,0,500,121.50,602,114.68,0.24,0.36,6.08,3.51,111.73,550,105.85,0.01,0.22,5.47,2.06,0.088
1001,63,0,1000,120.08,586,113.88,0.13,0.24,5.73,2.03,110.66,536,105.49,0.00,0.04,5.11,0.07,0.176
1501,65,0,1500,119.25,578,113.40,0.09,0.17,5.53,1.36,110.23,532,105.31,0.00,0.01,4.91,0.02,0.264
2001,62,0,2000,118.88,573,113.25,0.07,0.13,5.39,1.03,110.08,531,105.28,0.00,0.01,4.79,0.01,0.352
2276,29,1,0,118.93,572,113.38,0.06,0.11,5.33,0.90,110.12,537,105.38,0.00,0.01,4.73,0.01,0.400
2776,72,1,500,118.63,569,113.22,0.05,0.09,5.24,0.74,109.90,531,105.24,0.00,0.01,4.65,0.01,0.488
3276,71,1,1000,118.54,567,113.22,0.04,0.08,5.16,0.63,109.81,530,105.21,0.00,0.01,4.59,0.01,0.576
3776,55,1,1500,118.35,567,113.11,0.04,0.07,5.10,0.55,109.83,532,105.29,0.00,0.01,4.53,0.01,0.664
4276,65,1,2000,118.30,566,113.12,0.04,0.06,5.05,0.48,109.70,531,105.21,0.00,0.01,4.47,0.01,0.752


-----------Output sentences for each topic-----------
Item idx: B000VB7EFW
0  BOW: cover ! $ % & ' 'd 'll 'm 're
0  SENTENCE: 
1  BOW: cover ! $ % & ' 'd 'll 'm 're
1  SENTENCE: 
2  BOW: cover ; - ! $ % & ' 'd 'll
2  SENTENCE: 
3  BOW: cover ! $ % & ' 'd 'll 'm 're
3  SENTENCE: 
4  BOW: cover ! $ % & ' 'd 'll 'm 're
4  SENTENCE: 
5  BOW: cover ! $ % & ' 'd 'll 'm 're
5  SENTENCE: 
6  BOW: cover ! $ % & ' 'd 'll 'm 're
6  SENTENCE: 
7  BOW: cover ! $ % & ' 'd 'll 'm 're
7  SENTENCE: 
8  BOW: cover ! $ % & ' 'd 'll 'm 're
8  SENTENCE: 
9  BOW: cover ! $ % & ' 'd 'll 'm 're
9  SENTENCE: 
-----------Summaries-----------
SUMMARY 0 : 
 This is a very well made bag, nice construction, lots of pockets.
the straps are very comfortable.
and protects everything inside.
It says
it fits a 17inch notebook,
however it did not.
after using the pack for less than a month,
it is ripping out already.
SUMMARY 1 : 
 This is a very well made bag, nice construction, lots of pockets.
The quality is excellent


# confirm variables

In [27]:
_prob_topic, _prob_topic_sents, _prob_topic_infer, _means_topic_infer = debug_value([prob_topic, prob_topic_sents, prob_topic_infer, means_topic_infer], return_value=True)

In [28]:
batch_i = 4
_prob_topic_sents[batch_i], _prob_topic_infer[batch_i]

(array([2.6868487e-05, 8.6784212e-06, 9.9986529e-01, 1.4388333e-05,
        1.7386952e-05, 1.4713172e-05, 7.6274541e-06, 1.7465694e-05,
        1.1872702e-05, 1.5650594e-05], dtype=float32),
 array([1.4476807e-11, 1.0054337e-11, 1.0000000e+00, 1.1890730e-11,
        1.2296763e-11, 1.2778320e-11, 1.0013915e-11, 1.2678735e-11,
        1.2906035e-11, 1.3451982e-11], dtype=float32))

In [29]:
_means_topic_infer[0][:, :4]

array([[-1.8682381e+01, -1.7612494e+01, -1.7935837e+01, -1.7858501e+01],
       [-1.8674824e+01, -1.8183336e+01, -1.8666691e+01, -1.8586023e+01],
       [-3.4380314e-01, -2.0471608e-02,  1.2687216e-02,  9.8747285e-03],
       [-1.8649149e+01, -1.8267252e+01, -1.8531384e+01, -1.8472548e+01],
       [-1.8672371e+01, -1.7965401e+01, -1.8521042e+01, -1.8361164e+01],
       [-1.8543634e+01, -1.7945541e+01, -1.8386732e+01, -1.8298574e+01],
       [-1.8830236e+01, -1.8141018e+01, -1.8735767e+01, -1.8616648e+01],
       [-1.8482653e+01, -1.7899456e+01, -1.8424776e+01, -1.8251898e+01],
       [-1.8605921e+01, -1.8031254e+01, -1.8298433e+01, -1.8395790e+01],
       [-1.8475525e+01, -1.8080276e+01, -1.8305523e+01, -1.8298229e+01]],
      dtype=float32)

In [30]:
w_means_topic, b_means_topic = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "topic/dec/mean_topic")

pred_topic_embeddings, pred_topic_bow, pred_means_topic, pred_logvars_topic, pred_token_idxs, _w_means_topic, _b_means_topic, _w_mean_topic_infer = \
                                sess.run([topic_embeddings, topic_bow, means_topic, logvars_topic, topic_beam_output_token_idxs, w_means_topic, b_means_topic, w_mean_topic_infer], 
                                         feed_dict={t_variables['batch_l']: config.n_topic, t_variables['keep_prob']: 1.,})

pred_sents = idxs_to_sents(pred_token_idxs, config, idx_to_word)

pred_topics_freq_bow_indices = np.argsort(pred_topic_bow, 1)[:, ::-1][:, :10]
pred_topics_freq_bow_idxs = bow_idxs[pred_topics_freq_bow_indices]

In [31]:
for idxs in pred_topics_freq_bow_idxs:
    print([idx_to_word[idx] for idx in idxs])

['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', ';', '-', 'fix', 'fine', 'finger', 'fingerprints', 'finish', 'fitting', 'zips']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']
['cover', 'zips', 'floor', 'flexible', 'flaw', 'flat', 'flash', 'flap', 'fix', 'fitting']


In [32]:
pred_topic_embeddings[:, :10]

array([[-2.836687  , -2.4325309 , -2.962067  , -2.1210825 , -2.6124496 ,
        -1.8852538 , -2.4223897 , -2.7951393 , -2.2931848 , -2.1283157 ],
       [-1.6798272 , -2.2147717 , -1.8608545 , -1.6574732 , -1.8061484 ,
        -2.1520698 , -1.5335383 , -1.9913623 , -1.3796672 , -2.0616167 ],
       [-1.0150758 , -0.89505154, -1.0576841 , -1.0169678 , -0.9941106 ,
        -1.0553619 , -0.9833162 , -0.9620681 , -0.9419456 , -0.9740358 ],
       [-2.4011621 , -2.0774152 , -1.6468724 , -2.3599555 , -2.3003724 ,
        -2.1366048 , -2.1995428 , -2.0477817 , -2.3054452 , -2.083294  ],
       [-1.9486684 , -1.5919771 , -2.2330887 , -1.8404497 , -1.6795243 ,
        -1.7537042 , -1.8309035 , -2.172012  , -1.9229968 , -1.5811139 ],
       [-1.7979274 , -2.3900738 , -2.2278528 , -2.05863   , -1.8887968 ,
        -1.8150214 , -1.9893633 , -2.0921412 , -1.8070822 , -2.0015564 ],
       [-1.9307586 , -2.6086934 , -2.179174  , -2.474641  , -2.6866994 ,
        -3.1212592 , -1.8104733 , -1.8549055 

In [33]:
pred_topic_bow

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [34]:
_w_means_topic

array([[ 0.0739732 , -0.08790135,  0.03179525, ..., -0.06434678,
         0.0469898 , -0.01484643],
       [ 0.12066822, -0.07779175,  0.0811064 , ...,  0.00998305,
         0.02352253, -0.06851349],
       [-0.18138495,  0.8433418 ,  0.61981755, ..., -0.11070126,
        -0.24917184,  0.6417571 ],
       ...,
       [-0.07755507,  0.08136196, -0.03300147, ..., -0.03827026,
        -0.01169072,  0.04148794],
       [ 0.2552175 ,  0.77502584,  0.5159265 , ...,  0.26267377,
         0.12818944,  0.3866017 ],
       [ 0.06512748,  0.06922452, -0.01005569, ...,  0.10908195,
        -0.07925141,  0.00807684]], dtype=float32)

In [35]:
_b_means_topic

array([-0.34380314, -0.0219513 ,  0.01438296,  0.00777078, -0.34469175,
       -0.3049803 , -0.02120364, -0.41466707,  0.01747918, -0.00686976,
        0.01272351, -0.02183504,  0.04956156,  0.03889449,  0.01130024,
       -0.32727343,  0.01677949, -0.41357547,  0.01243312, -0.00792332,
       -0.34260193, -0.41951048,  0.05819419, -0.30284566, -0.31787518,
       -0.31931284, -0.32922003, -0.3255996 , -0.07046406, -0.29471713,
       -0.3527893 ,  0.00778683], dtype=float32)

In [36]:
pred_means_topic

array([[-0.34380314, -0.0219513 ,  0.01438296,  0.00777078, -0.34469175,
        -0.3049803 , -0.02120364, -0.41466707,  0.01747918, -0.00686976,
         0.01272351, -0.02183504,  0.04956156,  0.03889449,  0.01130024,
        -0.32727343,  0.01677949, -0.41357547,  0.01243312, -0.00792332,
        -0.34260193, -0.41951048,  0.05819419, -0.30284566, -0.31787518,
        -0.31931284, -0.32922003, -0.3255996 , -0.07046406, -0.29471713,
        -0.3527893 ,  0.00778683],
       [-0.34380314, -0.0219513 ,  0.01438296,  0.00777078, -0.34469175,
        -0.3049803 , -0.02120364, -0.41466707,  0.01747918, -0.00686976,
         0.01272351, -0.02183504,  0.04956156,  0.03889449,  0.01130024,
        -0.32727343,  0.01677949, -0.41357547,  0.01243312, -0.00792332,
        -0.34260193, -0.41951048,  0.05819419, -0.30284566, -0.31787518,
        -0.31931284, -0.32922003, -0.3255996 , -0.07046406, -0.29471713,
        -0.3527893 ,  0.00778683],
       [-0.34380314, -0.0219513 ,  0.01438296,  0.0077

In [37]:
_w_mean_topic_infer[:, :10, 0]

array([[-35.833862, -35.247192, -35.53736 , -35.848877, -35.871967,
        -35.88059 , -35.82204 , -35.93227 , -35.87375 , -35.87809 ],
       [-35.57662 , -35.536495, -35.588844, -35.550194, -35.547264,
        -35.56824 , -35.578804, -35.54832 , -35.5852  , -35.59411 ],
       [-35.929325, -35.76501 , -36.037468, -35.967644, -35.233234,
        -35.940414, -35.025246, -35.227554, -36.04773 , -36.170334],
       [-35.57632 , -35.555447, -35.552483, -35.575047, -35.55795 ,
        -35.547314, -35.56661 , -35.556694, -35.56401 , -35.54191 ],
       [-35.57836 , -35.568584, -35.568565, -35.581158, -35.557106,
        -35.572052, -35.54077 , -35.57486 , -35.5603  , -35.53741 ],
       [-35.577534, -35.566864, -35.581524, -35.540504, -35.580524,
        -35.55489 , -35.56971 , -35.58291 , -35.572197, -35.550014],
       [-35.55864 , -35.556496, -35.552464, -35.559177, -35.542934,
        -35.56165 , -35.583477, -35.56518 , -35.544052, -35.56943 ],
       [-35.533333, -35.554634, -35.58234

In [38]:
_b_means_topic

array([-0.34380314, -0.0219513 ,  0.01438296,  0.00777078, -0.34469175,
       -0.3049803 , -0.02120364, -0.41466707,  0.01747918, -0.00686976,
        0.01272351, -0.02183504,  0.04956156,  0.03889449,  0.01130024,
       -0.32727343,  0.01677949, -0.41357547,  0.01243312, -0.00792332,
       -0.34260193, -0.41951048,  0.05819419, -0.30284566, -0.31787518,
       -0.31931284, -0.32922003, -0.3255996 , -0.07046406, -0.29471713,
       -0.3527893 ,  0.00778683], dtype=float32)

In [39]:
_enc_state_infer, _means_topic_infer = debug_value([enc_state_infer, means_topic_infer], return_value=True)

NameError: name 'enc_state_infer' is not defined

In [None]:
_enc_state_infer.shape

In [None]:
_means_topic_infer[0]