In [1]:
import numpy as np

import artm
from artm import hARTM

import sys
sys.path.append('utils/')
# you need sklearn for simple loading
from sklearn.datasets import fetch_20newsgroups

import glob 
import os

In [10]:
hier = hARTM()

In [11]:
level0 = hier.add_level(num_topics=3)
level0.scores.add(artm.TopTokensScore(name='TopTokensScore', num_tokens=100, 
                                      class_id='text'))

In [4]:
from sklearn.datasets import fetch_20newsgroups
import os
import re

def load_20newsgroups(path='../data'):
    """
    Download train part of 20newsgroups collection. Simple preprocess it, 
    convert to vowpal wabbit format, save in path/20newsgroups directory.
    
    This function does nothing if file path/20newsgroups/20newsgroups_train.vw
    already exists.
    
    Parameters:
    -----------
    path: str
        The folder for the collection saving
    """
    if os.path.isfile('{}/20newsgroups/20newsgroups_train.vw'.format(path)):
        return None
    data = fetch_20newsgroups(data_home=path, subset='train', remove=('headers', 'footers', 'quotes'))
    if os.path.isfile('{}/20newsgroups'.format(path)):
        os.remove('{}/20newsgroups'.format(path))
    if not os.path.isdir('{}/20newsgroups'.format(path)):
        os.mkdir('{}/20newsgroups'.format(path))
    with open('{}/20newsgroups/20newsgroups_train.vw'.format(path), 'w') as f_output:
        for i, (document, document_class) in enumerate(zip(data['data'], data['target'])):
            content = " ".join(re.sub('[^a-z]', ' ', document.lower()).split())
            print(content)
            new_line = '{} |text {} |class_id {}\n'.format(i, content, document_class)
            f_output.write(new_line)
    os.remove('{}/20news-bydate.pkz'.format(path))
    
    
load_20newsgroups()

In [4]:
data_path = 'data/lang_data.vw'
batches_path = 'data/batches'

In [5]:
if len(glob.glob(os.path.join(batches_path + '*.batch'))) < 1:
    batch_vectorizer = artm.BatchVectorizer(data_path=data_path, data_format='vowpal_wabbit',
                                            target_folder=batches_path)
else:
    batch_vectorizer = artm.BatchVectorizer(data_path=batches_path, data_format='batches')

In [6]:
dictionary = artm.Dictionary('dictionary')
dictionary.gather(batches_path)
dictionary.filter(min_df=100, max_tf=17390)

artm.Dictionary(name=dictionary, num_entries=29493)

In [12]:
level0.initialize(dictionary=dictionary)
level0.fit_offline(batch_vectorizer, num_collection_passes=30)

Exception ignored in: <generator object tqdm_notebook.__iter__ at 0x7f32f78943b8>
Traceback (most recent call last):
  File "/home/mr9bit/bigartm/env/lib/python3.6/site-packages/tqdm/_tqdm_notebook.py", line 228, in __iter__
    self.sp(bar_style='danger')
AttributeError: 'tqdm_notebook' object has no attribute 'sp'


KeyboardInterrupt: 

In [8]:
for topic_name in level0.topic_names:
    print (topic_name + ': ')
    print (", ".join(level0.score_tracker['TopTokensScore'].last_tokens[topic_name]))

topic_0: 
"asoc, *mtd,, -1}, -1},, ##, 7,, loff_t, partition, shift,, nbits,, vreg,, tps6586x_regulator, ebit0,, ereg1,, ereg0,, soc_tplg, usb_mixer_elem_info, 32,, dev_err(tplg->dev,, errcode_t, unitid,, mtd, 7),, bb,, (!, _pname,, ene,, (ret, len), bits, ebit1,, 25000,, tps6586x_ldo0,, tps6586x_ldo,, from,, blk), *tplg,, mixer, pass, part->offset,, regulator, kcontrol, sndrv_ctl_elem_id_name_maxlen), p4_escr_emask_bit(p4_event_bsq_cache_reference,, min_uv,, supplyv2,, n_volt,, uv_step,, "vinldo678",, --
topic_1: 
""", inode, kctx);, js,, kctx,, extern, contexts, dentry, cores, ssize_t, none, kbasep_js_kctx_info, because, mali_false), ***, release, scheduled, queue, name, path, ;, affinity, following, (0x0000), error;, does, lock, msm_vidc_core, also, time, flags;, flags, call, kbase_debug_assert(kctx, kbase_debug_print_info(kbase_jm,, x, object, ctx, slot, printf, enum, (u32), change, core, *,, std, spin_unlock_irqrestore(&js_devdata->runpool_irq.lock,, returns, irq, might
topic_2: 


In [None]:
level1 = hier.add_level(num_topics=25, topic_names=['child_topic_' + str(i) for i in range(25)], 
                        parent_level_weight=1)
level1.regularizers.add(artm.HierarchySparsingThetaRegularizer(name="HierSp", tau=1.0))
level1.scores.add(artm.TopTokensScore(name='TopTokensScore', num_tokens=120, 
                                      class_id='text'))

In [None]:
level1.initialize(dictionary=dictionary)
level1.fit_offline(batch_vectorizer, num_collection_passes=30)

In [None]:
len(level0.get_psi()), len(level1.get_psi())


In [None]:
psi = level1.get_psi()

In [None]:
print ("Psi support:", psi.values.max(axis=1).min())

In [None]:
psi_threshold = 0.01
parent_counts = np.zeros(0)
for level_idx in range(1, hier.num_levels):
    psi = hier.get_level(level_idx).get_psi().values
    parent_counts = np.hstack((parent_counts, (psi > psi_threshold).sum(axis=1)))
print ("Mean parents count:", parent_counts.mean())


In [None]:
np.array(psi) * Ntw

In [None]:
batch = artm.messages.Batch()
batch_name = 'phi1.batch'

with open(batch_name, "rb") as f:
    batch.ParseFromString(f.read())
    
Ntw = np.zeros(len(level0.topic_names))
    
for i,item in enumerate(batch.item):
    for (token_id, token_weight) in zip(item.field[0].token_id, item.field[0].token_weight):
        Ntw[i] += token_weight

Nt1t0 = np.array(psi) * Ntw
psi_bayes = (Nt1t0 / Nt1t0.sum(axis=1)[:, np.newaxis]).T

In [None]:
indexes_child = np.argmax(psi_bayes, axis=0)

In [None]:
topic_parent_name = 'topic_4'
print(topic_parent_name + ':')
print(" ".join(level0.score_tracker['TopTokensScore'].last_tokens[topic_parent_name]))
print('')
i=9
for child in np.where(indexes_child == i)[0]:
    print('    ' + level1.topic_names[child] + ': ')
    print(" ".join(level1.score_tracker['TopTokensScore'].last_tokens[level1.topic_names[child]]))
    print('')

In [None]:
psi1 = level1.get_psi()
psi1

In [None]:
tokens0

In [None]:
tokens0 = level0.score_tracker["TopTokensScore"].last_tokens
tokens1 = level1.score_tracker["TopTokensScore"].last_tokens
for t, topic_name in enumerate(level0.topic_names):
    print (topic_name + ': ')
    for word in tokens0[topic_name]:
        print (word, end=' ')
    print()
    for s, topic_name1 in enumerate(level1.topic_names):
        if psi1[topic_name ][ topic_name1 ] > 0.05:
            print ("\t", topic_name1 + ': ')
            for word in tokens1[topic_name1]:    
                print (word, end=' ')
            print()
    print("=="*30)

In [None]:
psi1["topic_0"]["child_topic_0"]

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
z
import matplotlib.ticker as ticker
import matplotlib.cm as cm
import matplotlib as mpl


fig = plt.figure()
fig, ax = plt.subplots(1,1, figsize=(11,20))
heatplot = ax.imshow(psi1, cmap='hot')
ax.set_xticklabels(['child_topic_' + str(i) for i in range(25)], rotation=40)
ax.set_yticklabels(['topic_' + str(i) for i in range(25)])

tick_spacing = 1
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
