In [128]:
import os
import pickle
import sys
import numpy as np
import pandas as pd
import scipy.io
import torch
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.linear_model import LogisticRegression

sys.path.append('..')
import data
from etm import ETM

## Train LDA

In [188]:
data_path = './../data/my_20ng_2'

def load(path, prefix):
    return scipy.io.loadmat(os.path.join(path, f'{prefix}_counts.mat'))['counts'].squeeze(), scipy.io.loadmat(os.path.join(path, f'{prefix}_tokens.mat'))['tokens'].squeeze()

In [189]:
def get_csr(counts, tokens): 
    indptr = np.zeros(len(tokens)+1, dtype=np.uint32)
    for i in range(len(tokens)):
        indptr[i+1] = len(tokens[i].squeeze()) + indptr[i]
    tokens_flat, counts_flat = [], []
    for i in range(len(tokens)):
        doc_tokens = tokens[i].squeeze()
        doc_counts = counts[i].squeeze()
        tokens_flat.extend(doc_tokens.tolist())
        counts_flat.extend(doc_counts.tolist())
    return scipy.sparse.csr_matrix((np.array(counts_flat), np.array(tokens_flat), indptr))

In [190]:
with open(os.path.join(data_path, 'vocab.pkl'), 'rb') as f:
    vocab = pickle.load(f)
print(len(vocab))

1901


In [191]:
train_mat = get_csr(*load(data_path, 'bow_tr'))
X = train_mat.todense()
print(X.shape)

(1114, 1901)


In [192]:
k = 4
lda = LatentDirichletAllocation(n_components=k,
                                learning_method='online',
                                learning_decay=0.85,
                                learning_offset=10.,
                                evaluate_every=10,
                                verbose=1,
                                random_state=5).fit(X)

iteration: 1 of max_iter: 10
iteration: 2 of max_iter: 10
iteration: 3 of max_iter: 10
iteration: 4 of max_iter: 10
iteration: 5 of max_iter: 10
iteration: 6 of max_iter: 10
iteration: 7 of max_iter: 10
iteration: 8 of max_iter: 10
iteration: 9 of max_iter: 10
iteration: 10 of max_iter: 10, perplexity: 915.7583


In [193]:
save_path = './../results'
with open(os.path.join(save_path, f'lda_{k}_{os.path.basename(data_path)}.pkl'), 'wb') as f:
    pickle.dump(lda, f)

In [194]:
def get_topic_words(dists, n_top_words):
    topics = []
    for dist in dists:
        top_word_idxs = np.argsort(dist)[::-1][:n_top_words]
        topics.append([vocab[i] for i in top_word_idxs])
    return topics

In [195]:
topics = get_topic_words(lda.components_, 20)
for i, t in enumerate(topics):
    print(f'[{i+1}] {t}\n')

[1] ['god', 'one', 'people', 'edu', 'writes', 'would', 'think', 'article', 'atheism', 'say', 'believe', 'like', 'religion', 'must', 'system', 'atheists', 'well', 'many', 'know', 'something']

[2] ['drive', 'scsi', 'controller', 'card', 'ide', 'system', 'one', 'bus', 'disk', 'drives', 'get', 'use', 'would', 'hard', 'pc', 'edu', 'know', 'like', 'problem', 'also']

[3] ['com', 'jesus', 'would', 'one', 'edu', 'writes', 'article', 'know', 'matthew', 'people', 'said', 'time', 'like', 'could', 'think', 'see', 'john', 'tek', 'really', 'vice']

[4] ['edu', 'com', 'people', 'writes', 'time', 'article', 'book', 'jesus', 'liar', 'would', 'christian', 'os', 'first', 'read', 'one', 'comp', 'ca', 'david', 'saturn', 'wwc']



## Evaluation

In [137]:
def get(data_path, mode='test'):
    if mode == 'test':
        prefix = 'ts'
    elif mode == 'train':
        prefix = 'tr'
    counts, tokens = load(data_path, f'bow_{prefix}')
    test_mat = get_csr(counts, tokens)
    X = test_mat.todense()
    with open(os.path.join(data_path, 'labels.pkl'), 'rb') as f:
        labels = pickle.load(f)[mode]
    with open(os.path.join(data_path, 'vocab.pkl'), 'rb') as f:
        vocab = pickle.load(f)
    print(X.shape)
    return X, counts, tokens, labels, vocab

# Squish all the sub-categories together (on the full dataset)
def collect_labels(labels):
    new_labels = []
    for label in labels:
        if label == 0:
            new_labels.append(0)  # religion
        elif label <= 5:
            new_labels.append(1)  # computers
        elif label == 6:
            new_labels.append(2)  # sale
        elif label <= 8:
            new_labels.append(3)  # cars
        elif label <= 10:
            new_labels.append(4)  # sports
        elif label <= 14:
            new_labels.append(5)  # science
        elif label == 15:
            new_labels.append(0)
        elif label <= 17:
            new_labels.append(6)  # politics
        else:
            new_labels.append(0)
    return new_labels

### Clustering

In [138]:
# LDA
def lda_doc_topic(model_path, X):
    with open(model_path, 'rb') as f:
        lda_model = pickle.load(f)
    return lda_model.transform(X)
    #return doc_topic_dists.argmax(axis=1)
    
 # ETM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def etm_doc_topic(model_path, counts, tokens, vocab, truncate_first=None):
    with open(model_path, 'rb') as f:
        etm_model = torch.load(f)
    etm_model.eval()
    with torch.no_grad():
        all_data = data.get_batch(tokens, counts, range(len(counts)), len(vocab), device)
        all_data_norm = all_data / (all_data.sum(1).unsqueeze(1))
        thetas, _ = etm_model.get_theta(all_data_norm)
    if truncate_first is not None:
        return thetas.numpy()[:, 7:]
        #return thetas.numpy()[:, 7:].argmax(axis=1)
    else:
        return thetas.numpy()
        #return thetas.numpy().argmax(axis=1)
        
def cluster(thetas):
    return thetas.argmax(axis=1)

In [167]:
lda_path1 = '../results/lda_7_my_20ng.pkl'
lda_path2 = '../results/lda_4_my_20ng_2.pkl'
lda_path3 = '../results/lda_20_my_20ng_rare.pkl'
etm_path1 = '../results/etm_20ng_K_7_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0'
etm_path2 = '../results/etm_20ng_K_7_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0_computers_cars_sports_science_sale_politics_religion_SeedLd_1.0'
etm_path3 = '../results/etm_20ng_K_7_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0_computers_cars_sports_science_sale_politics_religion_SeedLd_0.1'
etm_path4 = '../results/etm_20ng_K_14_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0_computers_cars_sports_science_sale_politics_religion_SeedLd_1.0'
etm_path5 = '../results/etm_20ng_K_4_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0'
etm_path6 = '../results/etm_20ng_K_4_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0_cars_religion_science_hardware_SeedLd_1.0'
etm_path7 = '../results/etm_20ng_rare_K_20_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0'
etm_path8 = '../results/etm_20ng_rare_K_20_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_0_cars_SeedLd_1.0'

### Metrics

In [79]:
# very slow n^2 operation...
def confusion(y, yhat):
    cf = {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}
    for i in range(len(y)):
        for j in range(i+1, len(y)):
            if y[i] != y[j] and yhat[i] != yhat[j]:
                cf['tn'] += 1
            elif y[i] != y[j] and yhat [i] == yhat[j]:
                cf['fp'] += 1
            elif y[i] == y[j] and yhat[i] != yhat[j]:
                cf['fn'] += 1
            else:
                cf['tp'] += 1
    print(cf)
    return cf

def f1_score(cf):
    precision = cf['tp'] / (cf['tp'] + cf['fp'])
    recall = cf['tp'] / (cf['tp'] + cf['fn'])
    F = 2 * (precision * recall) / (precision + recall)
    print(f'Precision: {round(precision, 4)}, recall: {round(recall, 4)}, F measure: {round(F, 4)}')
    return precision, recall, F

In [80]:
# 7 topics
X_test_7, test_counts_7, test_tokens_7, labels_7, vocab_7 = get('./../data/my_20ng', mode='test')
labels_7 = collect_labels(labels_7)

lda_dt1 = lda_doc_topic(lda_path1, X_test_7)
etm_dt1 = etm_doc_topic(etm_path1, test_counts_7, test_tokens_7, vocab_7)
etm_dt2 = etm_doc_topic(etm_path2, test_counts_7, test_tokens_7, vocab_7)
clust_lda1 = cluster(lda_dt1)
clust_etm1 = cluster(etm_dt1)
clust_etm2 = cluster(etm_dt2)

cf_lda1 = confusion(labels_7, clust_lda1)
_ = f1_score(cf_lda1)
cf_etm1 = confusion(labels_7, clust_etm1)
_ = f1_score(cf_etm1)
cf_etm2 = confusion(labels_7, clust_etm2)
_ = f1_score(cf_etm2)

(5379, 3455)
5379




{'tp': 1044484, 'fp': 2201984, 'tn': 9724623, 'fn': 1493040}
Precision: 0.3217, recall: 0.4116, F measure: 0.3612
{'tp': 993954, 'fp': 1436480, 'tn': 10490127, 'fn': 1543570}
Precision: 0.409, recall: 0.3917, F measure: 0.4001
{'tp': 1937099, 'fp': 8906576, 'tn': 3020031, 'fn': 600425}
Precision: 0.1786, recall: 0.7634, F measure: 0.2895


In [99]:
# 4 topics
X_test_4, test_counts_4, test_tokens_4, labels_4, vocab_4 = get('./../data/my_20ng_2', mode='test')

lda_dt2 = lda_doc_topic(lda_path2, X_test_4)
etm_dt5 = etm_doc_topic(etm_path5, test_counts_4, test_tokens_4, vocab_4)
etm_dt6 = etm_doc_topic(etm_path6, test_counts_4, test_tokens_4, vocab_4)
clust_lda2 = cluster(lda_dt2)
clust_etm5 = cluster(etm_dt5)
clust_etm6 = cluster(etm_dt6)

cf_lda2 = confusion(labels_4, clust_lda2)
_ = f1_score(cf_lda2)
cf_etm5 = confusion(labels_4, clust_etm5)
_ = f1_score(cf_etm5)
cf_etm6 = confusion(labels_4, clust_etm6)
_ = f1_score(cf_etm6)

(519, 1901)
519




{'tp': 53161, 'fp': 624, 'tn': 66116, 'fn': 14520}
Precision: 0.9884, recall: 0.7855, F measure: 0.8753
{'tp': 37829, 'fp': 4610, 'tn': 62130, 'fn': 29852}
Precision: 0.8914, recall: 0.5589, F measure: 0.6871
{'tp': 57409, 'fp': 3448, 'tn': 63292, 'fn': 10272}
Precision: 0.9433, recall: 0.8482, F measure: 0.8933


### Classification

In [177]:
X_train, train_counts, train_tokens, train_labels, vocab = get('./../data/my_20ng_rare', mode='train')
X_test, test_counts, test_tokens, test_labels, vocab = get('./../data/my_20ng_rare', mode='test')
train_labels = collect_labels(train_labels)
test_labels = collect_labels(test_labels)
rare_idxs = list(np.argwhere(np.array(test_labels) == 3).squeeze())  # cars only
test_labels_rare = [test_labels[i] for i in rare_idxs]
logit_params = {'solver': 'liblinear', 'multi_class': 'ovr', 'class_weight': 'balanced'}

(10744, 5350)
(4951, 5350)


In [182]:
lda_dt_train = lda_doc_topic(lda_path3, X_train)
lda_dt_test = lda_doc_topic(lda_path3, X_test)

In [183]:
logit_lda = LogisticRegression(**logit_params).fit(lda_dt_train, train_labels)
print(logit_lda.score(lda_dt_test, test_labels))

0.7612603514441527


In [168]:
etm_dt_train1 = etm_doc_topic(etm_path7, train_counts, train_tokens, vocab)
etm_dt_test1 = etm_doc_topic(etm_path7, test_counts, test_tokens, vocab)
etm_dt_train2 = etm_doc_topic(etm_path8, train_counts, train_tokens, vocab)
etm_dt_test2 = etm_doc_topic(etm_path8, test_counts, test_tokens, vocab)

In [181]:
logit_etm1 = LogisticRegression(**logit_params).fit(etm_dt_train1, train_labels)
logit_etm2 = LogisticRegression(**logit_params).fit(etm_dt_train2, train_labels)
print(logit_etm1.score(etm_dt_test1, test_labels))
print(logit_etm2.score(etm_dt_test2, test_labels))

0.6927893354877802
0.6802666128054938


In [184]:
lda_dt_test_rare = lda_doc_topic(lda_path3, X_test[rare_idxs])
logit_lda.score(lda_dt_test_rare, test_labels_rare)

0.6460176991150443

In [185]:
etm_dt_test_rare1 = etm_doc_topic(etm_path7, test_counts[rare_idxs], test_tokens[rare_idxs], vocab)
etm_dt_test_rare2 = etm_doc_topic(etm_path8, test_counts[rare_idxs], test_tokens[rare_idxs], vocab)
print(logit_etm1.score(etm_dt_test_rare1, test_labels_rare))
print(logit_etm2.score(etm_dt_test_rare2, test_labels_rare))



0.5663716814159292
0.6637168141592921


In [212]:
def get_topic_words_etm(model_path, num_topics=4, num_words=10):
    with open(model_path, 'rb') as f:
        model = torch.load(f)
    model.eval()
    with torch.no_grad():
        beta = model.get_beta()
        topic_indices = list(np.random.choice(num_topics, 10)) # 10 random topics
        for k in range(num_topics):#topic_indices:
            gamma = beta[k]
            top_words = list(gamma.cpu().numpy().argsort()[-num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            print('Topic {}: {}'.format(k, topic_words)) 

In [214]:
get_topic_words_etm(etm_path5)
print()
get_topic_words_etm(etm_path6)

Topic 0: ['scsi', 'drive', 'system', 'drives', 'controller', 'card', 'disk', 'bus', 'ide']
Topic 1: ['like', 'use', 'get', 'one', 'would', 'time', 'problem', 'please', 'could']
Topic 2: ['edu', 'writes', 'com', 'article', 'wrote', 'livesey', 'matthew', 'keith', 'alt']
Topic 3: ['god', 'say', 'people', 'believe', 'religion', 'would', 'atheists', 'atheism', 'thing']

Topic 0: ['car', 'drivers', 'buphy', 'livesey', 'uiuc', 'bobbe', 'mathew', 'bus', 'irq']
Topic 1: ['religion', 'religions', 'atheism', 'atheists', 'atheist', 'anybody', 'religious', 'islam', 'faith']
Topic 2: ['science', 'edu', 'physics', 'scientific', 'com', 'california', 'beauchaine', 'theology', 'kevin']
Topic 3: ['hardware', 'disk', 'disks', 'server', 'drives', 'cpu', 'software', 'config', 'motherboard']
