In [1]:
import pandas as pd
df = pd.read_csv('../data/hansard_speeches_processed.csv')
df = df.sample(n=50000, random_state = 42)

import sys
sys.path.append('../gtm/')
from corpus import GTMCorpus
from gtm import GTM

# Create a GTMCorpus object
train_dataset = GTMCorpus(
    df, # Must contain a column 'doc' with the text of each document and a column 'doc_clean' with the cleaned text of each document.
    labels = "~ party", # The features to predict. Would be "~ gdp" if the df has a column 'gdp'.
    content='~ 1' # To absorb frequent/procedural words
)

# Train the model
tm = GTM(
    train_dataset, 
    n_topics=20,
    doc_topic_prior='dirichlet', # other option is logistic_normal
    update_prior=False, # no prevalence covariates so no need to update the prior
    predictor_type='classifier', # 'regressor' for continuous variables such as GDP
    num_epochs=10 # No need to run many epochs. I found 10 to work well on 50 000 speeches.
)

  from .autonotebook import tqdm as notebook_tqdm


Epoch   1	Iter   10	Loss:12.4985991	Rec Loss:2.7515297	MMD:9.7431316	Sparsity_Loss:0.0000000	Pred_Loss:0.0039373
Epoch   1	Iter   20	Loss:10.2804003	Rec Loss:2.2792683	MMD:7.9973512	Sparsity_Loss:0.0000000	Pred_Loss:0.0037813
Epoch   1	Iter   30	Loss:9.7446136	Rec Loss:2.6023178	MMD:7.1383057	Sparsity_Loss:0.0000000	Pred_Loss:0.0039903
Epoch   1	Iter   40	Loss:7.8622828	Rec Loss:2.5463870	MMD:5.3119216	Sparsity_Loss:0.0000000	Pred_Loss:0.0039739
Epoch   1	Iter   50	Loss:4.5873442	Rec Loss:2.2960241	MMD:2.2873540	Sparsity_Loss:0.0000000	Pred_Loss:0.0039663
Epoch   1	Iter   60	Loss:3.6806417	Rec Loss:2.3312786	MMD:1.3455907	Sparsity_Loss:0.0000000	Pred_Loss:0.0037724
Epoch   1	Iter   70	Loss:2.9565632	Rec Loss:2.2007699	MMD:0.7518327	Sparsity_Loss:0.0000000	Pred_Loss:0.0039605
Epoch   1	Iter   80	Loss:2.7045636	Rec Loss:2.1932111	MMD:0.5075876	Sparsity_Loss:0.0000000	Pred_Loss:0.0037649
Epoch   1	Iter   90	Loss:2.8402884	Rec Loss:2.1366076	MMD:0.6999087	Sparsity_Loss:0.0000000	Pred_Loss:

Epoch   4	Iter  180	Loss:2.4311395	Rec Loss:2.3266883	MMD:0.1004048	Sparsity_Loss:0.0000000	Pred_Loss:0.0040463
Epoch   4	Iter  190	Loss:2.7160995	Rec Loss:2.4796121	MMD:0.2325526	Sparsity_Loss:0.0000000	Pred_Loss:0.0039349
Epoch   5	Iter   10	Loss:2.5003145	Rec Loss:2.3354883	MMD:0.1609896	Sparsity_Loss:0.0000000	Pred_Loss:0.0038365
Epoch   5	Iter   20	Loss:2.5280614	Rec Loss:2.4411931	MMD:0.0830270	Sparsity_Loss:0.0000000	Pred_Loss:0.0038414
Epoch   5	Iter   30	Loss:2.6254165	Rec Loss:2.4383011	MMD:0.1832000	Sparsity_Loss:0.0000000	Pred_Loss:0.0039155
Epoch   5	Iter   40	Loss:2.5431709	Rec Loss:2.4466226	MMD:0.0926466	Sparsity_Loss:0.0000000	Pred_Loss:0.0039018
Epoch   5	Iter   50	Loss:2.3943644	Rec Loss:2.2197537	MMD:0.1706268	Sparsity_Loss:0.0000000	Pred_Loss:0.0039838
Epoch   5	Iter   60	Loss:2.4402196	Rec Loss:2.1390157	MMD:0.2973271	Sparsity_Loss:0.0000000	Pred_Loss:0.0038769
Epoch   5	Iter   70	Loss:2.1655185	Rec Loss:2.0571136	MMD:0.1045344	Sparsity_Loss:0.0000000	Pred_Loss:0.

Epoch   8	Iter   20	Loss:2.7009394	Rec Loss:2.3967650	MMD:0.3000473	Sparsity_Loss:0.0000000	Pred_Loss:0.0041272
Epoch   8	Iter   30	Loss:2.8929253	Rec Loss:2.6027308	MMD:0.2862331	Sparsity_Loss:0.0000000	Pred_Loss:0.0039614
Epoch   8	Iter   40	Loss:2.2972202	Rec Loss:2.0731342	MMD:0.2201917	Sparsity_Loss:0.0000000	Pred_Loss:0.0038944
Epoch   8	Iter   50	Loss:2.8508446	Rec Loss:2.6113338	MMD:0.2355860	Sparsity_Loss:0.0000000	Pred_Loss:0.0039249
Epoch   8	Iter   60	Loss:2.5535820	Rec Loss:2.3623841	MMD:0.1872940	Sparsity_Loss:0.0000000	Pred_Loss:0.0039038
Epoch   8	Iter   70	Loss:2.7651854	Rec Loss:2.4919565	MMD:0.2693123	Sparsity_Loss:0.0000000	Pred_Loss:0.0039167
Epoch   8	Iter   80	Loss:2.5600955	Rec Loss:2.2926793	MMD:0.2634948	Sparsity_Loss:0.0000000	Pred_Loss:0.0039214
Epoch   8	Iter   90	Loss:2.6763868	Rec Loss:2.4935913	MMD:0.1788812	Sparsity_Loss:0.0000000	Pred_Loss:0.0039144
Epoch   8	Iter  100	Loss:2.8662815	Rec Loss:2.5595827	MMD:0.3027822	Sparsity_Loss:0.0000000	Pred_Loss:0.