In [1]:
import glob
import itertools
import os
import time
import sys

import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
sys.path.insert(0, '../OptimalNumberOfTopics/')     # topnum

In [3]:
%load_ext autoreload
%autoreload 2 

In [4]:
from topicnet.cooking_machine.models import TopicModel
from topicnet.cooking_machine.dataset import Dataset

from topnum.data.vowpal_wabbit_text_collection import VowpalWabbitTextCollection
from topnum.search_methods.optimize_scores_method import OptimizeScoresMethod
from topnum.utils import (
    read_corpus_config, split_into_train_test, 
    build_every_score, monotonity_and_std_analysis, 
    trim_config, classify_curve, SCORES_DIRECTION, load_models_from_disk
)
from topnum.model_constructor import KnownModel, PARAMS_EXPLORED
from topnum.utils import estimate_num_iterations_for_convergence

from collections import defaultdict

In [5]:


EXPERIMENTS_DICT = {
    "20NewsGroups": "/data/_tmp_alekseev/OptNumExperiments/AllDatasets/20NG_20NG_NEW",
    # "RuWikiGood": 
    "StackOverflow": "/data/_tmp_alekseev/OptNumExperiments/AllDatasets/SO_SO_NEW",
    "WikiRef220": "/data/_tmp_alekseev/OptNumExperiments/AllDatasets/WRef_NEW/",
    "PostNauka": "/data/_tmp_alekseev/OptNumExperiments/AllDatasets/PN_PN_NEW",
    # "Reuters": "/data/_tmp_alekseev/OptNumExperiments/AllDatasets/"
    "Brown": "/data/_tmp_alekseev/OptNumExperiments/AllDatasets/Brown_Brown_NEW",
}


In [8]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [9]:
EXPERIMENT_NAME_TEMPLATE = "_{mfv}_{param_id}_{seed}"

configs_dir = os.path.join('..', 'OptimalNumberOfTopics', 'topnum', 'configs')
configs_mask = os.path.join(configs_dir, '*.yml')

data_results = []
optimum_tolerance = 0.07

for config_file in glob.glob(configs_mask):
    config = read_corpus_config(config_file)

    if config['name'] in EXPERIMENTS_DICT:
        print(config['name'])
        experiment_directory = EXPERIMENTS_DICT[config['name']]
        
        for model_family in KnownModel:
            #if model_family != KnownModel.ARTM:
            #    continue
            print(model_family, end=", ")
            tmp = "WRef_test" if config['name'] == "WikiRef220" else config['batches_prefix']
            template = tmp + EXPERIMENT_NAME_TEMPLATE.format(
                mfv=model_family.value, param_id="{}", seed="{}"
            )

            details = defaultdict(dict)

            all_subexperems_mask = os.path.join(
                experiment_directory, template.format("*", "*")
            )

            for entry in glob.glob(all_subexperems_mask):
                experiment_name = entry.split("/")[-1]

                result, detailed_result = load_models_from_disk(
                    experiment_directory, experiment_name
                )

                for score in detailed_result.keys():
                    if SCORES_DIRECTION[score] is not None:
                        details[score][experiment_name] = detailed_result[score].T
            for score in details.keys():
                for experiment_name, data in details[score].items():
                    
                    *name_base, param_id, seed = experiment_name.split("_")
                    seed = int(seed)
                    my_data = data.T.mean(axis=0)

                    score_direction = SCORES_DIRECTION[score]
                    colored_values, curve_type = classify_curve(my_data, optimum_tolerance, score_direction)
                    
                    data_results.append(
                        [
                            config['name'], model_family.value, param_id, seed, score, 
                            str(curve_type).split(".")[1], 
                            list(colored_values[colored_values.notna()].index)
                        ]
                    )
            print()



20NewsGroups
KnownModel.LDA, 

  return np.nanmean(a, axis, out=out, keepdims=keepdims)



KnownModel.PLSA, 
KnownModel.SPARSE, 
KnownModel.TLESS, 
KnownModel.DECORRELATION, 
KnownModel.ARTM, 
StackOverflow
KnownModel.LDA, 
KnownModel.PLSA, 
KnownModel.SPARSE, 
KnownModel.TLESS, 
KnownModel.DECORRELATION, 
KnownModel.ARTM, 
WikiRef220
KnownModel.LDA, 
KnownModel.PLSA, 
KnownModel.SPARSE, 
KnownModel.TLESS, 
KnownModel.DECORRELATION, 
KnownModel.ARTM, 
PostNauka
KnownModel.LDA, 
KnownModel.PLSA, 
KnownModel.SPARSE, 
KnownModel.TLESS, 
KnownModel.DECORRELATION, 
KnownModel.ARTM, 
Brown
KnownModel.LDA, 
KnownModel.PLSA, 
KnownModel.SPARSE, 
KnownModel.TLESS, 
KnownModel.DECORRELATION, 
KnownModel.ARTM, 


In [10]:
df = pd.DataFrame(data=data_results, columns=["corpus", "model_family", "parameters_id", "seed", "score", "curve_type", "optimums"])

df

Unnamed: 0,corpus,model_family,parameters_id,seed,score,curve_type,optimums
0,20NewsGroups,LDA,1,2,SparsityThetaScore,EMPTY,"[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38]"
1,20NewsGroups,LDA,0,1,SparsityThetaScore,EMPTY,"[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38]"
2,20NewsGroups,LDA,2,2,SparsityThetaScore,EMPTY,"[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38]"
3,20NewsGroups,LDA,0,0,SparsityThetaScore,EMPTY,"[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38]"
4,20NewsGroups,LDA,1,1,SparsityThetaScore,EMPTY,"[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38]"
...,...,...,...,...,...,...,...
9931,Brown,ARTM,0,2,toptok1,PEAK,[9]
9932,Brown,ARTM,8,1,toptok1,PEAK,[12]
9933,Brown,ARTM,7,0,toptok1,OUTSIDE,[6]
9934,Brown,ARTM,11,2,toptok1,PEAK,[15]


In [11]:
df.curve_type.value_counts()

OUTSIDE         6767
INTERVAL         953
PEAK             843
JUMP_OUTSIDE     556
EMPTY            512
JUMPING          305
Name: curve_type, dtype: int64

In [13]:
df.query("corpus == 'WikiRef220' and curve_type != 'EMPTY'")

Unnamed: 0,corpus,model_family,parameters_id,seed,score,curve_type,optimums
3906,WikiRef220,LDA,2,0,TopicKernel@lemmatized.average_contrast,OUTSIDE,[2]
3907,WikiRef220,LDA,1,2,TopicKernel@lemmatized.average_contrast,OUTSIDE,[2]
3908,WikiRef220,LDA,2,2,TopicKernel@lemmatized.average_contrast,OUTSIDE,[2]
3909,WikiRef220,LDA,1,0,TopicKernel@lemmatized.average_contrast,OUTSIDE,[2]
3910,WikiRef220,LDA,0,2,TopicKernel@lemmatized.average_contrast,OUTSIDE,[2]
...,...,...,...,...,...,...,...
5899,WikiRef220,ARTM,0,0,toptok1,PEAK,[5]
5900,WikiRef220,ARTM,1,0,toptok1,PEAK,[5]
5901,WikiRef220,ARTM,8,0,toptok1,PEAK,[13]
5902,WikiRef220,ARTM,8,2,toptok1,PEAK,[19]


In [51]:
# df.pivot_table(index=['score', 'curve_type'], aggfunc='count')
table = df.query("model_family != 'ARTM'").pivot_table(
    values="seed", 
    index=['score'], columns=['curve_type'], aggfunc='count', 
    fill_value=0
)



In [52]:
table.sort_values(by=['INTERVAL'])

curve_type,EMPTY,INTERVAL,JUMPING,JUMP_OUTSIDE,OUTSIDE,PEAK
score,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
perp,0,0,0,3,177,0
intra,0,0,0,0,135,9
holdout_perp,0,0,0,3,177,0
SparsityPhiScore@lemmatized,18,0,0,0,54,0
SparsityPhiScore@word,27,0,0,0,81,0
calhar,0,0,0,0,157,23
TopicKernel@lemmatized.average_purity,0,1,0,1,68,2
TopicKernel@word.average_contrast,0,1,1,3,100,3
SparsityThetaScore,74,2,1,7,91,5
diversity_jensenshannon_False,11,2,1,8,145,13


In [53]:
table.sort_values(by=['PEAK'])

curve_type,EMPTY,INTERVAL,JUMPING,JUMP_OUTSIDE,OUTSIDE,PEAK
score,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
BIC_sparsity_False,0,14,0,0,166,0
holdout_perp,0,0,0,3,177,0
MDL_sparsity_False,0,23,0,0,157,0
perp,0,0,0,3,177,0
SparsityPhiScore@lemmatized,18,0,0,0,54,0
SparsityPhiScore@word,27,0,0,0,81,0
renyi_2,0,63,3,15,99,0
AIC_sparsity_False,0,72,4,1,102,1
AIC_sparsity_True,0,22,0,1,156,1
TopicKernel@lemmatized.average_contrast,0,2,0,7,61,2


In [54]:
df.query("curve_type == 'PEAK' and score == 'toptok1'")



Unnamed: 0,corpus,model_family,parameters_id,seed,score,curve_type,optimums,interval_length
244,20NewsGroups,LDA,0,1,toptok1,PEAK,[8],1
246,20NewsGroups,LDA,0,0,toptok1,PEAK,[5],1
247,20NewsGroups,LDA,1,1,toptok1,PEAK,[5],1
333,20NewsGroups,PLSA,0,0,toptok1,PEAK,[8],1
335,20NewsGroups,PLSA,0,1,toptok1,PEAK,[8],1
...,...,...,...,...,...,...,...,...
9928,Brown,ARTM,2,0,toptok1,PEAK,[12],1
9930,Brown,ARTM,11,1,toptok1,PEAK,[21],1
9931,Brown,ARTM,0,2,toptok1,PEAK,[9],1
9932,Brown,ARTM,8,1,toptok1,PEAK,[12],1


In [97]:
from functools import reduce

df2 = df.query("curve_type != 'OUTSIDE'").groupby(["corpus", "model_family", 'parameters_id', 'score']).agg(
    random_intersection=pd.NamedAgg(
        column='optimums', 
        aggfunc=lambda data: reduce(lambda a, b: set(a) & set(b), data))
)

df2

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,random_intersection
corpus,model_family,parameters_id,score,Unnamed: 4_level_1
20NewsGroups,ARTM,0,MDL_sparsity_False,"{12, 15, 18, 21, 24, 27}"
20NewsGroups,ARTM,0,SparsityThetaScore,"[30, 36, 39]"
20NewsGroups,ARTM,0,TopicKernel@word.average_contrast,{6}
20NewsGroups,ARTM,0,TopicKernel@word.average_purity,{6}
20NewsGroups,ARTM,0,diversity_cosine_True,"{33, 3, 36, 6, 39, 9, 12, 15, 18, 21, 24, 27, 30}"
...,...,...,...,...
WikiRef220,sparse,3,diversity_jensenshannon_True,[4]
WikiRef220,sparse,3,intra,{}
WikiRef220,sparse,3,renyi_0.5,{14}
WikiRef220,sparse,3,renyi_1,"{17, 19}"


In [98]:
df2.loc[('WikiRef220',)]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,random_intersection
model_family,parameters_id,score,Unnamed: 3_level_1
ARTM,0,AIC_sparsity_False,{7}
ARTM,0,BIC_sparsity_False,"[4, 5]"
ARTM,0,BIC_sparsity_True,"{11, 13, 14, 15, 17, 19, 20, 21}"
ARTM,0,MDL_sparsity_True,"{8, 7}"
ARTM,0,TopicKernel@lemmatized.average_contrast,[4]
...,...,...,...
sparse,3,diversity_jensenshannon_True,[4]
sparse,3,intra,{}
sparse,3,renyi_0.5,{14}
sparse,3,renyi_1,"{17, 19}"


In [100]:
from collections import Counter

def combine_counters(a, b):
    ca = Counter(a)
    ca.update(Counter(b))
    return ca

df3 = df2.groupby(["corpus", "model_family", 'parameters_id']).agg(
    all_scores_intersection=pd.NamedAgg(
        column='random_intersection', 
        aggfunc=lambda data: reduce(combine_counters, data))
)



In [101]:
df3.loc['20NewsGroups'].all_scores_intersection.apply(lambda x: x.most_common(3))

model_family   parameters_id
ARTM           0                   [(18, 9), (6, 8), (12, 7)]
               1                  [(12, 8), (15, 8), (21, 7)]
               10                 [(15, 9), (18, 9), (21, 9)]
               11                   [(6, 8), (33, 7), (9, 7)]
               2                    [(6, 8), (33, 7), (9, 7)]
               3                  [(12, 6), (15, 6), (18, 6)]
               4                [(18, 12), (15, 11), (21, 9)]
               5                    [(6, 8), (33, 7), (9, 7)]
               6                   [(6, 9), (18, 8), (21, 8)]
               7                  [(15, 6), (18, 6), (21, 6)]
               8                    [(6, 8), (33, 7), (9, 7)]
               9                  [(18, 7), (21, 7), (12, 7)]
LDA            0                  [(17, 8), (20, 7), (23, 7)]
               1                    [(5, 6), (8, 5), (11, 5)]
               2                  [(17, 7), (20, 5), (23, 5)]
PLSA           0                   [(17, 

In [102]:
df3.loc['WikiRef220'].all_scores_intersection.apply(lambda x: x.most_common(3))

model_family   parameters_id
ARTM           0                 [(17, 5), (4, 4), (18, 4)]
               1                   [(5, 8), (6, 7), (7, 7)]
               10                [(7, 8), (17, 7), (18, 7)]
               11               [(14, 9), (16, 7), (17, 7)]
               2                [(12, 8), (10, 5), (11, 5)]
               3                  [(4, 10), (6, 6), (7, 6)]
               4                  [(7, 10), (9, 7), (6, 6)]
               5                [(13, 9), (14, 7), (16, 5)]
               6                 [(7, 8), (13, 8), (14, 8)]
               7                  [(6, 9), (7, 8), (11, 8)]
               8                [(12, 9), (13, 6), (10, 5)]
               9                [(10, 9), (14, 7), (15, 7)]
LDA            0                   [(7, 9), (5, 6), (6, 4)]
               1                  [(7, 10), (5, 6), (2, 5)]
               2                   [(7, 8), (5, 7), (6, 7)]
PLSA           0                  [(7, 5), (5, 3), (18, 3)]
TARTM      

In [103]:
df3.loc['PostNauka'].all_scores_intersection.apply(lambda x: x.most_common(3))

model_family   parameters_id
ARTM           0                [(48, 5), (51, 3), (33, 2)]
               1                 [(9, 7), (15, 7), (45, 6)]
               10                [(9, 8), (12, 7), (15, 6)]
               11               [(30, 7), (21, 6), (27, 6)]
               2                [(15, 8), (18, 6), (33, 6)]
               3                 [(51, 2), (6, 2), (27, 2)]
               4                 [(9, 9), (12, 7), (15, 5)]
               5                [(21, 7), (24, 7), (18, 6)]
               6                [(48, 5), (36, 2), (39, 2)]
               7                 [(9, 7), (15, 7), (45, 6)]
               8                [(39, 9), (18, 9), (24, 8)]
               9                [(27, 3), (30, 3), (33, 2)]
LDA            0                [(38, 6), (44, 6), (47, 6)]
               1                [(38, 7), (44, 7), (47, 7)]
               2                [(41, 9), (44, 9), (38, 7)]
PLSA           0                [(47, 3), (38, 2), (44, 2)]
TARTM      

In [106]:
df3.loc['StackOverflow'].all_scores_intersection.apply(lambda x: x.most_common(3))

model_family   parameters_id
ARTM           0                [(151, 4), (146, 3), (136, 2)]
               1                [(151, 4), (131, 3), (126, 3)]
               10               [(146, 6), (151, 6), (131, 3)]
               11               [(146, 7), (151, 7), (126, 6)]
               2                [(146, 5), (151, 5), (136, 4)]
               3                [(146, 3), (126, 3), (151, 3)]
               4                [(136, 2), (146, 2), (126, 2)]
               5                [(131, 5), (146, 5), (151, 5)]
               6                [(146, 5), (151, 5), (131, 4)]
               7                [(146, 5), (151, 5), (136, 5)]
               8                [(151, 5), (136, 5), (141, 4)]
               9                [(146, 6), (151, 6), (141, 4)]
LDA            0                  [(10, 4), (15, 4), (130, 3)]
               1                [(130, 5), (135, 5), (140, 5)]
               2                 [(130, 6), (95, 5), (100, 5)]
PLSA           0          