In [1]:
import pandas as pd
df = pd.read_csv('../data/hansard_speeches_processed.csv')
df = df.sample(n=50000,random_state=42)

train = df.sample(frac=0.5,random_state=42)
test = df.drop(train.index).reset_index(drop=True)
train = train.reset_index(drop=True)

In [2]:
import sys
sys.path.append('../gtm/')
from corpus import GTMCorpus
from gtm import GTM

# Create a GTMCorpus objects (one train and one test set to avoid overfitting the supervised learning algorithm)
train_dataset = GTMCorpus(
    train, # Must contain a column 'doc' with the text of each document and a column 'doc_clean' with the cleaned text of each document.
    labels = "~party-1", # The features to predict. Would be "~ gdp" if the df has a column 'gdp'.
)

test_dataset = GTMCorpus(
    test, 
    labels = "~party-1", 
    vectorizer = train_dataset.vectorizer # pass on the same vectorizer as for the training set (this ensures the document term matrices have the same number of dimensions)
)

  from .autonotebook import tqdm as notebook_tqdm
  self.log_word_frequencies = torch.FloatTensor(np.log(np.array(self.M_bow.sum(axis=0)).flatten()))


In [3]:
import numpy as np
train_dataset.M_labels = np.delete(train_dataset.M_labels, -0, axis=1)
test_dataset.M_labels = np.delete(test_dataset.M_labels, -0, axis=1)

In [4]:
# Train the model
tm = GTM(
    train_dataset, 
    test_dataset,
    n_topics = 20,
    doc_topic_prior = 'dirichlet', # other option is "logistic_normal"
    update_prior = False, # no prevalence covariates so no need to update the prior
    encoder_hidden_layers=[128,64], # structure of the encoder neural net
    decoder_hidden_layers=[64], # structure of the decoder neural net
    predictor_type = 'classifier', # 'regressor' for continuous variables such as GDP
    num_epochs = 10, # No need to run many epochs. I found 10 to work well on 50 000 speeches.
    w_pred_loss = 10, # how much weight should we give to the prediction task in the likelihood?
    print_every = 50, # print progress every x batches
    log_every = 1, # print topic-word dist every x epochs
    batch_size=256,
)

Epoch   1	Iter   20	Training Loss:12.6886930
Rec Loss:2.8161604
MMD Loss:9.8454857
Sparsity Loss:0.0000000
Pred Loss:0.0270473

Epoch   1	Iter   40	Training Loss:164.9443359
Rec Loss:35.8425064
MMD Loss:128.6754303
Sparsity Loss:0.0000000
Pred Loss:0.4263945


Epoch   1	Mean Training Loss:17.8338630

Epoch   1	Iter   20	Validation Loss:12.5746689
Rec Loss:2.7410650
MMD Loss:9.8066006
Sparsity Loss:0.0000000
Pred Loss:0.0270033

Epoch   1	Iter   40	Validation Loss:240.2220764
Rec Loss:60.7605705
MMD Loss:179.0244293
Sparsity Loss:0.0000000
Pred Loss:0.4370740


Epoch   1	Mean Validation Loss:18.9860128

['leave', 'subject', 'try', 'time', 'newspaper', 'death', 'cover', 'argument']
['humankind', 'rearrangement', 'assign', 'practice', 'gps', 'coat', 'scrounger', 'tracing']
['act', 'reform', 'stage', 'seat', 'achieve', 'evidence', 'party', 'pressure']
['leader', 'reflect', 'join', 'outcome', 'colleague', 'review', 'ability', 'care']
['time', 'call', 'speech', 'agree', 'committee', 'decline

Epoch   5	Iter   20	Training Loss:2.6751425
Rec Loss:2.3678026
MMD Loss:0.2803771
Sparsity Loss:0.0000000
Pred Loss:0.0269627

Epoch   5	Iter   40	Training Loss:40.5419769
Rec Loss:37.5501823
MMD Loss:2.5345263
Sparsity Loss:0.0000000
Pred Loss:0.4572691


Epoch   5	Mean Training Loss:3.6289726

Epoch   5	Iter   20	Validation Loss:10.7160759
Rec Loss:2.6898460
MMD Loss:7.9994407
Sparsity Loss:0.0000000
Pred Loss:0.0267889

Epoch   5	Iter   40	Validation Loss:135.1471252
Rec Loss:30.4597969
MMD Loss:104.2501678
Sparsity Loss:0.0000000
Pred Loss:0.4371687


Epoch   5	Mean Validation Loss:13.0126077

['hon', 'work', 'time', 'people', 'give', 'need', 'friend', 'make']
['hon', 'friend', 'work', 'people', 'make', 'give', 'time', 'point']
['make', 'hon', 'people', 'give', 'year', 'friend', 'work', 'know']
['hon', 'make', 'friend', 'people', 'work', 'year', 'give', 'time']
['make', 'time', 'hon', 'friend', 'work', 'people', 'need', 'say']
['hon', 'people', 'make', 'time', 'give', 'friend', 'wo

Epoch   9	Iter   40	Validation Loss:143.7682190
Rec Loss:27.5829430
MMD Loss:115.7283173
Sparsity Loss:0.0000000
Pred Loss:0.4569491


Epoch   9	Mean Validation Loss:13.9695221

['hon', 'say', 'make', 'people', 'friend', 'year', 'take', 'work']
['hon', 'make', 'say', 'friend', 'people', 'take', 'year', 'give']
['hon', 'make', 'say', 'people', 'year', 'friend', 'take', 'give']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'friend', 'people', 'year', 'take', 'time']
['hon', 'make', 'people', 'say', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'friend', 'people', 'year', 'take', 'time']
['hon', 'make', 'friend', 'say', 'people', 'take', 'year', 'give']
['hon', 'make', 'people', 'friend', 'say', 'year', 'time', 'take']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'friend', 'people', 'year', 'take', 'time']
['hon', 'make', 'say', 'friend', 'people', 'take', 'year', 'give']
['hon', 'make', 's

Epoch  14	Iter   20	Training Loss:2.4976757
Rec Loss:2.3976569
MMD Loss:0.0733324
Sparsity Loss:0.0000000
Pred Loss:0.0266863

Epoch  14	Iter   40	Training Loss:23.1820641
Rec Loss:27.1865768
MMD Loss:-4.4332676
Sparsity Loss:0.0000000
Pred Loss:0.4287551


Epoch  14	Mean Training Loss:3.0777507

Epoch  14	Iter   20	Validation Loss:11.7473583
Rec Loss:2.5225539
MMD Loss:9.1979599
Sparsity Loss:0.0000000
Pred Loss:0.0268438

Epoch  14	Iter   40	Validation Loss:194.2173004
Rec Loss:49.0739822
MMD Loss:144.7139587
Sparsity Loss:0.0000000
Pred Loss:0.4293629


Epoch  14	Mean Validation Loss:15.6720382

['hon', 'say', 'make', 'people', 'friend', 'year', 'take', 'work']
['hon', 'make', 'friend', 'say', 'people', 'take', 'year', 'work']
['hon', 'make', 'say', 'people', 'year', 'friend', 'take', 'give']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'friend', 'people', 'year', 'take', 'work']
['hon', 'make', 'people', 'say', 'friend', 'year', 'take', 

Epoch  18	Iter   40	Validation Loss:187.6553497
Rec Loss:47.9478226
MMD Loss:139.2696075
Sparsity Loss:0.0000000
Pred Loss:0.4379174


Epoch  18	Mean Validation Loss:15.4237269

['hon', 'say', 'make', 'people', 'year', 'friend', 'take', 'give']
['hon', 'make', 'friend', 'people', 'say', 'take', 'year', 'work']
['hon', 'make', 'say', 'people', 'year', 'friend', 'take', 'give']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'work']
['hon', 'make', 'say', 'friend', 'people', 'year', 'take', 'work']
['hon', 'make', 'people', 'say', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'time']
['hon', 'make', 'friend', 'say', 'people', 'take', 'year', 'give']
['hon', 'make', 'people', 'say', 'friend', 'year', 'take', 'time']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'people', 'friend', 'year', 'take', 'give']
['hon', 'make', 'say', 'people', 'friend', 'take', 'year', 'give']
['hon', 'make', 's