### Topic modelling

See *Blei, 2003: Latent dirichlet allocation* [PDF](http://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf) for a description of LDA.

### LDA Topic modelling



#### What is The Jupyter Project and Jupyter Notebook

> Project [Jupyter](http://jupyter.org/) develops open-source software for **interactive and reproducible computing**.<br>
> The **open science movement** is a driving force for Jupyter's popularity.<br>
> In part a response to the **reproducibility crisis in science** and the **statistical crisis in science** (aka data dredging, p-hacking) in science.<br>
> With Jupyter Notebooks contain **excutable code, equations, visualizations and narrative text**.<br>
> It is a **web application** (can run locally) with a simple and easy to use web interface.
> <img src="./images/narrative_new.svg" style="width: 300px;padding: 0; margin: 0;"><br>
> Jupyter supports a large number of programming languages (50+ e.g. Python, R, JavaScript)

The project is sponsered by large companies such as Google and Microsoft, and funders such as Alfred P. Sloan foundation. See link [jupyter.org/about](http://jupyter.org/about) for all sponsors.

#### Brief Instructions on How to Use Notebooks
- **Menu Help -> User Interface Tour** gives an overview of the user interface.
- **Code cells** contains the script code and have **In [x]** in the left margin.
  - **In []** indicates that the code cell hasn't been executed yet.
  - **In [n]** indicates that the code has been executed(n is an integer).
  - **In [\*]** indicates that the code is executing, or waiting to be executed (i.e. other cells are executing).
- **The current code** is highlighted with a blue border - you make it current by clicking on it.
- **SHIFT+ENTER** or **Play button** executes the current cell. Code cells aren't executed automatically.
- **Out[n]** indicates the output (or result) of a cell's execution and is directly below the executed cell.
- **SHIFT+ENTER** automatically selects the next code cell.
- **SHIFT+ENTER** can hence be used repeatedly to executes the code cells in sequence.
- **Menu Cell -> Run All** executes the entire notebook in a single step (can take some time to finish, notice how "In [\*]" indicators change to "In [n]" ).
- **Double-Click** on a cell to edit its content.
- **ESC key** Leaves edit mode (or just click on any other cell).
- **Kernel -> Restart** restarts server side kernel (use if notebook seems stuck)


### Risks
- The risk of using tools and methods **without fully understanding** them
- The risk of using tools and methods **for non-intended purposes or in new contexts**
- How to verify **performance** (correctness of result)
- Risk of **data dredging**, p-hacking, "the statistical crisis".
- The risk that **engineer makes micro-decisions** the researcher don'r know about
- The risk of **reading to much into visualizations** (networks, layouts, clusters).

### Challenges
- **What’s easy for humans can be extremely hard for computers**
- **Human-in-the-loop or supervised learning can be very expensive**
- Ambiguity and fuzziness of terms and phrases
- Poor data quality, errors in data, wrong data, missing data, ambigeous data
- Context, metadata, domain-specific data
- Data size (to much, to little)
- Computational methods requires a structured internal representation
- Internal models are a simplified views of the data
- etc...

### A sample high-level workflow

<img src="./images/text_analysis_workflow.svg" alt="" width="1200"/>

# Sample text corpus processing
### Extract Text From PDFs

In [1]:
import glob
import os
import zipfile

from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
from pdfminer.pdfpage import PDFTextExtractionNotAllowed, PDFPage
from pdfminer.pdfinterp import PDFPageInterpreter, PDFResourceManager
from pdfminer.converter import TextConverter
from pdfminer.layout import LAParams
from io import StringIO

def extract_pdf_text(filename):
    text_lines = []
    with open(filename, 'rb') as fp:
        
        parser = PDFParser(fp)
        document = PDFDocument(parser)

        if not document.is_extractable:
            raise PDFTextExtractionNotAllowed

        resource_manager = PDFResourceManager()

        result_buffer = StringIO()

        device = TextConverter(resource_manager, result_buffer, codec='utf-8', laparams=LAParams())

        interpreter = PDFPageInterpreter(resource_manager, device)

        for page in PDFPage.create_pages(document):
            interpreter.process_page(page)

        lines = result_buffer.getvalue().splitlines()
        for line in lines:
            text_lines.append(line)

    return text_lines


def extract_pdf_texts(source_folder, target_zip_filename):
    with zipfile.ZipFile(target_zip_filename, 'w', zipfile.ZIP_DEFLATED) as target_zip:

        for filename in glob.glob(os.path.join(source_folder,'*.pdf')):

            print('Processing: ' + filename)

            text_lines = extract_pdf_text(filename)

            target_filename = os.path.splitext(os.path.split(filename)[1])[0] + '.txt'
            target_filename = target_filename.lower().replace(' ', '_').replace(',','')

            target_zip.writestr(target_filename, '\n'.join(text_lines))
            
#source_folder = './data'
#target_zip_filename = 'data/texts.zip'
#extract_pdf_texts(source_folder, target_zip_filename)

PDF Conversion to Text

### TRY IT: Text Cleanup

In [88]:
# folded code
# -*- coding: utf-8 -*-
%load_ext autoreload
%autoreload 2
import os
import nltk
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
import re
import string
import zipfile
import spacy
import textacy
import textacy.extract
import textacy.preprocess
import common.utility as utility
import warnings
import types
from gensim import corpora, models, matutils

from IPython.display import display, HTML
from spacy import displacy

logger = utility.getLogger(format="%(levelname)s;%(message)s")
warnings.filterwarnings('ignore')

def get_filenames(zip_filename, extension='.txt'):
    with zipfile.ZipFile(zip_filename, mode='r') as zf:
        return [ x for x in zf.namelist() if x.endswith(extension) ]
    
def get_text(zip_filename, filename):
    with zipfile.ZipFile(zip_filename, mode='r') as zf:
        return zf.read(filename).decode(encoding='utf-8')

DEFAULT_TERM_PARAMS = dict(
    args=dict(ngrams=1, named_entities=True, normalize='lemma', as_strings=True),
    kwargs=dict(filter_stops=True, filter_punct=True, filter_nums=True, min_freq=1, drop_determiners=True, include_pos=('NOUN', 'PROPN', ))
)
    
def filter_terms(doc, term_args, chunk_size=None):
    kwargs = utility.extend({}, DEFAULT_TERM_PARAMS['kwargs'], term_args['kwargs'])
    args = utility.extend({}, DEFAULT_TERM_PARAMS['args'], term_args['args'])
    terms = doc.to_terms_list(
        args['ngrams'],
        args['named_entities'],
        args['normalize'],
        args['as_strings'],
        **kwargs
    )
    return terms

LANGUAGE = 'en'
SOURCE_FOLDER = './data'
SOURCE_FILENAME = os.path.join(SOURCE_FOLDER, 'p1_paper_texts_edited.zip')
HYPHEN_REGEXP = re.compile(r'\b(\w+)-\s*\r?\n\s*(\w+)\b', re.UNICODE)
DF_TAGSET = pd.read_csv('./data/tagset.csv', sep='\t').fillna('')

logger.info('POS tag set: ' + ' '.join(list(DF_TAGSET.POS.unique())))

%matplotlib inline




INFO : POS tag set: PUNCT SYM X ADJ VERB CONJ NUM DET ADV ADP  NOUN PROPN PART PRON SPACE INTJ


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Preparation Step: Automatic Text Cleanup

In [2]:

def preprocess_text(source_filename, **args):
    filenames = get_filenames(source_filename)
    basename, extension = os.path.splitext(source_filename)
    target_filename = basename + '_preprocessed' + extension # '_'.join(list(args.keys())) + extension
    texts = ( (filename, get_text(source_filename, filename)) for filename in filenames )
    with zipfile.ZipFile(target_filename, 'w', zipfile.ZIP_DEFLATED) as zf:
        for filename, text in texts:
            logger.info('Processing ' + filename)
            text = re.sub(HYPHEN_REGEXP, r"\1\2\n", text)
            text = textacy.preprocess.normalize_whitespace(text)   
            text = textacy.preprocess.fix_bad_unicode(text)   
            text = textacy.preprocess.replace_currency_symbols(text)
            text = textacy.preprocess.unpack_contractions(text)
            text = textacy.preprocess.replace_urls(text)
            text = textacy.preprocess.replace_emails(text)
            text = textacy.preprocess.replace_phone_numbers(text)
            text = textacy.preprocess.remove_accents(text)
            #text = preprocess.preprocess_text(text, **args)
            zf.writestr(filename, text)
            
#preprocess_text(SOURCE_FILENAME)
#preprocess_text(SOURCE_FILENAME, lowercase=True)


## Preparation Step: Create Text Corpus 

In [89]:
from spacy.language import Language
from textacy.spacier.utils import merge_spans

def create_textacy_corpus(source_filename, language, preprocess_args):
    make_title = lambda filename: filename.replace('_', ' ').replace('.txt', '').title()
    filenames = get_filenames(source_filename)
    corpus = textacy.Corpus(language)
    text_stream = ( (filename, get_text(source_filename, filename)) for filename in filenames )
    for filename, text in text_stream:
        logger.info('Processing ' + filename)
        text = re.sub(HYPHEN_REGEXP, r"\1\2\n", text)
        text = textacy.preprocess.preprocess_text(text, **preprocess_args)
        corpus.add_text(text, dict(filename=filename, title=make_title(filename)))
    for doc in corpus:
        doc.spacy_doc.user_data['title'] = doc.metadata['title']
    return corpus

def remove_whitespace_entities(doc):
    doc.ents = [ e for e in doc.ents if not e.text.isspace() ]
    return doc

def generate_textacy_corpus(source_filename, language, corpus_args, preprocess_args, merge_named_entities=True, force=False):
    
    corpus_tag = '_'.join([ k for k in preprocess_args if preprocess_args[k] ]) + \
        '_disable(' + ','.join(corpus_args.get('disable', [])) +')'
    
    textacy_corpus_filename = os.path.join(SOURCE_FOLDER, 'corpus_{}_{}.pkl'.format(language, corpus_tag))
    
    Language.factories['remove_whitespace_entities'] = lambda nlp, **cfg: remove_whitespace_entities
    
    logger.info('Loading model: english...')
    nlp = textacy.load_spacy('en_core_web_sm', **corpus_args)
    pipeline = lambda: [ x[0] for x in nlp.pipeline ]
        
    logger.info('Using pipeline: ' + ' '.join(pipeline()))

    if force or not os.path.isfile(textacy_corpus_filename):
        logger.info('Working: Computing new corpus ' + textacy_corpus_filename + '...')
        corpus = create_textacy_corpus(source_filename, nlp, preprocess_args)
        corpus.save(textacy_corpus_filename)
    else:
        logger.info('Working: Loading corpus ' + textacy_corpus_filename + '...')
        corpus = textacy.Corpus.load(textacy_corpus_filename)
        
    if merge_named_entities:
        logger.info('Working: Merging named entities...')
        for doc in corpus:
            named_entities = textacy.extract.named_entities(doc)
            merge_spans(named_entities, doc.spacy_doc)
    else:
        logger.info('Note: named entities not merged')
        
    logger.info('Done!')
    return textacy_corpus_filename, corpus

def assign_document_titles(corpus):
    for doc in corpus:
        doc.spacy_doc.user_data['title'] = doc.metadata['title']
    
def get_corpus_documents(corpus):
    df_documents = pd.DataFrame([ (document_id, doc.metadata['title'], doc.metadata['filename']) for document_id, doc in enumerate(corpus) ], columns=['document_id', 'title', 'filename']).set_index('document_id')
    return df_documents

textacy_corpus_filename, corpus = generate_textacy_corpus(SOURCE_FILENAME, LANGUAGE, corpus_args=dict(), preprocess_args=dict(), merge_named_entities=True, force=False)
df_documents = get_corpus_documents(corpus)


INFO : Loading model: english...
INFO : Using pipeline: tagger parser ner
INFO : Working: Loading corpus ./data/corpus_en__disable().pkl...
INFO : Working: Merging named entities...
INFO : Done!


In [90]:
def display_cleanup_text_gui(corpus, callback):
    
    document_options = {v: k for k, v in df_documents['title'].to_dict().items()}
    
    #pos_options = [ x for x in DF_TAGSET.POS.unique() if x not in ['PUNCT', '', 'DET', 'X', 'SPACE', 'PART', 'CONJ', 'SYM', 'INTJ', 'PRON']]  # groupby(['POS'])['DESCRIPTION'].apply(list).apply(lambda x: ', '.join(x)).to_dict()
    pos_tags = DF_TAGSET.groupby(['POS'])['DESCRIPTION'].apply(list).apply(lambda x: ', '.join(x[:1])).to_dict()
    pos_options = { k + ' (' + v + ')': k for k,v in pos_tags.items() }
    display_options = { 'Source text': 'source_text', 'Sanitized text': 'sanitized_text', 'Statistics': 'statistics'}
    gui = types.SimpleNamespace(
        document_id=widgets.Dropdown(description='Paper', options=document_options, value=0, layout=widgets.Layout(width='400px')),
        progress=widgets.IntProgress(value=0, min=0, max=5, step=1, description='', layout=widgets.Layout(width='90%')),
        min_freq=widgets.FloatSlider(value=0, min=0, max=1.0, step=0.01, description='Min frequency', layout=widgets.Layout(width='400px')),
        ngrams=widgets.Dropdown(description='n-grams', options=[1,2,3], value=1, layout=widgets.Layout(width='180px')),
        min_word=widgets.Dropdown(description='Min length', options=[1,2,3,4], value=1, layout=widgets.Layout(width='180px')),
        normalize=widgets.Dropdown(description='Normalize', options=[ False, 'lemma', 'lower' ], value=False, layout=widgets.Layout(width='180px')),
        filter_stops=widgets.ToggleButton(value=False, description='Filter stops',  tooltip='Filter out stopwords', icon='check'),
        filter_nums=widgets.ToggleButton(value=False, description='Filter nums',  tooltip='Filter out stopwords', icon='check'),
        filter_punct=widgets.ToggleButton(value=False, description='Filter punct',  tooltip='Filter out punctuations', icon='check'),
        named_entities=widgets.ToggleButton(value=False, description='Merge entities',  tooltip='Merge entities', icon='check'),
        drop_determiners=widgets.ToggleButton(value=False, description='Drop determiners',  tooltip='Drop determiners', icon='check'),
        include_pos=widgets.SelectMultiple(description='POS', options=pos_options, value=list(), rows=10, layout=widgets.Layout(width='400px')),
        display_type=widgets.Dropdown(description='Show', value='statistics', options=display_options, layout=widgets.Layout(width='180px')),
        output_text=widgets.Output(layout={'height': '500px'}),
        output_statistics = widgets.Output(),
        boxes=None
    )
    
    uix = widgets.interactive(

        callback,

        corpus=widgets.fixed(corpus),
        gui=widgets.fixed(gui),
        display_type=gui.display_type,
        document_id=gui.document_id,
        
        ngrams=cleanup_gui.ngrams,
        named_entities=gui.named_entities,
        normalize=gui.normalize,
        filter_stops=gui.filter_stops,
        filter_punct=gui.filter_punct,
        filter_nums=gui.filter_nums,
        include_pos=gui.include_pos,
        min_freq=gui.min_freq,
        drop_determiners=gui.drop_determiners
    )
    
    gui.boxes = widgets.VBox([
        gui.progress,
        widgets.HBox([
            widgets.VBox([
                gui.document_id,
                widgets.HBox([gui.display_type, gui.normalize]),
                widgets.HBox([gui.ngrams, gui.min_word]),
                gui.min_freq
            ]),
            widgets.VBox([
                gui.include_pos
            ]),
            widgets.VBox([
                gui.filter_stops,
                gui.filter_nums,
                gui.filter_punct,
                gui.named_entities,
                gui.drop_determiners
            ])
        ]),
        widgets.HBox([
            gui.output_text, gui.output_statistics
        ]),
        uix.children[-1]
    ])
    
    display(gui.boxes)
                                  
    uix.update()
    return gui, uix

def plot_xy_data(data, title='', xlabel='', ylabel='', **kwargs):
    x, y = list(data[0]), list(data[1])
    labels = x
    plt.figure(figsize=(10, 10 / 1.618))
    plt.plot(x, y, 'ro', **kwargs)
    plt.xticks(x, labels, rotation='45')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()
    
def display_cleaned_up_text(corpus, gui, display_type, document_id, **kwargs): # ngrams, named_entities, normalize, include_pos):
    
    gui.output_text.clear_output()
    gui.output_statistics.clear_output()
    
    #Additional candidates;
    #is_alpha	bool	Does the token consist of alphabetic characters? Equivalent to token.text.isalpha().
    #is_ascii	bool	Does the token consist of ASCII characters? Equivalent to [any(ord(c) >= 128 for c in token.text)].
    #like_url	bool	Does the token resemble a URL?
    #like_email	bool	Does the token resemble an email address?
                                                                    
    terms = [ x for x in corpus[document_id].to_terms_list(as_strings=True, **kwargs) ]
    
    if display_type == 'source_text':
        # Utskrift av de första och sista 250 tecknen:
        with gui.output_text.clear_output():
            #print('{}\n.................\n(NOT SHOWN TEXT)\n.................\n{}'.format(document[:2500], document[-250:]))
            #print(doc)
            print(' '.join(list(terms)))
        return
        
    if display_type in ['sanitized_text', 'statistics']:

        if display_type == 'sanitized_text':
            with gui.output_text:
                #display('{}\n.................\n(NOT SHOWN TEXT)\n.................\n{}'.format(
                #    ' '.join(tokens[:word_count]),
                #    ' '.join(tokens[-word_count:])
                #))
                print(' '.join(list(terms)))
                return

        if display_type == 'statistics':

            wf = nltk.FreqDist(terms)

            with gui.output_text:

                print('Antal ord (termer): {}'.format(wf.N()))
                print('Antal unika termer (vokabulär): {}'.format(wf.B()))
                print(' ')

                df = pd.DataFrame(wf.most_common(25), columns=['token','count'])
                display(df)

            with gui.output_statistics:

                data = list(zip(*wf.most_common(25)))
                plot_xy_data(data, title='Word distribution', xlabel='Word', ylabel='Word count')

                wf = nltk.FreqDist([len(x) for x in terms])
                data = list(zip(*wf.most_common(25)))
                plot_xy_data(data, title='Word length distribution', xlabel='Word length', ylabel='Word count')

xgui, xuix =display_cleanup_text_gui(corpus, display_cleaned_up_text)



VBox(children=(IntProgress(value=0, layout=Layout(width='90%'), max=5), HBox(children=(VBox(children=(Dropdown…

## Compute LDA Topic Model


In [91]:
import types

# OBS OBS! https://scikit-learn.org/stable/auto_examples/applications/plot_topics_extraction_with_nmf_lda.html
DEFAULT_VECTORIZE_PARAMS = dict(tf_type='linear', apply_idf=False, idf_type='smooth', norm='l2', min_df=1, max_df=0.95)

def compute_topic_model(corpus, tick=utility.noop, method='sklearn_lda', vec_args=None, term_args=None, tm_args=None, **args):
    
    tick()
    vec_args = utility.extend({}, DEFAULT_VECTORIZE_PARAMS, vec_args)
    
    terms_iter = lambda: (filter_terms(doc, term_args) for doc in corpus)

    vectorizer = textacy.Vectorizer(**vec_args)
    doc_term_matrix = vectorizer.fit_transform(terms_iter())
    tick()

    if method == 'sklearn_lda':
        model = textacy.TopicModel('lda', **tm_args)
        model.fit(doc_term_matrix)
        tick()
        doc_topic_matrix = model.transform(doc_term_matrix)
        tick()
        id2word = vectorizer.id_to_term
    else:
        dictionary = corpora.Dictionary(terms_iter())
        lda_corpus = [dictionary.doc2bow(text) for text in terms_iter()]
        id2word = dictionary
        #id2word = vectorizer.id_to_term
        #lda_corpus = matutils.Sparse2Corpus(doc_term_matrix, documents_columns=False)
        model = models.LdaModel(
            lda_corpus, # [doc for doc in terms_iter()],
            num_topics  =  tm_args.get('n_topics', 0),
            id2word     =  id2word,
            iterations  =  tm_args.get('max_iter', 0),
            passes      =  20,
            alpha       = 'asymmetric'
        )
    
    tm_model = types.SimpleNamespace(
        model=model,
        doc_term_matrix=doc_term_matrix,
        doc_topic_matrix=doc_topic_matrix if method == 'sklearn_lda' else None,
        id_to_term=id2word,
        id2term=id2word,
        vectorizer=vectorizer
    )
    
    tick(0)
    
    return tm_model

def get_doc_topic_weights(doc_topic_matrix, threshold=0.05):
    topic_ids = range(0,doc_topic_matrix.shape[1])
    for document_id in range(0,doc_topic_matrix.shape[1]):
        topic_weights = doc_topic_matrix[document_id, :]
        for topic_id in topic_ids:
            if topic_weights[topic_id] >= threshold:
                yield (document_id, topic_id, topic_weights[topic_id])

def get_df_doc_topic_weights(doc_topic_matrix, threshold=0.05):
    it = get_doc_topic_weights(doc_topic_matrix, threshold)
    df = pd.DataFrame(list(it), columns=['document_id', 'topic_id', 'weight']).set_index('document_id')
    return df

def display_topic_model_gui(corpus, compute_callback):
    
    pos_options = [ x for x in DF_TAGSET.POS.unique() if x not in ['PUNCT', '', 'DET', 'X', 'SPACE', 'PART', 'CONJ', 'SYM', 'INTJ', 'PRON']]  # groupby(['POS'])['DESCRIPTION'].apply(list).apply(lambda x: ', '.join(x)).to_dict()
    engine_options = {'gensim': 'gensim', 'sklearn_lda': 'sklearn_lda'}
    normalize_options = { 'None': False, 'Use lemma': 'lemma', 'Lowercase': 'lower'}
    ngrams_options = { '1': [1], '1, 2': [1, 2], '1,2,3': [1, 2, 3] }
    gui = types.SimpleNamespace(
        progress=widgets.IntProgress(value=0, min=0, max=5, step=1, description='', layout=widgets.Layout(width='90%')),
        n_topics=widgets.IntSlider(description='#topics', min=5, max=50, value=20, step=1),
        min_freq=widgets.IntSlider(description='Min word freq', min=0, max=10, value=2, step=1),
        max_iter=widgets.IntSlider(description='Max iterations', min=100, max=1000, value=20, step=10),
        ngrams=widgets.Dropdown(description='n-grams', options=ngrams_options, value=[1], layout=widgets.Layout(width='200px')),
        normalize=widgets.Dropdown(description='Normalize', options=normalize_options, value='lemma', layout=widgets.Layout(width='200px')),
        filter_stops=widgets.ToggleButton(value=True, description='Remove stopword',  tooltip='Filter out stopwords', icon='check'),
        filter_nums=widgets.ToggleButton(value=True, description='Remove nums',  tooltip='Filter out stopwords', icon='check'),
        named_entities=widgets.ToggleButton(value=False, description='Merge entities',  tooltip='Merge entities', icon='check'),
        drop_determiners=widgets.ToggleButton(value=True, description='Drop determiners',  tooltip='Drop determiners', icon='check'),
        apply_idf=widgets.ToggleButton(value=False, description='Apply IDF',  tooltip='Apply TF-IDF', icon='check'),
        include_pos=widgets.SelectMultiple(description='POS', options=pos_options, value=['NOUN', 'PROPN'], rows=7, layout=widgets.Layout(width='200px')),
        method=widgets.Dropdown(description='Engine', options=engine_options, value='gensim', layout=widgets.Layout(width='200px')),
        compute=widgets.Button(description='Compute'),
        boxes=None,
        output = widgets.Output(layout={'height': '500px', 'border': '1px solid black'}),
        model=None
    )
    gui.boxes = widgets.VBox([
        gui.progress,
        widgets.HBox([
            widgets.VBox([
                gui.n_topics,
                gui.min_freq,
                gui.max_iter
            ]),
            widgets.VBox([
                gui.filter_stops,
                gui.filter_nums,
                gui.named_entities,
                gui.drop_determiners,
                gui.apply_idf
            ]),
            widgets.VBox([
                gui.normalize,
                gui.ngrams,
                gui.method
            ]),
            gui.include_pos,
            widgets.VBox([
                gui.compute
            ])
        ]),
        widgets.VBox([gui.output]), # ,layout=widgets.Layout(top='20px', height='500px',width='100%'))
    ])
    fx = lambda *args: compute_callback(gui, *args)
    gui.compute.on_click(fx)
    return gui
    

def compute_callback(gui, *args):
    
    def tick(x=None):
        gui.progress.value = gui.progress.value + 1 if x is None else x
        
    gui.output.clear_output()
    with gui.output:
        vec_args = dict(apply_idf=gui.apply_idf.value)
        term_args = dict(
            args=dict(
                ngrams=gui.ngrams.value,
                named_entities=gui.named_entities.value,
                normalize=gui.normalize.value,
                as_strings=True
            ),
            kwargs=dict(
                filter_nums=gui.filter_nums.value,
                drop_determiners=gui.drop_determiners.value,
                min_freq=gui.min_freq.value,
                include_pos=gui.include_pos.value,
                filter_stops=gui.filter_stops.value,
                filter_punct=True
            )
        )
        tm_args = dict(
            n_topics=gui.n_topics.value,
            max_iter=gui.max_iter.value,
            learning_method='online', 
            n_jobs=1
        )
        method = gui.method.value
        gui.model = compute_topic_model(
            corpus=corpus,
            tick=tick,
            method=method,
            vec_args=vec_args,
            term_args=term_args,
            tm_args=tm_args
        )
        
tm_gui = display_topic_model_gui(corpus, compute_callback)
display(tm_gui.boxes)


VBox(children=(IntProgress(value=0, layout=Layout(width='90%'), max=5), HBox(children=(VBox(children=(IntSlide…

In [87]:

def compile_dictionary(lda):
    logger.info('Compiling dictionary...')
    token_ids, tokens = list(zip(*lda.id2word.items()))
    dfs = lda.id2word.dfs.values() if lda.id2word.dfs is not None else [0] * len(tokens)
    dictionary = pd.DataFrame({
        'token_id': token_ids,
        'token': tokens,
        'dfs': list(dfs)
    }).set_index('token_id')[['token', 'dfs']]
    return dictionary
    
def compile_topic_token_weights(lda, dictionary, num_words=200):
    logger.info('Compiling topic-tokens weights...')

    df_topic_weights = pd.DataFrame(
        [ (topic_id, token, weight)
            for topic_id, tokens in (lda.show_topics(lda.num_topics, num_words=num_words, formatted=False))
                for token, weight in tokens if weight > 0.0 ],
        columns=['topic_id', 'token', 'weight']
    )

    df = pd.merge(
        df_topic_weights.set_index('token'),
        dictionary.reset_index().set_index('token'),
        how='inner',
        left_index=True,
        right_index=True
    )
    return df.reset_index()[['topic_id', 'token_id', 'token', 'weight']]

lda = tm_gui.model.model
id2word = tm_gui.model.id2term

df_dictionary = compile_dictionary(lda)
df_topic_token_weights = compile_topic_token_weights(lda, df_dictionary, num_words=200)


INFO : Compiling dictionary...
INFO : Compiling topic-tokens weights...


Unnamed: 0,topic_id,token_id,token,weight
0,4,213,ability,0.004586
1,19,213,ability,0.000048
2,14,536,academic,0.000059
3,17,536,academic,0.008979
4,18,536,academic,0.000053
5,3,492,acceptance,0.007093
6,6,492,acceptance,0.001792
7,7,492,acceptance,0.000172
8,9,492,acceptance,0.001792
9,11,492,acceptance,0.001792


## Display Named Entities

In [None]:
def display_document_entities_gui(corpus):
    
    def display_document_entities(document_id, corpus):
        displacy.render(corpus[document_id].spacy_doc, style='ent', jupyter=True)
    
    document_widget = widgets.Dropdown(description='Paper', options={v: k for k, v in df_documents['title'].to_dict().items()}, value=0, layout=widgets.Layout(width='80%'))

    itw = widgets.interactive(display_document_entities,document_id=document_widget, corpus=widgets.fixed(corpus))

    display(widgets.VBox([document_widget, widgets.VBox([itw.children[-1]],layout=widgets.Layout(margin_top='20px', height='500px',width='100%'))]))

    itw.update()

display_document_entities_gui(corpus)


## Document Key Terms 
- [TextRank]	Mihalcea, R., & Tarau, P. (2004, July). TextRank: Bringing order into texts. Association for Computational Linguistics.
- [SingleRank]	Hasan, K. S., & Ng, V. (2010, August). Conundrums in unsupervised keyphrase extraction: making sense of the state-of-the-art. In Proceedings of the 23rd International Conference on Computational Linguistics: Posters (pp. 365-373). Association for Computational Linguistics.


In [92]:
import textacy.keyterms

def display_document_key_terms_gui(corpus):
    
    methods = { 'SingleRank': textacy.keyterms.singlerank, 'TextRank': textacy.keyterms.textrank }
    document_options = {v: k for k, v in df_documents['title'].to_dict().items()}
    
    gui = types.SimpleNamespace(
        output=widgets.Output(layout={'border': '1px solid black'}),
        n_keyterms=widgets.IntSlider(description='#words', min=10, max=500, value=100, step=1, layout=widgets.Layout(width='240px')),
        document_id=widgets.Dropdown(description='Paper', options=document_options, value=0, layout=widgets.Layout(width='40%')),
        method=widgets.Dropdown(description='Algorithm', options=[ 'TextRank', 'SingleRank' ], value='TextRank', layout=widgets.Layout(width='180px')),
        normalize=widgets.Dropdown(description='Normalize', options=[ 'lemma', 'lower' ], value='lemma', layout=widgets.Layout(width='160px'))
    )
    
    def display_document_key_terms(corpus, method='TextRank', document_id=0, normalize='lemma', n_keyterms=10):
        keyterms = methods[method](corpus[document_id], normalize=normalize, n_keyterms=n_keyterms)
        terms = ' '.join([ x for x, y in keyterms ])
        gui.output.clear_output()
        with gui.output:
            display(terms)

    itw = widgets.interactive(
        display_document_key_terms,
        corpus=widgets.fixed(corpus),
        method=gui.method,
        document_id=gui.document_id,
        normalize=gui.normalize,
        n_keyterms=gui.n_keyterms,
    )

    display(widgets.VBox([
        widgets.HBox([gui.document_id, gui.method, gui.normalize, gui.n_keyterms]),
        gui.output
    ]))

    itw.update()

display_document_key_terms_gui(corpus)


VBox(children=(HBox(children=(Dropdown(description='Paper', index=5, layout=Layout(width='40%'), options={'Com…

### Basic Document Statistics

In [None]:
import common.utility as utility
import ipywidgets as widgets

df = pd.DataFrame([ utility.extend(dict(title=x.metadata['title']), textacy.TextStats(x).basic_counts) for x in corpus ])
df[['title', 'n_chars', 'n_words', 'n_unique_words', 'n_sents']]

In [None]:
punct_table = str.maketrans('', '', string.punctuation)

%matplotlib inline

hyphen_regexp = re.compile(r'\b(\w+)-\s*\r?\n\s*(\w+)\b', re.UNICODE)

def tokenize_and_sanitize(document, min_length=3, only_isalpha=True, remove_puncts=True, to_lower=True, remove_stop=True):
        
    #if de_hyphen:
    #    document = re.sub(hyphen_regexp, r"\1\2\n", document)

    tokens = (x for x in document if not (x.is_space))
    
    

In [None]:
    tokens = nltk.word_tokenize(document)
    
    # Ta bort ord kortare än tre tecken
    if min_length > 0:
        tokens = [ x for x in tokens if len([w for w in x if w.isalpha()]) > min_length ]
        
    # Ta bort ord som inte innehåller någon siffra eller bokstav
    if only_isalpha:
        tokens = [ x for x in tokens if x.isalpha() ]

    if remove_puncts:
        tokens = [ x.translate(punct_table) for x in tokens ]

    # Transformera till små bokstäver
    if to_lower:
        tokens = [ x.lower() for x in tokens ]
        
    # Ta bort de vanligaste stoporden
    if remove_stop:
        stopwords = nltk.corpus.stopwords.words('swedish')
        tokens = [ x for x in tokens if x not in stopwords ]
        
    return [ x for x in tokens if len(x) > 0 ]

def plot_xy_data(data, title='', xlabel='', ylabel='', **kwargs):
    
    x = list(data[0])
    y = list(data[1])
    labels = x

    plt.figure(figsize=(8, 8 / 1.618))
    plt.plot(x, y, 'ro', **kwargs)
    plt.xticks(x, labels, rotation='45')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    plt.show()
    
container=dict(
    display_type=widgets.Dropdown(
        description='Show',
        value='statistics',
        options={
            'Source text': 'source_text',
            'Sanitized text': 'sanitized_text',
            'Statistics': 'statistics'           
    }),
    min_length=widgets.IntSlider(value=0, min=0, max=5, step=1, description='Min alpha', tooltip='Min number of alphabetic characters'),
    de_hyphen=widgets.ToggleButton(value=False, description='Dehyphen', disabled=False, tooltip='Fix hyphens', icon=''),
    to_lower=widgets.ToggleButton(value=False, description='Lowercase', disabled=False, tooltip='Transform text to lowercase', icon=''),
    remove_stop=widgets.ToggleButton(value=False, description='No stopwords', disabled=False, tooltip='Remove stopwords', icon=''),
    only_isalpha=widgets.ToggleButton(value=False, description='Only alpha', disabled=False, tooltip='Keep only alphabetic words', icon=''),
    remove_puncts=widgets.ToggleButton(value=False, description='Remove puncts.', disabled=False, tooltip='Remove punctioations characters', icon=''),
    progress=widgets.IntProgress(value=0, min=0, max=5, step=1, description='' )
)

output1 = widgets.Output() #layout={'border': '1px solid black'})
output2 = widgets.Output() #layout={'border': '1px solid black'})
default_output = None

tokens = []

def display_document(display_type, to_lower, remove_stop, only_isalpha, remove_puncts, min_length, de_hyphen, word_count=500):

    global tokens
    
    p =  container['progress']
    p.value = 0
    try:
        output1.clear_output()
        output2.clear_output()
        default_output.clear_output()
        document = read_text_file('./data/urn-nbn-se-kb-digark-2106487.txt')
        p.value = p.value + 1

        if display_type == 'source_text':
            # Utskrift av de första och sista 250 tecknen:
            with output1:
                print('{}\n.................\n(NOT SHOWN TEXT)\n.................\n{}'.format(document[:2500], document[-250:]))
            p.value = p.value + 1
            return
        
        p.value = p.value + 1

        tokens = tokenize_and_sanitize(
            document,
            de_hyphen=de_hyphen,
            min_length=min_length,
            only_isalpha=only_isalpha,
            remove_puncts=remove_puncts,
            to_lower=to_lower,
            remove_stop=remove_stop
        )

        if display_type in ['sanitized_text', 'statistics']:
            
            p.value = p.value + 1
            
            if display_type == 'sanitized_text':
                with output1:
                    display('{}\n.................\n(NOT SHOWN TEXT)\n.................\n{}'.format(
                        ' '.join(tokens[:word_count]),
                        ' '.join(tokens[-word_count:])
                    ))
                p.value = p.value + 1
                return
            
            if display_type == 'statistics':

                wf = nltk.FreqDist(tokens)
                p.value = p.value + 1
                
                with output1:
                    
                    df = pd.DataFrame(wf.most_common(25), columns=['token','count'])
                    display(df)
                
                with output2:
                    
                    print('Antal ord (termer): {}'.format(wf.N()))
                    print('Antal unika termer (vokabulär): {}'.format(wf.B()))
                    print(' ')
                    
                    data = list(zip(*wf.most_common(25)))
                    plot_xy_data(data, title='Word distribution', xlabel='Word', ylabel='Word count')
                    
                    wf = nltk.FreqDist([len(x) for x in tokens])
                    data = list(zip(*wf.most_common(25)))
                    plot_xy_data(data, title='Word length distribution', xlabel='Word length', ylabel='Word count')
    
    except Exception as ex:
        raise
        
    finally:
        p.value = 0

i_widgets = widgets.interactive(display_document, **container)
default_output = i_widgets.children[-1]
display(widgets.VBox([
    widgets.HBox([
        container['display_type'],
        container['to_lower'],
        container['remove_stop'],
        container['de_hyphen'],
        container['only_isalpha'],
        container['remove_puncts']
    ]),
    widgets.HBox([container['min_length'], container['progress']]),
    widgets.HBox([output1, output2]),
    default_output
]))

i_widgets.update()

In [None]:
import textacy
import textacy.datasets

cw = textacy.datasets.SupremeCourt()
cw.download()
records = cw.records()

txt_strm, meta_strm = textacy.fileio.split_record_fields(records, 'text')
corpus = textacy.Corpus(u'en', texts=txt_strm, metadatas=meta_strm)
vectorizer = textacy.Vectorizer(
    weighting='tfidf',
    normalize=True,
    smooth_idf=True,
    min_df=2,
    max_df=0.95
)
doc_term_matrix = vectorizer.fit_transform((
    doc.to_terms_list(ngrams=1, named_entities=True, as_strings=True)
    for doc in corpus
))
print(repr(doc_term_matrix))
model = textacy.TopicModel('nmf', n_topics=10)
model.fit(doc_term_matrix)
doc_topic_matrix = model.transform(doc_term_matrix)
print(doc_topic_matrix.shape)

for t_idx, top_terms in model.top_topic_terms(vectorizer.id_to_term, top_n=10):
    print('topic', t_idx, ':', ' '.join(top_terms))

In [None]:
# folded code
# -*- coding: utf-8 -*-
import nltk
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
import re
import string

punct_table = str.maketrans('', '', string.punctuation)

%matplotlib inline

hyphen_regexp = re.compile(r'\b(\w+)-\s*\r?\n\s*(\w+)\b', re.UNICODE)

def read_text_file(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        document = f.read()
    return document

def tokenize_and_sanitize(document, de_hyphen=True, min_length=3, only_isalpha=True, remove_puncts=True, to_lower=True, remove_stop=True):
        
    # hantera avstavningar
    if de_hyphen:
        document = re.sub(hyphen_regexp, r"\1\2\n", document)

    tokens = nltk.word_tokenize(document)
    
    # Ta bort ord kortare än tre tecken
    if min_length > 0:
        tokens = [ x for x in tokens if len([w for w in x if w.isalpha()]) > min_length ]
        
    # Ta bort ord som inte innehåller någon siffra eller bokstav
    if only_isalpha:
        tokens = [ x for x in tokens if x.isalpha() ]

    if remove_puncts:
        tokens = [ x.translate(punct_table) for x in tokens ]

    # Transformera till små bokstäver
    if to_lower:
        tokens = [ x.lower() for x in tokens ]
        
    # Ta bort de vanligaste stoporden
    if remove_stop:
        stopwords = nltk.corpus.stopwords.words('swedish')
        tokens = [ x for x in tokens if x not in stopwords ]
        
    return [ x for x in tokens if len(x) > 0 ]

def plot_xy_data(data, title='', xlabel='', ylabel='', **kwargs):
    
    x = list(data[0])
    y = list(data[1])
    labels = x

    plt.figure(figsize=(8, 8 / 1.618))
    plt.plot(x, y, 'ro', **kwargs)
    plt.xticks(x, labels, rotation='45')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    plt.show()
    
container=dict(
    display_type=widgets.Dropdown(
        description='Show',
        value='statistics',
        options={
            'Source text': 'source_text',
            'Sanitized text': 'sanitized_text',
            'Statistics': 'statistics'           
    }),
    min_length=widgets.IntSlider(value=0, min=0, max=5, step=1, description='Min alpha', tooltip='Min number of alphabetic characters'),
    de_hyphen=widgets.ToggleButton(value=False, description='Dehyphen', disabled=False, tooltip='Fix hyphens', icon=''),
    to_lower=widgets.ToggleButton(value=False, description='Lowercase', disabled=False, tooltip='Transform text to lowercase', icon=''),
    remove_stop=widgets.ToggleButton(value=False, description='No stopwords', disabled=False, tooltip='Remove stopwords', icon=''),
    only_isalpha=widgets.ToggleButton(value=False, description='Only alpha', disabled=False, tooltip='Keep only alphabetic words', icon=''),
    remove_puncts=widgets.ToggleButton(value=False, description='Remove puncts.', disabled=False, tooltip='Remove punctioations characters', icon=''),
    progress=widgets.IntProgress(value=0, min=0, max=5, step=1, description='' )
)

output1 = widgets.Output() #layout={'border': '1px solid black'})
output2 = widgets.Output() #layout={'border': '1px solid black'})
default_output = None

tokens = []

def display_document(display_type, to_lower, remove_stop, only_isalpha, remove_puncts, min_length, de_hyphen, word_count=500):

    global tokens
    
    p =  container['progress']
    p.value = 0
    try:
        output1.clear_output()
        output2.clear_output()
        default_output.clear_output()
        document = read_text_file('./data/urn-nbn-se-kb-digark-2106487.txt')
        p.value = p.value + 1

        if display_type == 'source_text':
            # Utskrift av de första och sista 250 tecknen:
            with output1:
                print('{}\n.................\n(NOT SHOWN TEXT)\n.................\n{}'.format(document[:2500], document[-250:]))
            p.value = p.value + 1
            return
        
        p.value = p.value + 1

        tokens = tokenize_and_sanitize(
            document,
            de_hyphen=de_hyphen,
            min_length=min_length,
            only_isalpha=only_isalpha,
            remove_puncts=remove_puncts,
            to_lower=to_lower,
            remove_stop=remove_stop
        )

        if display_type in ['sanitized_text', 'statistics']:
            
            p.value = p.value + 1
            
            if display_type == 'sanitized_text':
                with output1:
                    display('{}\n.................\n(NOT SHOWN TEXT)\n.................\n{}'.format(
                        ' '.join(tokens[:word_count]),
                        ' '.join(tokens[-word_count:])
                    ))
                p.value = p.value + 1
                return
            
            if display_type == 'statistics':

                wf = nltk.FreqDist(tokens)
                p.value = p.value + 1
                
                with output1:
                    
                    df = pd.DataFrame(wf.most_common(25), columns=['token','count'])
                    display(df)
                
                with output2:
                    
                    print('Antal ord (termer): {}'.format(wf.N()))
                    print('Antal unika termer (vokabulär): {}'.format(wf.B()))
                    print(' ')
                    
                    data = list(zip(*wf.most_common(25)))
                    plot_xy_data(data, title='Word distribution', xlabel='Word', ylabel='Word count')
                    
                    wf = nltk.FreqDist([len(x) for x in tokens])
                    data = list(zip(*wf.most_common(25)))
                    plot_xy_data(data, title='Word length distribution', xlabel='Word length', ylabel='Word count')
    
    except Exception as ex:
        raise
        
    finally:
        p.value = 0

i_widgets = widgets.interactive(display_document, **container)
default_output = i_widgets.children[-1]
display(widgets.VBox([
    widgets.HBox([
        container['display_type'],
        container['to_lower'],
        container['remove_stop'],
        container['de_hyphen'],
        container['only_isalpha'],
        container['remove_puncts']
    ]),
    widgets.HBox([container['min_length'], container['progress']]),
    widgets.HBox([output1, output2]),
    default_output
]))

i_widgets.update()


### TRY IT: Språkbanken NER tagging av SOU

In [48]:
import nltk
import matplotlib.pyplot as plt
import ipywidgets as widgets
%matplotlib inline

import pandas as pd

entities = pd.read_csv('./data/SOU_1990_total_ner_extracted.csv', sep='\t',
                       names=['filename', 'year', 'location', 'categories', 'entity'])

entities['document_id'] = entities.filename.apply(lambda x: int(x.split('_')[1]))
entities['categories'] = entities.categories.str.replace('/', ' ')
entities['category'] = entities.categories.str.split(' ').str.get(0)
entities['sub_category'] = entities.categories.str.split(' ').str.get(1)

entities.drop(['location', 'categories'], inplace=True, axis=1)

document_names = pd.read_csv('./data/SOU_1990_index.csv',
                             sep='\t',
                             names=['year', 'sequence_id', 'report_name']).set_index('sequence_id')

def plot_freqdist(wf, n=25, **kwargs):
    data = list(zip(*wf.most_common(n)))
    x = list(data[0])
    y = list(data[1])
    labels = x

    plt.figure(figsize=(13, 13/1.618))
    plt.plot(x, y, '--ro', **kwargs)
    plt.xticks(x, labels, rotation='45')
    plt.show()

doc_names = { v: k for k, v in document_names.report_name.to_dict().items()}
doc_names['** ALL DOCUMENTS **'] = 0
@widgets.interact(category=entities.category.unique())
def display_most_frequent_pos_tags(document_id=doc_names, category='LOC', top=10):
    global entities
    locations = entities
    if document_id > 0:
        locations = locations.loc[locations.document_id==document_id]
    locations = locations.loc[locations.category==category]['entity']
    location_freqs = nltk.FreqDist(locations)
    #location_freqs.tabulate()
    plot_freqdist(location_freqs, n=top)

# display_most_frequent_pos_tags()

FileNotFoundError: File b'./data/SOU_1990_total_ner_extracted.csv' does not exist

### <span style='color:blue'>MANDATORY STEP</span> Setup and Initialize the Notebook
Use the **play** button, or press **Shift-Enter** to execute a code cell (select it first). The code imports Python libraries and frameworks, and initializes the notebook.

In [None]:
# Folded Code
%load_ext autoreload
%autoreload 2

import common.utility
from common.model_utility import ModelUtility
from common.plot_utility import layout_algorithms, PlotNetworkUtility
import common.widgets_utility as wf
from common.network_utility import NetworkUtility, DISTANCE_METRICS, NetworkMetricHelper
#import common.vectorspace_utility

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=FutureWarning) 

import os
import glob
import math
import types
import ipywidgets as widgets
import logging
import bokeh.models as bm
import bokeh.palettes
import pandas as pd
import numpy as np

from pivottablejs import pivot_ui
from IPython.display import display, HTML, clear_output, IFrame
from itertools import product
from bokeh.io import output_file, push_notebook
from bokeh.core.properties import value, expr
from bokeh.transform import transform, jitter
from bokeh.layouts import row, column, widgetbox
from bokeh.plotting import figure, show, output_notebook, output_file
from bokeh.models.widgets import DataTable, DateFormatter, TableColumn
from bokeh.models import ColumnDataSource, CustomJS

logger = logging.getLogger('explore-topic-models')
TOOLS = "pan,wheel_zoom,box_zoom,reset,previewsave"
AGGREGATES = { 'mean': np.mean, 'sum': np.sum, 'max': np.max, 'std': np.std }

output_notebook()

pd.set_option('precision', 10)

### <span style='color:blue'>MANDATORY STEP</span> Select LDA Topic Model
- Select one of the previously computed and prepared topic models that you wan't to use in subsequent steps.
- Models are computed in batch in accordance to 
<a href="./images/workflow-prepare.svg">process flow</a> used in the *Digitala modeller* project.
- Note that subsequent code cells are NOT updated (executed) automatically when a new model is selected.
- Use the **play** button, or press **Shift-Enter** to execute the selected cell.

In [None]:
# Hidden code: Select current model state
class ModelState:
    
    def __init__(self, data_folder):
        
        self.data_folder = data_folder
        self.basenames = ModelUtility.get_model_names(data_folder)
        self.basename = self.basenames[0]
        self.on_set_model_callback = None
        
    def set_model(self, basename=None):

        basename = basename or self.basename
        
        self.basename = basename
        self.topic_keys = ModelUtility.get_topic_keys(self.data_folder, basename)
        state.max_alpha = self.topic_keys.alpha.max()
        self.topic_overview = ModelUtility\
            .get_result_model_sheet(self.data_folder, basename, 'topic_tokens')
        self.document_topic_weights = ModelUtility\
            .get_result_model_sheet(self.data_folder, basename, 'doc_topic_weights')\
            .drop('Unnamed: 0', axis=1, errors='ignore')
        self.topic_token_weights = ModelUtility\
            .get_result_model_sheet(self.data_folder, basename, 'topic_token_weights')\
            .drop('Unnamed: 0', axis=1, errors='ignore')\
            .dropna(subset=['token'])
        self._years = list(range(
            self.document_topic_weights.year.min(), self.document_topic_weights.year.max() + 1))
        self.min_year = min(self._years)
        self.max_year = max(self._years)
        self.years = [None] + self._years
        self.n_topics = self.topic_overview.topic_id.max() + 1
        # https://stackoverflow.com/questions/44561609/how-does-mallet-set-its-default-hyperparameters-for-lda-i-e-alpha-and-beta
        self.initial_alpha = 0.0  # 5.0 / self.n_topics if 'mallet' in state.basename else 1.0 / self.n_topics
        self.initial_beta = 0.0  # 0.01 if 'mallet' in basename else 1.0 / self.n_topics
        self._lda = None
        self._topic_titles = None
        self.corpus_documents = ModelUtility.get_corpus_documents(self.data_folder, self.basename).set_index('document_id')
        print("Current model: " + self.basename.upper())
        
        if self.on_set_model_callback is not None:
            self.on_set_model_callback(self)
            
        # _fix_topictokens()
        return self
    
    #def get_document_topic_weights(self, year=None, topic_id=None):
    #    df = self.document_topic_weights
    #    if year is None and topic_id is None:
    #        return df
    #    if topic_id is None:
    #        return df[(df.year == year)]
    #    if year is None:
    #        return df[(df.topic_id == topic_id)]
    #    return df[(df.year == year)&(df.topic_id == topic_id)]
    
    def get_unique_topic_ids(self):
        return self.document_topic_weights['topic_id'].unique()
    
    #def get_topic_weight_by_year_or_document(self, key='mean', pivot_column=None):
    #    
    #    if pivot_column is None:
    #        pivot_column = 'year' if year is None else 'document_id'    
    #        
    #    df = self.document_topic_weights(year) \
    #        .groupby([pivot_column,'topic_id']) \
    #        .agg(AGGREGATES[key])[['weight']].reset_index()
    #    return df, pivot_column
    
    #return self.get_document_topic_weight_by_pivot_column(pivot_column, key, filter={'column': 'year', 'values': [year]})
    
    def get_document_topic_weight_by_filter(self, filters=None):
        df = self.document_topic_weights.query('weight > 0')
        for filter in (filters or []):
            if 'query' in filter.keys():
                df = df.query(filter['query'])
            elif isinstance(filter['value'], str):
                df = df[(df[filter['column']]==filter['value'])]
            elif isinstance(filter['value'], list):
                df = df[(df[filter['column']].isin(filter['value']))]
        return df
    
    def get_document_topic_weight_by_pivot_column(self, pivot_column, key='mean', filters=None):
        df = self.get_document_topic_weight_by_filter(filters)
        df = df.groupby([pivot_column, 'topic_id'])\
               .agg(AGGREGATES[key])[['weight']].reset_index()
        return df[df.weight > 0]
    
    def get_topic_tokens_dict(self, topic_id, n_top=200):
        return self.get_topic_tokens(topic_id)\
            .sort_values(['weight'], ascending=False)\
            .head(n_top)[['token', 'weight']]\
            .set_index('token').to_dict()['weight']

    def compute_topic_terms_vector_space(self, n_words=100):
        '''
        Create an align topic-term vector space of top n_words from each topic
        '''
        unaligned_vector_dicts = ( self.get_topic_tokens_dict(topic_id, n_words) for topic_id in range(0, self.n_topics) )
        X, feature_names = ModelUtility.compute_and_align_vector_space(unaligned_vector_dicts)
        return X, feature_names

    def get_lda(self):
        raise Exception("Use of LDA model disabled in this Notebook")
        '''
        Get gensim model. Only used for pyLDAvis display
        '''
        if self._lda is None:
            filename = os.path.join(self.data_folder, self.basename, 'gensim_model_{}.gensim.gz'.format(self.basename))
            if os.path.isfile(filename):
                self._lda = LdaModel.load(filename)
                print('LDA model loaded...')
            else:
                print('LDA not found on disk...')
        return self._lda 
    
    def get_topic_titles(self, n_words=100, cache=True):
        if cache and self._topic_titles is not None:
            return self._topic_titles
        _topic_titles = ModelUtility.get_topic_titles(state.topic_token_weights, n_words=n_words)
        self._topic_titles = _topic_titles if cache else None
        return _topic_titles
    
    def get_topic_tokens(self, topic_id, max_n_words=500):
        tokens = state.topic_token_weights\
            .loc[lambda x: x.topic_id == topic_id]\
            .sort_values('weight',ascending=False)[:max_n_words]
        return tokens
    
    def get_topic_alphas(self):
        tokens = state.topic_token_weights\
            .loc[lambda x: x.topic_id == topic_id]\
            .sort_values('weight',ascending=False)[:max_n_words]
        alpas = ModelUtility.get_topic_alphas
        return tokens
    
    def get_topic_year_aggregate_weights(self, fn, threshold):
        df = self.document_topic_weights[(self.document_topic_weights.weight > 0.001)]
        df = df.groupby(['year', 'topic_id']).agg(fn)['weight'].reset_index()
        df = df[(df.weight>=threshold)]
        return df
    
    def get_topic_proportions(self):
        corpus_documents = self.get_corpus_documents()
        document_topic_weights = self.document_topic_weights
        topic_proportion = ModelUtility.compute_topic_proportions(document_topic_weights, corpus_documents)
        return topic_proportion
    
    def get_corpus_documents(self):
        #if self.corpus_documents is None:
        #    self.corpus_documents = ModelUtility.get_corpus_documents(self.data_folder, self.basename)
        return self.corpus_documents

    def on_set_model(self, callback):
        self.on_set_model_callback = callback
        return self
        
def on_set_model_handler(state):

    if 'report_name' in state.corpus_documents:
        return
    
    state.source_documents = pd.read_csv('data/SOU_1990_index.csv', sep='\t', header=None, names=['year', 'report_id', 'report_name'])
    state.corpus_documents['report_id'] = state.corpus_documents.document.str.split('_').apply(lambda x: x[1]).astype(np.int64)
    state.corpus_documents['report_name'] = pd.merge(state.corpus_documents, state.source_documents, how='inner', on=['year', 'report_id']).report_name
    state.corpus_documents['report_name'] = state.corpus_documents.apply(lambda x: '{}-{} {}'.format(x['year'], x['report_id'], x['report_name'])[:50], axis=1)
    state.document_topic_weights['report_name'] = pd.merge(state.document_topic_weights, state.corpus_documents, left_on='document_id', right_index=True).report_name

def select_model_main(state):
    
    basename_widget = widgets.Dropdown(
        options=state.basenames,
        value=state.basename,
        description='Topic model',
        disabled=False,
        layout=widgets.Layout(width='75%')
    )
    
    w = widgets.interactive(state.set_model, basename=basename_widget, state=widgets.fixed(state))
    display(widgets.VBox((basename_widget,) + (w.children[-1],)))
    w.update()

state = ModelState('./data').on_set_model(on_set_model_handler)

select_model_main(state)


### Topic-Word Distribution - Wordcloud and Table

In [None]:
# Display LDA topic's token wordcloud
opts = { 'max_font_size': 100, 'background_color': 'white', 'width': 900, 'height': 600 }

import wordcloud
import matplotlib.pyplot as plt

def display_topic_distribution_widgets(callback, state, text_id, output_options=None, word_count=(1, 100, 50)):
    
    output_options = output_options or []
    wc = wf.BaseWidgetUtility(
        n_topics=state.n_topics,
        text_id=text_id,
        text=wf.create_text_widget(text_id),
        topic_id=widgets.IntSlider(
            description='Topic ID', min=0, max=state.n_topics - 1, step=1, value=0, continuous_update=False),
        word_count=widgets.IntSlider(
            description='#Words', min=word_count[0], max=word_count[1], step=1, value=word_count[2], continuous_update=False),
        output_format=wf.create_select_widget('Format', output_options, default=output_options[0], layout=widgets.Layout(width="200px")),
        progress = widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="95%"))
    )

    wc.prev_topic_id = wc.create_prev_id_button('topic_id', state.n_topics)
    wc.next_topic_id = wc.create_next_id_button('topic_id', state.n_topics)

    iw = widgets.interactive(
        callback,
        topic_id=wc.topic_id,
        n_words=wc.word_count,
        output_format=wc.output_format,
        widget_container=widgets.fixed(wc)
    )

    display(widgets.VBox([
        wc.text,
        widgets.HBox([wc.prev_topic_id, wc.next_topic_id, wc.topic_id, wc.word_count, wc.output_format]),
        wc.progress,
        iw.children[-1]
    ]))

    iw.update()

def plot_wordcloud(df_data, token='token', weight='weight', figsize=(14, 14/1.618), **args):
    token_weights = dict({ tuple(x) for x in df_data[[token, weight]].values })
    image = wordcloud.WordCloud(**args,)
    image.fit_words(token_weights)
    plt.figure(figsize=figsize) #, dpi=100)
    plt.imshow(image, interpolation='bilinear')
    plt.axis("off")
    plt.show()
    
def display_wordcloud(topic_id=0, n_words=100, output_format='Wordcloud', widget_container=None):
    widget_container.progress.value = 1
    df_temp = state.topic_token_weights.loc[(state.topic_token_weights.topic_id == topic_id)]
    tokens = state.get_topic_titles(n_words=n_words, cache=True).iloc[topic_id]
    widget_container.value = 2
    widget_container.text.value = 'ID {}: {}'.format(topic_id, tokens)
    if output_format == 'Wordcloud':
        plot_wordcloud(df_temp, 'token', 'weight', max_words=n_words, **opts)
    elif output_format == 'Table':
        widget_container.progress.value = 3
        df_temp = state.get_topic_tokens(topic_id, n_words)
        widget_container.progress.value = 4
        display(HTML(df_temp.to_html()))
    else:
        display(pivot_ui(state.get_topic_tokens(topic_id, n_words)))
    widget_container.progress.value = 0

display_topic_distribution_widgets(display_wordcloud, state, 'tx02', ['Wordcloud', 'Table', 'Pivot'])


### Topic-Word Distribution - Chart
The following chart shows the word distribution for each selected topic. You can zoom in on the left chart. The distribution seems to follow [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law) as (perhaps) expected.

In [None]:
# Display topic's word distribution

def plot_topic_word_distribution(tokens, **args):

    source = ColumnDataSource(tokens)

    p = figure(toolbar_location="right", **args)

    cr = p.circle(x='xs', y='ys', source=source)

    label_style = dict(level='overlay', text_font_size='8pt', angle=np.pi/6.0)

    text_aligns = ['left', 'right']
    for i in [0, 1]:
        label_source = ColumnDataSource(tokens.iloc[i::2])
        labels = bm.LabelSet(x='xs', y='ys', text_align=text_aligns[i], text='token', text_baseline='middle',
                          y_offset=5*(1 if i == 0 else -1),
                          x_offset=5*(1 if i == 0 else -1),
                          source=label_source, **label_style)
        p.add_layout(labels)

    p.xaxis[0].axis_label = 'Token #'
    p.yaxis[0].axis_label = 'Probability%'
    p.ygrid.grid_line_color = None
    p.xgrid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_text_font_size = "6pt"
    p.axis.major_label_standoff = 0
    return p

def plot_topic_tokens_charts(tokens, flag=True):

    if flag:
        left = plot_topic_word_distribution(tokens, plot_width=1000, plot_height=500, title='', tools='box_zoom,wheel_zoom,pan,reset')
        show(left)
        return

    left = plot_topic_word_distribution(tokens, plot_width=450, plot_height=500, title='', tools='box_zoom,wheel_zoom,pan,reset')
    right = plot_topic_word_distribution(tokens, plot_width=450, plot_height=500, title='', tools='pan')

    source = ColumnDataSource({'x':[], 'y':[], 'width':[], 'height':[]})
    left.x_range.callback = create_js_callback('x', 'width', source)
    left.y_range.callback = create_js_callback('y', 'height', source)

    rect = bm.Rect(x='x', y='y', width='width', height='height', fill_alpha=0.0, line_color='blue', line_alpha=0.4)
    right.add_glyph(source, rect)

    show(row(left, right))

def display_topic_tokens(topic_id=0, n_words=100, output_format='Chart', widget_container=None):
    widget_container.forward()
    tokens = state.get_topic_tokens(topic_id=topic_id).\
        copy()\
        .drop('topic_id', axis=1)\
        .assign(weight=lambda x: 100.0 * x.weight)\
        .sort_values('weight', axis=0, ascending=False)\
        .reset_index()\
        .head(n_words)
    if output_format == 'Chart':
        widget_container.forward()
        tokens = tokens.assign(xs=tokens.index, ys=tokens.weight)
        plot_topic_tokens_charts(tokens)
        widget_container.forward()
    elif output_format == 'Table':
        #display(tokens)
        display(HTML(tokens.to_html()))
    else:
        display(pivot_ui(tokens))
    widget_container.reset()
        
display_topic_distribution_widgets(display_topic_tokens, state, 'wc01', ['Chart', 'Table'])


### Topic's Trend Over Time or Documents
- Displays topic's share over documents or time.
- Note that source documents (i.e. SOU reports) are splitted into 1000 word chunks (LDA document) by the topic modelling process
- If "SOU Report" or "Year" is selected then the **max** or **mean** weight is selected from corresponding LDA documents

In [None]:
# Plot a topic's yearly weight over time in selected LDA topic model
import numpy as np
import math
import bokeh.plotting
from bokeh.models import ColumnDataSource, DataRange1d, Plot, LinearAxis, Grid
from bokeh.models.glyphs import VBar
from bokeh.io import curdoc, show

def plot_topic_trend(df, pivot_column, value_column, x_label=None, y_label=None):

    xs = df[pivot_column].astype(np.str)
    p = bokeh.plotting.figure(x_range=xs, plot_width=1000, plot_height=700, title='', tools=TOOLS, toolbar_location="right")

    glyph = p.vbar(x=xs, top=df[value_column], width=0.5, fill_color="#b3de69")
    p.xaxis.major_label_orientation = math.pi/4
    p.xgrid.grid_line_color = None
    p.xaxis[0].axis_label = (x_label or '').title()
    p.yaxis[0].axis_label = (y_label or '').title()
    p.y_range.start = 0.0
    #p.y_range.end = 1.0
    p.x_range.range_padding = 0.01
    return p

def display_topic_trend(topic_id, pivot_config, value_column, widgets_container, output_format='Chart', state=None, threshold=0.01):
    
    pivot_column = pivot_config['pivot_column']
    tokens = state.get_topic_titles(n_words=200, cache=True).iloc[topic_id]
    widgets_container.text.value = 'ID {}: {}'.format(topic_id, tokens)
    value_column = value_column if pivot_column is not None else 'weight'
    
    df = state.document_topic_weights[(state.document_topic_weights.topic_id==topic_id)]
    
    if pivot_column is not None:
        df = df.groupby([pivot_column]).agg([np.mean, np.max])['weight'].reset_index()
        df.columns = [pivot_column, 'mean', 'max' ]
        df = df[(df[value_column] > threshold)]
        
    if output_format == 'Table':
        display(df)
    else:
        x_label = pivot_column.title()
        y_label = value_column.title() + ('weight' if value_column != 'weight' else '')
        p = plot_topic_trend(df, pivot_column, value_column, x_label=x_label, y_label=y_label)
        show(p)

def create_topic_trend_widgets(state):
    pivot_options = {
        '': { 'pivot_column': None, 'filter': None },
        'SOU Report': { 'pivot_column': 'report_name', 'filter': None },
        'Year': { 'pivot_column': 'year', 'filter': None },
        'LDA Document': { 'pivot_column': 'document_id', 'filter': None }
    } 
    wc = wf.BaseWidgetUtility(
        n_topics=state.n_topics,
        text_id='topic_share_plot',
        text=wf.create_text_widget('topic_share_plot'),
        #year=wf.create_select_widget('Year', options=state.years, value=state.years[-1]),
        pivot_config=widgets.Dropdown(
            options=pivot_options,
            value=pivot_options['SOU Report'],
            description='Group by'
        ),
        threshold=widgets.FloatSlider(description='Threshold', min=0.0, max=0.25, step=0.01, value=0.10, continuous_update=False),
        topic_id=widgets.IntSlider(description='Topic ID', min=0, max=state.n_topics - 1, step=1, value=0, continuous_update=False),
        output_format=wf.create_select_widget('Format', ['Chart', 'Table'], default='Chart'),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="50%")),
        aggregate=widgets.Dropdown(options=['max', 'mean'], value='max', description='Aggregate')
    )

    wc.prev_topic_id = wc.create_prev_id_button('topic_id', state.n_topics)
    wc.next_topic_id = wc.create_next_id_button('topic_id', state.n_topics)

    iw = widgets.interactive(
        display_topic_trend,
        topic_id=wc.topic_id,
        pivot_config=wc.pivot_config,
        value_column=wc.aggregate,
        widgets_container=widgets.fixed(wc),
        output_format=wc.output_format,
        state=widgets.fixed(state),
        threshold=wc.threshold
    )
    display(widgets.VBox([
        wc.text,
        widgets.HBox([wc.prev_topic_id, wc.next_topic_id, wc.pivot_config, wc.aggregate, wc.output_format]),
        widgets.HBox([wc.topic_id, wc.threshold, wc.progress]),
        iw.children[-1]
    ]))
    
    iw.update()
    
create_topic_trend_widgets(state)

### Topic to Document Network
The green nodes are documents, and blue nodes are topics. The edges (lines) indicates the strength of a topic in the connected document. The width of the edge is proportinal to the strength of the connection. Note that only edges with a strength above the certain threshold are displayed.

In [None]:
# Visualize year-to-topic network by means of topic-document-weights
     
def plot_topic_year_network(network, layout, scale=1.0, titles=None):

    year_nodes, topic_nodes = NetworkUtility.get_bipartite_node_set(network, bipartite=0)  
    
    year_source = NetworkUtility.get_node_subset_source(network, layout, year_nodes)
    topic_source = NetworkUtility.get_node_subset_source(network, layout, topic_nodes)
    lines_source = NetworkUtility.get_edges_source(network, layout, scale=6.0, normalize=False)
    
    edges_alphas = NetworkMetricHelper.compute_alpha_vector(lines_source.data['weights'])
    
    lines_source.add(edges_alphas, 'alphas')
    
    p = figure(plot_width=1000, plot_height=600, x_axis_type=None, y_axis_type=None, tools=TOOLS)
    
    r_lines = p.multi_line(
        'xs', 'ys', line_width='weights', alpha='alphas', color='black', source=lines_source
    )
    r_years = p.circle(
        'x','y', size=40, source=year_source, color='lightgreen', level='overlay', line_width=1,alpha=1.0
    )
    
    r_topics = p.circle('x','y', size=25, source=topic_source, color='skyblue', level='overlay', alpha=1.00)
    
    p.add_tools(bm.HoverTool(renderers=[r_topics], tooltips=None, callback=wf.WidgetUtility.\
        glyph_hover_callback(topic_source, 'node_id', text_ids=titles.index, text=titles, element_id='nx_id1'))
    )

    text_opts = dict(
        x='x', y='y', text='name', level='overlay',
        x_offset=0, y_offset=0, text_font_size='8pt'
    )
    
    p.add_layout(
        bm.LabelSet(
            source=year_source, text_color='black', text_align='center', text_baseline='middle', **text_opts
        )
    )
    p.add_layout(
        bm.LabelSet(
            source=topic_source, text_color='black', text_align='center', text_baseline='middle', **text_opts
        )
    )
    
    return p

def main_topic_year_network(state):
    
    wc = wf.BaseWidgetUtility(
        n_topics=state.n_topics,
        text_id='nx_id1',
        text=wf.create_text_widget('nx_id1', style="display: inline; height='400px'"),
        year=widgets.IntSlider(description='Year', min=state.min_year, max=state.max_year, step=1, value=state.min_year, continues_update=False),
        pivot_column=widgets.Dropdown(
            options={
                'SOU report': 'report_name',
                'Year': 'year'
            },
            value='report_name',
            description='Pivot'
        ),
        scale=widgets.FloatSlider(description='Scale', min=0.0, max=1.0, step=0.01, value=0.1, continues_update=False),
        threshold=widgets.FloatSlider(description='Threshold', min=0.0, max=1.0, step=0.01, value=0.50, continues_update=False),
        output_format=widgets.Dropdown(
            options={'Network': 'network', 'Table': 'table'},
            value='network',
            description='Output'
        ),
        layout=widgets.Dropdown(
            options=list(layout_algorithms.keys()),
            value='Fruchterman-Reingold',
            description='Layout'
        ),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="40%"))
    ) 
    
    wc.previous = wc.create_prev_id_button('year', 10000)
    wc.next = wc.create_next_id_button('year', 10000)    
    
    def display_topic_year_network(
        layout_algorithm,
        threshold=0.50,
        scale=1.0,
        pivot_column='report_name',
        year=None,
        output_format='network'
    ):
        wc.progress.value = 1
        
        titles = state.get_topic_titles()
        filters = []
        if year is not None:
            filters = [ { 'column': 'year', 'value': year }]
        filters = filters + [ { 'query': 'weight >= {}'.format(threshold) } ]
        df = state.get_document_topic_weight_by_pivot_column(pivot_column, key='max', filters=filters)
        df = df[df.weight > threshold]
        
        wc.progress.value = 2

        network = NetworkUtility.create_bipartite_network(df, pivot_column, 'topic_id')
        
        wc.progress.value = 3

        if output_format == 'network':
            
            args = PlotNetworkUtility.layout_args(layout_algorithm, network, scale)
            layout = (layout_algorithms[layout_algorithm])(network, **args)
            
            wc.progress.value = 4
            
            p = plot_topic_year_network(network, layout, scale=scale, titles=titles)
            show(p)

        elif output_format == 'table':
            print(df.shape)
            display(df)
        else:
            display(pivot_ui(df))

        wc.progress.value = 0

    iw = widgets.interactive(
        display_topic_year_network,
        layout_algorithm=wc.layout,
        threshold=wc.threshold,
        scale=wc.scale,
        pivot_column=wc.pivot_column,
        year=wc.year,
        output_format=wc.output_format
    )

    display(widgets.VBox([
        wc.text,
        widgets.HBox([wc.layout, wc.year, wc.previous, wc.next]),
        widgets.HBox([wc.pivot_column, wc.scale]),
        widgets.HBox([wc.output_format, wc.threshold, wc.progress]),
        iw.children[-1]
    ]))
    iw.update()
    
main_topic_year_network(state)


### Topic Trends - Heatmap
- The topic shares  displayed as a scattered heatmap plot using gradient color based on topic's weight in document.
- [Stanford’s Termite software](http://vis.stanford.edu/papers/termite) uses a similar visualization.

In [None]:
# plot_topic_relevance_by_year

def setup_glyph_coloring(df):
    max_weight = df.weight.max()
    #colors = list(reversed(bokeh.palettes.Greens[9]))
    colors = ["#efefef", "#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878",
              "#933b41", "#550b1d"]
    mapper = bm.LinearColorMapper(palette=colors, low=df.weight.min(), high=max_weight)
    color_transform = transform('weight', mapper)
    color_bar = bm.ColorBar(color_mapper=mapper, location=(0, 0),
                         ticker=bm.BasicTicker(desired_num_ticks=len(colors)),
                         formatter=bm.PrintfTickFormatter(format=" %5.2f"))
    return color_transform, color_bar

def plot_topic_relevance_by_year(df, xs, ys, flip_axis, glyph, titles, text_id):

    line_height = 7
    if flip_axis is True:
        xs, ys = ys, xs
        line_height = 10
    
    ''' Setup axis categories '''
    x_range = list(map(str, df[xs].unique()))
    y_range = list(map(str, df[ys].unique()))
    
    ''' Setup coloring and color bar '''
    color_transform, color_bar = setup_glyph_coloring(df)
    
    source = ColumnDataSource(df)

    plot_height = max(len(y_range) * line_height, 500)
    
    p = figure(title="Topic heatmap", tools=TOOLS, toolbar_location="right", x_range=x_range,
           y_range=y_range, x_axis_location="above", plot_width=1000, plot_height=plot_height)

    args = dict(x=xs, y=ys, source=source, alpha=1.0, hover_color='red')
    
    if glyph == 'Circle':
        cr = p.circle(color=color_transform, **args)
    else:
        cr = p.rect(width=1, height=1, line_color=None, fill_color=color_transform, **args)

    p.x_range.range_padding = 0
    p.ygrid.grid_line_color = None
    p.xgrid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_text_font_size = "5pt"
    p.axis.major_label_standoff = 0
    p.xaxis.major_label_orientation = 1.0
    p.add_layout(color_bar, 'right')
    
    p.add_tools(bm.HoverTool(tooltips=None, callback=wf.WidgetUtility.glyph_hover_callback(
        source, 'topic_id', titles.index, titles, text_id), renderers=[cr]))
    
    return p

def topic_heatmap_main(state):
    
    def display_topic_relevance_by_year(state, key='max', pivot_column=None, year=None, flip_axis=False, glyph='Circle', wdgs=None):
        
        try:
            wdgs.reset()
            wdgs.forward()
            
            titles = ModelUtility.get_topic_titles(state.topic_token_weights, n_words=100)
            wdgs.forward()

            year = (year or 0)
            
            pivot_column = 'year' if year > 0 else (pivot_column or 'report_name')
            filters = [{'column': 'year', 'values': [year]}] if year > 0 else []
            
            df = state.get_document_topic_weight_by_pivot_column(pivot_column, key, filters=filters)
            
            wdgs.forward()
            
            df[pivot_column] = df[pivot_column].astype(str)
            df['topic_id'] = df.topic_id.astype(str)
            
            wdgs.forward()
            
            p = plot_topic_relevance_by_year(df, xs=pivot_column, ys='topic_id', flip_axis=flip_axis, glyph=glyph, titles=titles, text_id='topic_relevance')
            
            show(p)
            wdgs.reset()
        except Exception as ex:
            raise
            logger.error(ex)
        finally:
            wdgs.reset()

    wc = wf.BaseWidgetUtility(
        text_id='topic_relevance',
        text=wf.create_text_widget('topic_relevance'),
        year=widgets.Dropdown(options=state.years, value=None, description='Year', layout=widgets.Layout(width="140px")),
        pivot_column=widgets.Dropdown(
            options={
                'SOU report': 'report_name',
                # 'LDA document': 'document_id',
                'Year': 'year'
            },
            value='report_name',
            description='Pivot',
            layout=widgets.Layout(width="200px")
        ),
        aggregate=widgets.Dropdown(options=['max', 'mean'], value='max', description='Aggregate', layout=widgets.Layout(width="180px")),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="35%")),
        glyph=widgets.Dropdown(options=['Circle', 'Square'], value='Square', description='Glyph', layout=widgets.Layout(width="180px")),
        flip_axis=widgets.ToggleButton(value=True, description='Flip XY', tooltip='Flip X and Y axis', icon='', layout=widgets.Layout(width="80px"))
    )

    iw = widgets.interactive(
        display_topic_relevance_by_year,
        state=widgets.fixed(state),
        key=wc.aggregate,
        pivot_column=wc.pivot_column,
        year=wc.year,
        glyph=wc.glyph,
        flip_axis=wc.flip_axis,
        wdgs=widgets.fixed(wc)
    )

    display(widgets.VBox([
        widgets.HBox([wc.pivot_column, wc.year, wc.aggregate, wc.flip_axis, wc.glyph, wc.progress ]),
        wc.text,
        iw.children[-1]
    ]))

    iw.update()
            
topic_heatmap_main(state)