In [41]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
import random
import numpy as np

from util import gen_vocab, gen_data

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
PAD = '<pad>' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNK = '<unk>' # This has a vocab id, which is used to represent out-of-vocabulary words
BOS = '<p>' # This has a vocab id, which is used at the beginning of every decoder input sequence
EOS = '</p>' # This has a vocab id, which is used at the end of untruncated target sequences

In [44]:
class Config:
    def __init__(self):
        #preprocessing options
        self.vocab_minfreq=2 #minimum vocab frequency to filter
        self.vocab_maxfreq=0.001 #proportion of most-frequent vocab to filter
        self.stopwords='data/apnews/stopwords_mallet.txt'
        self.tm_sent_len=3 #m_1; topic model sequence length
        self.lm_sent_len=30 #m_2; language model sequence length
        self.doc_len=300 #m_3; document max length

        #training options
        self.seed=1
        self.batch_size=64 #n_batch
        self.rnn_layer_size=1 #n_layer
        self.rnn_hidden_size=60 #n_hidden
        self.epoch_size=1 #n_epoch
        self.topic_number=10 #k
        self.word_embedding_size=30 #e; setting ignored if word_embeding_model is provided
        self.word_embedding_model=None #pre-trained word embedding (gensim format); None if no pre-trained model
        self.word_embedding_update=True #update word embedding for topic model
        self.filter_sizes=[2] #h
        self.filter_number=20 #a; topic input vector dimension
        self.conv_activation='identity' #relu or identity (identity function is used in paper)
        self.topic_embedding_size=5 #b; topic output vector dimension
        self.learning_rate=0.001 #l
        self.tm_keep_prob=0.4 #p_1
        self.lm_keep_prob=0.6 #p_2
        self.max_grad_norm=5 #gradient clipping
        self.alpha=0.0 #additional loss to penalise similar topics; not used in paper (0.0)
        self.num_samples=0 #sampled softmax to speed up training; not used in paper (0)
        self.tag_embedding_size=0 #tag embedding dimension; 0 to disable tags
        
        self.train_corpus='data/apnews/apnews50k_train.txt'
        self.valid_corpus='data/apnews/apnews50k_valid.txt'
        
        self.verbose=True #print progress
        
cf = Config()

# preprocess data 

In [45]:
#set the seeds
random.seed(cf.seed)
np.random.seed(cf.seed)

In [46]:
#first pass to collect vocabulary information
dummy_tokens = [PAD, BOS, EOS, UNK]
idxvocab, vocabxid, tm_ignore = gen_vocab(dummy_tokens, cf.train_corpus, cf.stopwords, cf.vocab_minfreq, cf.vocab_maxfreq, cf.verbose)

49000 processed

In [56]:
idx_to_word = {idx:word for idx, word in enumerate(idxvocab)}

In [47]:
#second pass to collect train/valid data for topic and language model
train_sents, train_docs, train_docids, train_stats = gen_data(vocabxid, dummy_tokens, tm_ignore, cf.train_corpus, cf.tm_sent_len, cf.lm_sent_len, cf.verbose, False)
valid_sents, valid_docs, valid_docids, valid_stats = gen_data(vocabxid, dummy_tokens, tm_ignore, cf.valid_corpus, cf.tm_sent_len, cf.lm_sent_len, cf.verbose, False)

1000 processedd

In [48]:
valid_sents

([(0, 0, [1, 3253, 3807, 1088]),
  (0, 1, [1088, 253, 12681, 925]),
  (0, 2, [925, 590, 4258, 1581]),
  (0, 3, [1581, 2547, 5418, 6515]),
  (0, 4, [6515, 51695, 5277, 4172]),
  (0, 5, [4172, 30405, 1903, 576]),
  (0, 6, [576, 3253, 1903, 544]),
  (0, 7, [544, 362, 613, 2619]),
  (0, 8, [2619, 1305, 6701, 4902]),
  (0, 9, [4902, 4524, 5642, 1864]),
  (0, 10, [1864, 6575, 5863, 952]),
  (0, 11, [952, 925, 590, 51695]),
  (0, 12, [51695, 1903, 15719, 974]),
  (0, 13, [974, 5381, 7334, 1243]),
  (0, 14, [1243, 1885, 903]),
  (1, 0, [1, 203, 519, 1107]),
  (1, 1, [1107, 1178, 519, 692]),
  (1, 2, [692, 1325, 8033, 1495]),
  (1, 3, [1495, 2814, 268, 3362]),
  (1, 4, [3362, 2945, 844, 629]),
  (1, 5, [629, 4524, 3556, 204]),
  (1, 6, [204, 3556, 178, 1752]),
  (1, 7, [1752, 213, 228, 2122]),
  (1, 8, [2122, 202, 190, 4162]),
  (1, 9, [4162, 1495, 2814, 268]),
  (1, 10, [268, 281, 2814, 453]),
  (1, 11, [453, 673, 481, 190]),
  (1, 12, [190, 678, 1793, 692]),
  (1, 13, [692, 268, 4711, 16482])

In [70]:
for i in range(12):
    print([idx_to_word[idx] for idx in valid_sents[1][i][-1]])

['<p>', 'a', 'richmond', 'developer', 'will', 'spend', 'more', 'than', '16', 'years', 'in', 'prison', 'for', 'defrauding', 'federal', 'and', 'state', 'tax', 'credit', 'programs', 'for', 'rehabilitation', 'of', 'historic', 'properties', '.', '</p>']
['<p>', 'u.s.', 'attorney', 'neil', 'h.', 'macbride', 'says', '40-year-old', 'justin', 'glynn', 'french', 'was', 'sentenced', 'tuesday', 'in', 'federal', 'court', 'in', 'richmond', '.', '</p>']
['<p>', 'french', 'had', 'pleaded', 'guilty', 'in', 'january', 'to', 'mail', 'fraud', 'and', 'engaging', 'in', 'unlawful', 'monetary', 'transactions', '.', '</p>']
['<p>', 'he', 'admitted', 'obtaining', '$', '7', 'million', 'to', '$', '20', 'million', 'more', 'than', 'he', 'was', 'entitled', 'to', 'receive', 'from', 'the', 'tax', 'credit', 'programs', '.', '</p>']
['<p>', 'macbride', 'says', 'french', 'defrauded', 'more', 'than', '100', 'investors', 'and', 'lined', 'his', 'pockets', 'with', 'millions', 'in', 'stolen', 'tax', 'dollars', '.', '</p>']
['

In [74]:
len(idx_to_word)

93348

In [73]:
len(train_sents[1])

808673

In [64]:
valid_docs[1]

[[[3253, 3807, 1088, 253, 12681, 925, 590, 4258, 1581, 2547],
  [5418, 6515, 51695, 5277, 4172, 30405, 1903, 576, 3253],
  [1903, 544, 362, 613, 2619, 1305, 6701, 4902, 4524, 5642],
  [1864, 6575, 5863, 952, 925, 590],
  [51695, 1903, 15719, 974, 5381, 7334, 1243, 1885, 903]],
 [[203, 519, 1107, 1178, 519, 692, 1325, 8033, 1495, 2814, 268],
  [3362, 2945, 844, 629, 4524, 3556, 204],
  [3556, 178, 1752, 213, 228, 2122, 202, 190, 4162, 1495, 2814, 268, 281],
  [2814, 453],
  [673, 481, 190, 678, 1793],
  [692, 268, 4711, 16482, 2548, 190, 2272, 8150, 20007, 2291, 3747],
  [35172],
  [18451, 4740, 2352, 1571, 545]],
 [[1467, 4819, 2093, 2664, 2412, 183, 216, 1432, 9636, 13947, 1404],
  [82792, 32053, 853, 2266, 1467, 616, 8986, 2532, 394],
  [1263, 1875, 5940, 2664, 298],
  [32053, 1467, 2546, 1037],
  [5940, 488, 2322, 446, 948, 2637, 8986],
  [2474, 4254, 948, 1065, 36749, 823, 4368],
  [4819, 1900, 488, 1523, 9990, 11988, 4254, 462]],
 [[1938,
   359,
   785,
   5760,
   17368,
   2371