In [1]:
%load_ext watermark
%watermark  -a Filippo_Valle -v -m -g -r -v -p nltk,pandas,numpy,graph_tool,cloudpickle,topicpy,matplotlib,plotly

Author: Filippo_Valle

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.22.0

nltk       : 3.6.1
pandas     : 1.2.3
numpy      : 1.19.0
graph_tool : 2.37 (commit afba9459, )
cloudpickle: 1.6.0
topicpy    : 0.2.1
matplotlib : 3.4.1
plotly     : 4.14.3

Compiler    : GCC 9.3.0
OS          : Linux
Release     : 5.8.0-50-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 12
Architecture: 64bit

Git hash: aec0cb27d23e921cf53771b154b07fadbbd6854a

Git repo: git@github.com:fvalle1/epj.git



In [2]:
import pandas as pd
import numpy as np
import os,sys
sys.path.append("../")
from nlp import process_phrase
import nltk
nltk.download("reuters")
nltk.download("stopwords")
from nltk.corpus import reuters

import logging
log = logging.getLogger("plos")
log.addHandler(logging.StreamHandler())
log.setLevel(logging.DEBUG)

[nltk_data] Downloading package reuters to /home/jovyan/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package stopwords to /home/jovyan/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:
articles = [d for d in reuters.fileids() if d.startswith("test/")]
cat, count = np.unique(np.concatenate([reuters.categories(art) for art in articles]).ravel(), return_counts=True)
common_cat = cat[np.argsort(count)[::-1]][:10]
articles = np.random.choice([d for d in reuters.fileids() if d.startswith("test/") if reuters.categories(d)[0] in common_cat], 1000, replace=False)

In [None]:
print(len(articles))
articles[0]

In [None]:
def get_article(article):
    art = reuters.raw(article).split("\n")
    labels = reuters.categories(article)
    text = process_phrase(" ".join(art[1:]))
    return article, process_phrase(art[0]), text, labels

In [None]:
article = get_article(articles[10])
article

In [None]:
def get_article_dfs(article):
    df = pd.DataFrame()
    df_meta = pd.DataFrame()
    df_files = pd.DataFrame(columns=["category"])
    
    docid, title, text, labels = get_article(article)    
    if text is None:
        return None
    
    words = text.split(" ")
    title = title.split(" ")
    
    df_files=df_files.append(pd.Series(name=docid, index=["category"], data=labels[0], dtype=str))
    df = df.join(pd.Series(*np.unique(words, return_counts=True)[::-1], name=docid), how="outer")
    df_meta = df_meta.join(pd.Series(*np.unique(title, return_counts=True)[::-1], name=docid), how="outer")
    
    return df, df_meta, df_files

def append_callback(x):
    global df
    global df_meta
    global df_files
    if x is None:
        return None
    
    df_j, df_meta_j, df_files_j = x
    try:
        df = df.join(df_j, how="outer") # join new articles
        df_meta = df_meta.join(df_meta_j, how="outer") #join new articles
        df_files = df_files.append(df_files_j) 
    except:
        pass

In [None]:
get_article_dfs('test/14844')

In [None]:
import multiprocessing as mp

In [None]:
df = pd.DataFrame()
df_meta = pd.DataFrame()
df_files = pd.DataFrame(columns=["category"])

pool = mp.Pool(12)

work = [pool.apply_async(get_article_dfs, args=([article]), callback=append_callback, error_callback=lambda err: log.debug(err)) for article in articles]

pool.close()
pool.join()

In [None]:
df_files = df_files[df_files["category"].isin(common_cat)] 

df = df.fillna(0).astype(int).drop_duplicates()

'''
do reindex in two steps to avoid undefined behaviour
the sum is made on the new index
'''

df = df.reindex(columns=df.columns[df.columns.isin(df_files.index.dropna())]) #be sure every columns has a file
df = df.reindex(index=list(filter(lambda x:len(x)>0,df.index))) # remove '' from words

O = df.apply(lambda x: (x>0).sum(), axis=1)
df = df.reindex(index = df.index[O>5]) #words that appear in at least # articles 
df = df.reindex(columns = df.columns[df.sum(0) > 10]) #docs that have at least # word with repetition 

df_meta = df_meta.fillna(0).astype(int).drop_duplicates()

df_meta = df_meta.reindex(columns=df.columns) # match df index
df_meta = df_meta.reindex(index=list(filter(lambda x:len(x)>0,df_meta.index))) # remove '' from words
df_meta = df_meta.reindex(index=df_meta.index[df_meta.sum(1)>10]) # Keywords with at least # article

df_files = pd.DataFrame(index=df.columns, columns=["category"], data=[reuters.categories(art)[0] for art in df.columns])

In [None]:
df_files = pd.DataFrame(index=articles, columns=["category"], data=[reuters.categories(art)[0] for art in articles])

In [None]:
common_cat

In [None]:
df.sum(0).min()

In [None]:
print(df.shape)
print(df_meta.shape)
print(df_files.shape)

In [None]:
df_meta.sum(1)

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, ax = plt.subplots()
df.divide(df.sum(0),1).mean(1).sort_values(ascending=False).plot(ax=ax)

x = np.linspace(1,1e3)
ax.plot(x,1e-1*x**(-0.9))

ax.set_yscale("log")
ax.set_xscale("log")
fig.savefig("zipf.pdf")

# Make hSBM graph

In [3]:
import sys
sys.path.append("../../hSBM_Topicmodel/")

In [4]:
import graph_tool.all as gt
from sbmtm import sbmtm

In [None]:
sbmtm = sbmtm()
sbmtm.make_graph_from_BoW_df(df)
sbmtm.save_graph("reuters.xml.gz")

In [None]:
g = sbmtm.g
g

In [None]:
sbmtm.fit(n_init=5, verbose=False, B_min=10, B_max=100, parallel=True)

In [None]:
sbmtm.groups[0]=sbmtm.get_groups(0)
sbmtm.groups[1]=sbmtm.get_groups(1)

In [None]:
os.system("rm -rf reuters")
os.system("mkdir -p reuters")
os.chdir("reuters/")
df_files.to_csv("files.dat")
os.system("mkdir -p topsbm")
os.chdir("topsbm/")
sbmtm.save_data()
sbmtm.save_graph()
os.chdir("../../")

In [None]:
gt.draw_hierarchy(sbmtm.state, layout="bipartite", hedge_pen_width=8, hvertex_size=25, vertex_kind=sbmtm.g.vertex_properties["kind"])

## triSBM

In [5]:
sys.path.append("../../trisbm/")
from trisbm import trisbm

In [None]:
df_meta.index = ["#"+word for word in df_meta.index]

In [None]:
trisbm = trisbm()
trisbm.make_graph(df.append(df_meta), lambda word_keyword: 2 if word_keyword in df_meta.index else 1)

In [None]:
trisbm.save_graph("reuters_keyword.xml.gz")

In [None]:
trisbm.fit(n_init=5, verbose=False, B_min=10, B_max = 100)

In [None]:
import os

In [None]:
os.system("rm -r reuters_key")
os.system("mkdir -p reuters_key")
os.chdir("reuters_key/")
df_files.to_csv("files.dat")
os.system("mkdir -p trisbm")
os.chdir("trisbm/")
trisbm.save_data()
trisbm.save_graph()
os.chdir("../../")

In [None]:
gt.draw_hierarchy(trisbm.state, 
                  #pos=gt.sfdp_layout(model.g),
                  hedge_pen_width=8, 
                  hvertex_size=25, 
                 )

## Benchmark

In [None]:
from topicpy.hsbmpy import get_scores, get_scores_shuffled, add_score_lines, normalise_score
import matplotlib.pyplot as plt
import pandas as pd
import os

In [None]:
labels = ["category"]
scores = get_scores("reuters_key", labels, df_files, algorithm="trisbm", verbose=False)
scores['trisbm'] = scores[labels[0]]
scores["hsbm"]=get_scores("reuters", labels, df_files,algorithm="topsbm", verbose=False)[labels[0]]
scores['shuffle'] = get_scores_shuffled("reuters", df_files, label=labels[0], algorithm='topsbm')
normalise_score(scores, base_algorithm="shuffle", operation=lambda x,y: x/y)

In [None]:
fig=plt.figure(figsize=(18,15))
ax = fig.subplots(1)
add_score_lines(ax,scores,labels=["hsbm","trisbm", "shuffle"], V="norm_V", alpha=1)
ax.set_xscale('log')
ax.set_ylim(0,max(map(lambda s: max(s["norm_V"]), scores.values()))*1.1)
ax.set_xlim(0,max(map(lambda s: max(s["xl"]), scores.values()))*1.1)

plt.show()
fig.savefig("metric_scores.pdf")

In [None]:
import plotly.graph_objects as go

In [None]:
fig = go.Figure()
fig.add_traces([
    go.Bar(y = [sbmtm.get_mdl()/sbmtm.g.num_edges()], name="hSBM"),
    go.Bar(y = [trisbm.get_mdl()/trisbm.g.num_edges()], name="triSBM")
])

titlefont = {
    "size": 30 
}

tickfont = {
    "size":25
}

layout = {
    "title":"Reuters dataset",
    "xaxis":{
        "title": "Resolution",
        "titlefont": titlefont,
        "tickfont": tickfont
    },
    "yaxis":{
        "title": "∑/E",
        "type":"log",
        #"range": [10e3,20e3],
        "titlefont": titlefont,
        "tickfont": tickfont
    },
    "legend":{
        "font_size":35
    }
}

fig.update_layout(layout)
#fig.write_image("metric_entropies_bar.pdf")

In [None]:
fig = go.Figure()
fig.add_traces([
    go.Bar(y=scores["hsbm"]["norm_V"], name="hSBM"),
    go.Bar(y=scores["trisbm"]["norm_V"], name="triSBM")
])

titlefont = {
    "size": 30 
}

tickfont = {
    "size":25
}

layout = {
    "title":"Plos dataset",
    "xaxis":{
        "title": "Resolution",
        "titlefont": titlefont,
        "tickfont": tickfont
    },
    "yaxis":{
        "title": "NMI/NMI*",
        "titlefont": titlefont,
        "tickfont": tickfont
    },
    "legend":{
        "font_size":35
    }
}

fig.update_layout(layout)
#fig.write_image("metric_scores_bar.pdf")

In [None]:
from topicpy.hsbmpy import clusteranalysis

In [None]:
clusteranalysis(os.getcwd()+"/reuters/", ["category"], algorithm="topsbm")
clusteranalysis(os.getcwd()+"/reuters_key/", ["category"], algorithm="trisbm")

In [None]:
import cloudpickle as pickle

with open("sbmtm.pkl", "wb") as file:
    pickle.dump(sbmtm, file)
    
with open("trisbm.pkl", "wb") as file:
    pickle.dump(trisbm, file)