In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import artm
import re
import pickle

import os
import glob

from collections import Counter

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
print(artm.version())

0.9.0


## Загрузка данных в коллекцию

In [10]:
batch_vectorizer = None
if len(glob.glob(os.path.join('kos', '*.batch'))) < 1:
    print('here')
    batch_vectorizer = artm.BatchVectorizer(data_path='kos', 
                                            data_format='bow_uci', 
                                            collection_name='kos', 
                                            target_folder='kos')
else:
    batch_vectorizer = artm.BatchVectorizer(data_path='kos', 
                                            data_format='batches')

here


In [13]:
dictionary = artm.Dictionary()

if not os.path.isfile('kos/dictionary.dict'):
    dictionary.gather(data_path=batch_vectorizer.data_path)
    dictionary.save(dictionary_path='kos/dictionary.dict')

dictionary.load(dictionary_path='kos/dictionary.dict')
dictionary.load(dictionary_path='kos/dictionary.dict')

In [16]:
dictionary = batch_vectorizer.dictionary

In [30]:
num_topics = 10
num_iter = 10
class_ids = {'@default_class': 1.0}

In [48]:
model = artm.ARTM(num_topics=num_topics,
                      dictionary=dictionary,
                      cache_theta=True,
                      reuse_theta=True,
                      theta_columns_naming='title',
                      theta_name='theta',
                      class_ids=class_ids)
    # scores without class_ids
model.scores.add(artm.PerplexityScore(name='PerplexityScore', dictionary=dictionary))
model.scores.add(artm.SparsityThetaScore(name='SparsityThetaScore',
                                             topic_names=model.topic_names))

## Прогон модели

In [49]:
model.fit_offline(batch_vectorizer, num_collection_passes=num_iter)

In [50]:
phi = model.get_phi()
theta = model.get_theta()

In [51]:
phi.head()

Unnamed: 0,topic_0,topic_1,topic_2,topic_3,topic_4,topic_5,topic_6,topic_7,topic_8,topic_9
predebate,1.559367e-05,4.771524e-10,1.8e-05,3e-06,0.0,8.889098e-05,0.0,8.481417e-13,5.111464e-15,1.340616e-09
barbour,8.666757e-12,0.0,0.0,7e-06,5.805264e-07,0.000253644,1.983245e-05,1.508367e-12,1.180926e-12,1.898581e-07
bumblebums,0.0,9.774648e-14,0.000261,0.0,0.0004716354,0.0,0.0,0.0,4.017246e-13,0.0
spindizzy,5.515496e-16,6.895757e-14,0.000273,0.0,0.0004903445,0.0,0.0,0.0,2.716285e-14,0.0
mcentee,0.0,0.0,0.0,0.0,2.02694e-13,3.915129e-14,1.446391e-16,0.0,0.0005604493,1.046842e-15


In [52]:
theta.head()

Unnamed: 0,3001,3002,3003,3004,3005,3006,3007,3008,3009,3010,...,1991,1992,1993,1994,1995,1996,1997,1998,1999,2000
topic_0,0.317986,0.0,0.296328,0.07932479,0.006957777,0.0,0.1203053,0.0,6.219731e-14,2.655639e-14,...,0.405047,0.022242,0.08101609,2.016652e-16,0.0008619546,2.586948e-08,0.0,1.947319e-11,2.785038e-10,0.0
topic_1,0.357562,0.0,0.027244,1.060059e-13,0.000428216,0.0,0.03307888,0.0,0.0,0.0,...,0.00031,1.1e-05,6.015854e-07,0.0,0.0,0.307438,4.563063e-16,0.009136071,0.0,0.0
topic_2,0.0,0.032342,0.0,1.318244e-10,0.009497471,1.0,1.666639e-07,0.0,0.0,1.506624e-08,...,0.0,0.0,0.0,0.8237324,3.299588e-14,0.01418151,6.701723e-15,0.008544581,0.0,0.0
topic_3,0.000943,0.0,0.00446,0.2047598,0.1767253,1.004865e-12,0.0009707098,0.083137,0.02864749,0.004849271,...,0.242819,0.014091,0.01657027,8.908503e-09,0.3865769,0.1513948,0.1362417,0.0001340713,0.2210319,0.057991
topic_4,0.219647,0.000946,0.519475,0.0005952622,4.606396e-13,1.340456e-08,0.03212908,0.0,1.567544e-15,3.89103e-06,...,0.001198,0.173389,3.120962e-07,0.1184843,0.03170368,0.0,0.0,1.324158e-10,0.0,0.0


In [53]:
phi.shape, theta.shape

((6906, 10), (10, 3430))

## Подмена Фи

In [54]:
def init_custom_phi(model, class_id, class_id_phi):
    tm_info, phi_ref = model.master.attach_model(model=model.model_pwt)

    fields = tm_info.ListFields()
    token_order = np.array(fields[3][1])
    topic_order = np.array(fields[2][1])
    class_id_order = np.array(fields[5][1])

    new_phi = pd.DataFrame(data=phi_ref,
                           index=token_order,
                           columns=topic_order)

    mask = class_id_order == class_id
    current_phi = new_phi.iloc[mask].copy()
    current_phi.update(class_id_phi)
    new_phi.iloc[mask] = current_phi

    np.copyto(phi_ref, new_phi.values)
    return model

In [55]:
phi_new = pd.DataFrame(data=np.random.random(phi.shape),
                       columns=phi.columns,
                       index=phi.index)

In [56]:
phi_new.head()

Unnamed: 0,topic_0,topic_1,topic_2,topic_3,topic_4,topic_5,topic_6,topic_7,topic_8,topic_9
predebate,0.600288,0.113186,0.164752,0.391664,0.205864,0.20996,0.128473,0.184792,0.279693,0.281616
barbour,0.873536,0.535326,0.359308,0.032112,0.136268,0.640242,0.014927,0.012221,0.672071,0.148203
bumblebums,0.961663,0.108372,0.098168,0.092543,0.547058,0.109496,0.692209,0.381485,0.256309,0.742678
spindizzy,0.513241,0.552035,0.614153,0.692317,0.204203,0.980742,0.163537,0.718826,0.972189,0.934545
mcentee,0.73033,0.80026,0.706824,0.383587,0.228841,0.74354,0.562666,0.277414,0.359004,0.298519


In [57]:
model_new = init_custom_phi(model, class_id='@default_class', class_id_phi=phi_new)

In [58]:
model_new.get_phi().head()

Unnamed: 0,topic_0,topic_1,topic_2,topic_3,topic_4,topic_5,topic_6,topic_7,topic_8,topic_9
predebate,0.600288,0.113186,0.164752,0.391664,0.205864,0.20996,0.128473,0.184792,0.279693,0.281616
barbour,0.873536,0.535326,0.359308,0.032112,0.136268,0.640242,0.014927,0.012221,0.672071,0.148203
bumblebums,0.961663,0.108372,0.098168,0.092543,0.547058,0.109496,0.692209,0.381485,0.256309,0.742678
spindizzy,0.513241,0.552035,0.614153,0.692317,0.204203,0.980742,0.163537,0.718826,0.972189,0.934545
mcentee,0.73033,0.80026,0.706824,0.383587,0.228841,0.74354,0.562666,0.277414,0.359004,0.298519


In [59]:
model_new.fit_offline(batch_vectorizer, num_collection_passes=num_iter)

In [60]:
model_new.get_phi().head()

Unnamed: 0,topic_0,topic_1,topic_2,topic_3,topic_4,topic_5,topic_6,topic_7,topic_8,topic_9
predebate,5.69842e-05,3.413876e-07,1.8e-05,1.11088e-05,0.0,4.518716e-06,0.0,5.516832e-06,2.570949e-05,1.198883e-06
barbour,3.04126e-11,0.0,0.0,1.60057e-07,2.052859e-09,0.0003810915,1e-06,2.18601e-15,3.417976e-14,5.736089e-10
bumblebums,1.114285e-15,1.264632e-15,0.000406,0.0,1.781738e-06,2.331677e-13,0.0,0.0,7.513892e-15,0.0
spindizzy,2.843821e-16,1.028806e-15,0.000424,0.0,4.995859e-07,1.880039e-13,0.0,0.0,3.983614e-15,0.0
mcentee,0.0,0.0,0.0,0.0,0.0,2.437497e-10,0.0,0.0,0.0004914535,0.0


## Подмена Фи

In [115]:
def init_custom_theta(model, class_id_phi):
    tm_info, phi_ref = model.master.attach_model(model=model.theta_name)

    fields = tm_info.ListFields()
    doc_order = np.array(fields[3][1])
    topic_order = np.array(fields[2][1])
    
    print(phi_ref.shape)

    new_phi = pd.DataFrame(data=phi_ref.T,
                           index=topic_order,
                           columns=doc_order)

    print(new_phi.shape)
    current_phi = new_phi.copy()
    current_phi.update(class_id_phi)
    new_phi = current_phi.copy()

    np.copyto(phi_ref, new_phi.values.T)
    return model

In [116]:
theta_new = pd.DataFrame(data=np.random.random(theta.shape),
                       columns=theta.columns,
                       index=theta.index)

In [117]:
theta_new.head()

Unnamed: 0,3001,3002,3003,3004,3005,3006,3007,3008,3009,3010,...,1991,1992,1993,1994,1995,1996,1997,1998,1999,2000
topic_0,0.253828,0.105636,0.522545,0.900549,0.711446,0.15762,0.954709,0.401875,0.148614,0.301425,...,0.952692,0.367139,0.990485,0.717603,0.442356,0.144267,0.915164,0.759758,0.424313,0.298619
topic_1,0.375843,0.814147,0.653294,0.862777,0.35588,0.215285,0.305289,0.526001,0.341793,0.818205,...,0.895196,0.388949,0.743757,0.78234,0.706907,0.342647,0.122608,0.434243,0.665976,0.880371
topic_2,0.5055,0.592421,0.090751,0.490024,0.341837,0.669631,0.934827,0.792667,0.761241,0.081708,...,0.327735,0.25379,0.106632,0.661703,0.710091,0.619764,0.626211,0.38918,0.242926,0.268956
topic_3,0.500565,0.540554,0.551953,0.307033,0.699627,0.347081,0.786685,0.135277,0.06295,0.301437,...,0.287431,0.529197,0.302734,0.00298,0.979638,0.948853,0.182953,0.841014,0.900338,0.808614
topic_4,0.925963,0.738398,0.649108,0.215154,0.046559,0.930696,0.01968,0.020679,0.615051,0.672985,...,0.555221,0.977524,0.459417,0.499594,0.009958,0.845913,0.861079,0.675615,0.13211,0.564682


In [118]:
theta_new.shape

(10, 3430)

In [119]:
model_new2 = init_custom_theta(model, class_id_phi=theta_new)

(3430, 10)
(10, 3430)


In [120]:
model_new2.get_phi().head()

Unnamed: 0,topic_0,topic_1,topic_2,topic_3,topic_4,topic_5,topic_6,topic_7,topic_8,topic_9
predebate,3.503246e-44,3.413876e-07,6.854389000000001e-25,1.401298e-45,1.819909e-37,1.401298e-45,1.152318e-24,1.401298e-45,3.0127920000000002e-43,1.961818e-44
barbour,5.6164039999999995e-42,5.571563e-42,0.0,0.0,1.180216e-24,1.401298e-45,,,0.0,0.0
bumblebums,1.821259e-37,1.401298e-45,1.261169e-44,0.0,6.854282e-25,1.401298e-45,1.166505e-24,1.401298e-45,1.166704e-24,1.401298e-45
spindizzy,1.166682e-24,1.401298e-45,1.166693e-24,1.401298e-45,1.166698e-24,1.401298e-45,1.166687e-24,1.401298e-45,1.166665e-24,1.401298e-45
mcentee,1.166671e-24,1.401298e-45,1.722129e-25,1.401298e-45,1.5495650000000002e-25,1.401298e-45,,,1.401298e-45,0.0


In [121]:
model_new2.get_theta().head()

Unnamed: 0,3001,3002,3003,3004,3005,3006,3007,3008,3009,3010,...,1991,1992,1993,1994,1995,1996,1997,1998,1999,2000
topic_0,0.253828,0.105636,0.522545,0.900549,0.711446,0.15762,0.954709,0.401875,0.148614,0.301425,...,0.952692,0.367139,0.990485,0.717603,0.442356,0.144267,0.915164,0.759758,0.424313,0.298618
topic_1,0.375843,0.814147,0.653294,0.862777,0.35588,0.215285,0.305289,0.526001,0.341793,0.818205,...,0.895196,0.388949,0.743757,0.78234,0.706907,0.342647,0.122608,0.434243,0.665976,0.880371
topic_2,0.5055,0.592421,0.090751,0.490024,0.341837,0.669631,0.934827,0.792667,0.761241,0.081708,...,0.327735,0.25379,0.106632,0.661703,0.710091,0.619764,0.626211,0.38918,0.242926,0.268956
topic_3,0.500565,0.540554,0.551953,0.307033,0.699627,0.347081,0.786685,0.135277,0.06295,0.301437,...,0.287431,0.529197,0.302734,0.00298,0.979638,0.948853,0.182953,0.841014,0.900338,0.808614
topic_4,0.925963,0.738398,0.649108,0.215154,0.046559,0.930696,0.01968,0.020679,0.615051,0.672985,...,0.555221,0.977524,0.459417,0.499594,0.009958,0.845913,0.861079,0.675615,0.13211,0.564682
