# <span style='color:red'>THIS NOTEBOOK IS DEPRECATED!</span>

### The Culture of International Relations

#### About this project
Cultural treaties are the bi-lateral and multilateral agreements among states that promote and regulate cooperation and exchange in the fields of life generally call cultural or intellectual. Although it was only invented in the early twentieth century, this treaty type came to be the fourth most common bilateral treaty in the period 1900-1980 (Poast et al., 2010). In this project, we seek to use several (mostly European) states’ cultural treaties as a historical source with which to explore the emergence of a global concept of culture in the twentieth century. Specifically, the project will investigate the hypothesis that the culture concept, in contrast to earlier ideas of civilization, played a key role in the consolidation of the post-World War II international order.

The central questions that interest me here can be divided into two groups: 
- First, what is the story of the cultural treaty, as a specific tool of international relations, in the twentieth century? What was the historical curve of cultural treaty-making? For example, in which political or ideological constellations do we find (the most) use of cultural treaties? Among which countries, in which historical periods? What networks of relations were thereby created, reinforced, or challenged? 
- Second, what is the "culture" addressed in these treaties? That is, what do the two signatories seem to mean by "culture" in these documents, and what does that tell us about the role that concept played in the international system? How can quantitative work on this dataset advance research questions about the history of concepts?

In this notebook, we deal with these treaties in three ways:
1) quantitative analysis of "metadata" about all bilateral cultural treaties signed betweeen 1919 and 1972, as found in the World Treaty Index or WTI (Poast et al., 2010).
    For more on how exactly we define a "cultural treaty" here, and on other principles of selection, see... [add this, using text now in "WTI quality assurance"].
2) network analysis of the system of international relationships created by these treaties (using data from WTI, as above).
3) Text analysis of the complete texts of selected treaties. 

After some set-up sections, the discussion of the material begins at "Part 1," below.

### Brief Instructions on Jupyter Notebooks
Please see [this tutorial](https://www.youtube.com/watch?v=h9S4kN4l5Is) for an introduction on what Jupyter notebooks are and how to use them. There are lots of other Jupyter tutorials on YouTube (and elsewhere) as well. In short, a notebook is a document with embedded executable code presented in a simple and easy to use web interface. Most important things to note are:
- Click on the menu Help -> User Interface Tour for an overview of the Jupyter Notebook App user interface.
- The **code cells** contains the script code (Python in this case, but can be other languages are also suported) and are the sections marked by **In [x]** in the left margin. It is marked as **In []** if it hasn't been executed, and as **In [n]** when it has been executed(n is an integer). A cell marked as **In [\*]** is either executing, or waiting to be executed (i.e. other cells are executing).
- The **current cell** is highlighted with a blue (or green if in "edit" mode) border. You make a cell current by clicking on it,
- Code cells aren't executed automatically. Instead you execute the current cell by either pressing **shift+enter** or the **play** button in the toolbar. The output (or result) of a cell's execution is presented directly below the cell prefixed by **Out[n]**.
- The next cell will automatically be selected (made current) after a cell has been executed. Repeatadly pressing **shift+enter** or the play button hence executes the cells in sequence.
- You can run the entire notebook in a single step by clicking on the menu Cell -> Run All. Note that this can take some time to finish. You can see how cells are executed in sequence via the indicator in the margin (i.e. "In [\*]" changes to "In [n]" where n is an integer).
- The cells can be edited if they are double-clicked, in which case the cell border turns green. Use the ESC key to escape edit mode (or click on any other cell).

To restart the kernel (i.e. the computational engine assigned to your session), click on the menu Kernel -> Restart. 


### <span style='color:green'>**Optional Prepare Step**</span>: Update WTI data from Google Drive
The statistics computed on this page is dependent on a recent verison of the WTI treaties master list. This file is stored on Google Drive, and the script "./google_drive.py" can be used to download and update the data. Please note that the load script below reads CSV-files, with specific names, so a manual download of the master list must be followed by saving each sheet as an CSV. The script ./google_drive.py does this automatically.


In [1]:
# Code: Update WTI master data from Google Drive
%run ./common/google_drive
%run ./common/widgets_utility

import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

files_to_download = {
    'WTI Master Index': {
        'file_id': '1V8KPeghLQ2iOMWkbPqff480zDSLa5YDX',
        'destination': './data/Treaties_Master_List.xlsx',
        'sheets': [ 'Treaties' ]
    },
    'Curated Parties': {
        'file_id': '1k4dOPuqR7oi4K8SazoGN6R40jOBWOdWp',
        'destination': './data/parties_curated.xlsx',
        'sheets': ['parties', 'group', 'continent']
    },
    'Country & Continent': {
        'file_id': '19lEmVPu7hNmr1MaMpU0VvKL7muu-OKg9',
        'destination': './data/country_continent.csv',
        'sheets': [ ]
    }
}

def update_file(file, confirm):
    global upw
    if file is None:
        return
    if confirm is False:
        print('Please confirm update by checking the CONFIRM button!')
        return
    upw.confirm.value = False
    print('Updatating Google file with ID: {}'.format(file['file_id']))
    process_file(file, overwrite=confirm)
    
upw = BaseWidgetUtility(
    file=widgets.Dropdown(
        options=files_to_download,
        value=None,
        description='File:',
    ),
    confirm=widgets.ToggleButton(
        description='Confirm',
        button_style='',
        icon='check',
        value=False
    ),)
iupw = widgets.interactive(update_file, file=upw.file, confirm=upw.confirm)
display(widgets.VBox([widgets.HBox([upw.file, upw.confirm]), iupw.children[-1]]))
# iupw.update()

VBox(children=(HBox(children=(Dropdown(description='File:', options={'WTI Master Index': {'sheets': ['Treaties…

In [2]:
%%html
<style>
.jupyter-widgets {
    font-size: 8pt;
}
.widget-label {
    font-size: 8pt;
}
.widget-dropdown > select {
    font-size: 8pt;
}
</style>

### <span style='color:blue'>**Mandatory Prepare Step**</span>: Setup Notebook
The following code cell must to be executed once for each user session. The step loads utility Python code stored in separate files, and imports dependencies to external libraries. The following external libraries are used:
<table>
    <tr><td>[NLTK](https://www.nltk.org/)</td><td>NLP framework. *Natural Language Toolkit*</td><td>*Bird, Steven, Edward Loper and Ewan Klein (2009),<br/>Natural Language Processing with Python.<br/>O’Reilly Media Inc.*</td><td></td></tr>
    <tr><td> [gensim](https://radimrehurek.com/gensim/index.html)</td><td>NLP framework. *Topic Modelling for Humans*</td><td>[Google scholar](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=9vG_kV0AAAAJ&citation_for_view=9vG_kV0AAAAJ:NaGl4SEjCO4C)</td><td></td></tr>
    <tr><td>[python_louvain](https://github.com/taynaud/python-louvain)</td><td>Louvain Community Detection</td><td> https://python-louvain.readthedocs.io/</td><td></td></tr>
    <tr><td>[wordcloud](https://github.com/amueller/word_cloud)</td><td>Wordcloud generator</td><td>https://github.com/amueller/word_cloud</td><td></td></tr>
    <tr><td>[Graphviz](https://www.graphviz.org/)</td><td>Graphviz is an open source graph visualization software.</td><td>https://www.graphviz.org/</td><td></td></tr>
    <tr><td>[bokeh](http://bokeh.pydata.org/en/latest/)</td><td>Bokeh is an interactive visualization library.</td><td>http://bokeh.pydata.org/en/latest/</td><td></td></tr>
    <tr><td>[pandas](https://pandas.pydata.org/)</td><td>Data structures and data analysis tools</td><td>https://pandas.pydata.org/</td><td></td></tr>
    <tr><td>[NetworkX](https://networkx.github.io/)</td><td>Package for the creation, manipulation, and study of complex networks.</td><td>https://networkx.github.io/</td><td></td></tr>
    <tr><td>[graph_tool](https://graph-tool.skewed.de/)</td><td>Module for manipulation and statistical analysis of graphs networks.</td><td>https://graph-tool.skewed.de/</td><td></td></tr>
    <tr><td>[PyTables](http://www.pytables.org/)</td><td></td><td></td><td></td></tr>
    <tr><td>[scipy](http://www.scipy.org/), [numpy](http://www.numpy.org/)</td><td></td><td></td><td></td></tr>
</table>

Pandas, bokeh, Jupyter, numpy and PyTables are all sponsored by [NumFOCUS](https://www.numfocus.org/sponsored-projects).

In [3]:
# Setup
%run ./common/file_utility
%run ./common/network_utility
%run ./common/widgets_utility

import os
import re
import glob
import logging
import fnmatch
import datetime
import wordcloud
import warnings
import pandas as pd
import numpy as np
import networkx as nx
import bokeh.plotting as bp
import bokeh.palettes
import bokeh.models as bm
import bokeh.io
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib
import zipfile
import nltk.tokenize
import nltk.corpus
import gensim.models

from pivottablejs import pivot_ui
from math import sqrt
from bokeh.io import push_notebook
from gensim.corpora.textcorpus import TextCorpus

from IPython.display import display, HTML #, clear_output, IFrame
from IPython.core.interactiveshell import InteractiveShell

logging.basicConfig(format="%(asctime)s : %(levelname)s : %(message)s", level=logging.ERROR)
logger = logging.getLogger()
logger.setLevel(logging.ERROR)

TOOLS = "pan,wheel_zoom,box_zoom,reset,hover,previewsave"

InteractiveShell.ast_node_interactivity = "all"
warnings.filterwarnings('ignore')
bp.output_notebook()

%run ./common/utility
#%autosave 120
%config IPCompleter.greedy=True

matplotlib_plot_styles =[
    'ggplot',
    'bmh',
    'seaborn-notebook',
    'seaborn-whitegrid',
    '_classic_test',
    'seaborn',
    'fivethirtyeight',
    'seaborn-white',
    'seaborn-dark',
    'seaborn-talk',
    'seaborn-colorblind',
    'seaborn-ticks',
    'seaborn-poster',
    'seaborn-pastel',
    'fast',
    'seaborn-darkgrid',
    'seaborn-bright',
    'Solarize_Light2',
    'seaborn-dark-palette',
    'grayscale',
    'seaborn-muted',
    'dark_background',
    'seaborn-deep',
    'seaborn-paper',
    'classic'
]

output_formats = {
    'Plot vertical bar': 'plot_bar',
    'Plot horisontal bar': 'plot_barh',
    'Plot vertical bar, stacked': 'plot_bar_stacked',
    'Plot horisontal bar, stacked': 'plot_barh_stacked',
    'Plot line': 'plot_line',
    'Plot stacked line': 'plot_line_stacked',
    # 'Chart ': 'chart',
    'Table': 'table',
    'Pivot': 'pivot'
}



toggle_style = dict(icon='', layout=widgets.Layout(width='100px', left='0'))
drop_style = dict(layout=widgets.Layout(width='260px'))

### <span style='color:blue'>**Mandatory Prepare Step**</span>: Configuration elements
The following code cell must to be executed once for each user session.

In [4]:
# Settings

period_divisions = [
    [ (1919, 1939), (1940, 1944), (1945, 1955), (1956, 1966), (1967, 1972) ],
    [ (1919, 1944), (1945, 1955), (1956, 1966), (1967, 1972) ]
]

parties_of_interest = ['FRANCE', 'GERMU', 'ITALY', 'GERMAN', 'UK', 'GERME', 'GERMW', 'INDIA', 'GERMA' ]


period_group_options = {
    'Year': 'signed_year',
    'Default division': 'signed_period',
    'Alt. division': 'signed_period_alt'
}

default_party_options = {
    'Top #n parties': None,
    'PartyOf5': parties_of_interest,
    'France': [ 'FRANCE' ],
    'France vs UK': [ 'FRANCE', 'UK' ],
    'France vs ALL': [ 'FRANCE', None ],
    'Italy': [ 'ITALY' ],
    'UK': [ 'UK' ],
    'India': [ 'INDIA' ],
    'Germany, after 1991': [ 'GERMU' ],
    'Germany, before 1945': [ 'GERMAN' ],
    'East Germany': [ 'GERME' ],
    'West Germany': [ 'GERMW' ],
    'Germany, allied occupation': [ 'GERMA' ],
    'Germany (all)': [ 'GERMU', 'GERMAN', 'GERME', 'GERMW', 'GERMA' ],
    'China': [ 'CHINA' ]
}

category_group_settings = {
    '7CULT, 7SCIEN, and 7EDUC': {
        '7CULT': ['7CULT'],
        '7SCIEN': ['7SCIEN'],
        '7EDUC': ['7EDUC']
    },
    '7CULT, 7SCI, and 7EDUC+4EDUC': {
        '7CULT': ['7CULT'],
        '7SCIEN': ['7SCIEN'],
        '7EDUC+4EDUC': ['7EDUC', '4EDUC']
    },
    '7CULT + 1AMITY': {
        '7CULT': ['7CULT'],
        '1AMITY': ['1AMITY']
    },
    '7CULT + 1ALLY': {
        '7CULT': ['7CULT'],
        '1ALLY': ['1ALLY']
    },
    '7CULT + 1DIPLOMACY': {
        '7CULT': ['7CULT'],
        'DIPLOMACY': ['1ALLY', '1AMITY', '1ARMCO', '1CHART', '1DISPU', '1ESTAB', '1HEAD', '1OCCUP', '1OPTC', '1PEACE', '1RECOG', '1REPAR', '1STATU', '1TERRI', '1TRUST']
    },
    '7CULT + 2WELFARE': {
        '7CULT': ['7CULT'],
        '2WELFARE': [ '2HEW', '2HUMAN','2LABOR', '2NARK', '2REFUG', '2SANIT', '2SECUR', '2WOMEN' ]
    },
    '7CULT + 3ECONOMIC': {
        '7CULT': ['7CULT'],
        'ECONOMIC': ['3CLAIM', '3COMMO', '3CUSTO', '3ECON', '3INDUS', '3INVES', '3MOSTF', '3PATEN', '3PAYMT', '3PROD', '3TAXAT', '3TECH', '3TOUR','3TRADE','3TRAPA']
    },
    '7CULT + 4AID': {
        '7CULT': ['7CULT'],
        '4AID': ['4AGRIC','4AID', '4ATOM', '4EDUC', '4LOAN', '4MEDIC', '4MILIT', '4PCOR', '4RESOU', '4TECA', '4UNICE']
    },
    '7CULT + 5TRANSPORT': {
        '7CULT': ['7CULT'],
        '5TRANSPORT': ['5AIR', '5LAND', '5TRANS', '5WATER']
    },
    '7CULT + 6COMMUNICATIONS': {
        '7CULT': ['7CULT'],
        '6COMMUNICATIONS': ['6COMMU', '6MEDIA', '6POST', '6TELCO']
    },
    '7CULTURE': {
        '7CULT': ['7CULT'],
        '7EDUC': ['7EDUC'],
        '7RELIG': ['7RELIG'],
        '7SCIEN': ['7SCIEN'],
        '7SEMIN': ['7SEMIN'],
        '7SPACE': ['7SPACE']
    },
    '7CULT': {
        '7CULT': ['7CULT']
    },
    '7CULT + 8RESOURCES': {
        '7CULT': ['7CULT'],
        '8RESOURCES': ['8AGRIC', '8CATTL', '8ENERG', '8ENVIR', '8FISH', '8METAL', '8WATER', '8WOOD']
    },
    '7CULT + 9ADMINISTRATION': {
        '7CULT': ['7CULT'],
        '9ADMINISTRATION': ['9ADMIN', '9BOUND', '9CITIZ', '9CONSU', '9LEGAL', '9MILIT', '9MILMI', '9PRIVI', '9VISAS', '9XTRAD' ]
    },
}

category_group_maps = { 
    category_name: { v: k for k in category_group_settings[category_name].keys() for v in category_group_settings[category_name][k]  }
        for category_name in category_group_settings.keys()
}
    

### <span style='color:blue'>**Mandatory Prepare Step**</span>: Load and Process Treaty Master Index
The following code cell to be executed once for each user session. The code loads the WTI master index (and some related data files), and prepares the data for subsequent use.

The treaty data is processed as follows:
- All the treaty data are loaded.Extract year treaty was signed as seperate fields
- Add new fields for specified signed period divisions
- Fields 'group1' and 'group2' are ignored (many missing values). Instead group are fetched via party code from encoding found in the "groups" table.

In [6]:
# Load and process treaties master index

%run ./common/treaty_state

def load_treaty_state():
    global state
    try:
        state = TreatyState()
        print("Data loaded!")
    except Exception as ex:
        logger.error(ex)
        print('Load failed! Have you run setup cell above?')

load_treaty_state()


Data loaded!


### Step Sanity Checks: Per field (pair) value counts

In [7]:
# Code

treaty_fields = [
    '', 'is_cultural_yesno', 'source', 'party1', 'party2', 'laterality',
    'headnote', 'topic', 'topic1', 'topic2', 'title', 'signed_year', 'signed_period', 'signed_period_alt', 'is_cultural'
]    

def display_variable_stats(field1, field2, crosstab):
    
    columns = [ x for x in set([field1, field2 ]) if x != '' ]
    
    if len(columns) > 0:
        df = state.treaties.groupby(columns).size().reset_index()\
            .rename(columns={0: 'Count'})\
            .sort_values(['Count'], ascending=False)
            
        if crosstab is True:
            if len(columns) == 2:
                display(pd.crosstab(df[field1], df[field1]))
            else:
                print('Both fields are needed for crosstab')
        else:
            display(df)
        # df.set_index(columns).plot.bar(figsize=(16,8))

def sanity_check_main():
    
    sw = BaseWidgetUtility(
        field1=wf.create_select_widget('Field 1:', treaty_fields, default=''),
        field2=wf.create_select_widget('Field 2:', treaty_fields, default=''),
        crosstab=widgets.ToggleButton(
            description='Crosstab',
            button_style='',
            icon='check'
        ),
    )

    isw = widgets.interactive(display_variable_stats, field1=sw.field1, field2=sw.field2, crosstab=sw.crosstab)

    display(widgets.VBox([widgets.HBox([sw.field1, sw.field2, sw.crosstab]), isw.children[-1]]))

    isw.update()
    
sanity_check_main()


VBox(children=(HBox(children=(Dropdown(description='Field 1:', options=('', 'is_cultural_yesno', 'source', 'pa…

### Chart: Treaty Quantities by Selected Parties 
This chart displays the number of treaties per party, or group of parties, and year or period divisions. The default division has periods 1919-1944, 1945-1955, 1956-1966, and 1967-1972, and the alternative division  has 1940-1944 as an additional period. Use the "Top #n" slider to select how many parties to display for each period ordered by how many treaties the parties signed. The "Only Cultural" button filters out treaties having field "is_cultural" set to "yes".

<span style='color: green;'>**DONE: Top #n should show only n countries per selected period **</span><br>
<span style='color: green;'>**DONE: Labels for each year not visible when "Plot line" is selected **</span><br>
<span style='color: green;'>**DONE: Add Country vs Country selection ALL included - (two listboxes) **</span>
<span style='color: green;'>**BUG: Review 7CULT recode **</span><br>
<span style='color: green;'>**BUG: ALL doesn't work! (France vs ALL) **</span><br>
<span style='color: green;'>**TODO: Add ALL as separate baseline (inclusive and exclusive made selection) **</span><br>

In [8]:
# Code
%matplotlib inline
import matplotlib.pyplot as plt
#colors = hsv(np.linspace(0, 1.0, 16))
colors = bokeh.palettes.Category20[20] #plt.get_cmap('jet')(np.linspace(0, 1.0, 16))

def plot_treaties_per_period(data, output_format, plot_style, figsize=(12,6), xlabel='', ylabel='', xticks=None):
    matplotlib.style.use(plot_style)
    stacked = 'stacked' in output_format
    kind = output_format.split('_')[1]
    ax = data.plot(kind=kind, stacked=stacked, figsize=figsize, color=colors)
    
    if xticks is not None:
        ax.set_xticks(xticks)
    
    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)

    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    
    # Put a legend to the right of the current axis
    legend = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    legend.get_frame().set_linewidth(0.0)

    for tick in ax.get_xticklabels():        
        tick.set_rotation(45)
    
def get_top_parties(data, period, party_name, n_top=5):
    xd = data.groupby([period, party_name]).size().rename('TopCount').reset_index()
    top_list = xd.groupby([period]).apply(lambda x: x.nlargest(n_top, 'TopCount'))\
        .reset_index(level=0, drop=True)\
        .set_index([period, party_name])
    return top_list

def display_treaties_per_period(
    period,
    party_name,
    parties_selection,
    only_is_cultural=False,
    normalize_values=False,
    output_format='chart',
    plot_style='classic',
    top_n_parties=5
    ):
        
    try:

        data = state.stacked_treaties.copy()
                        
        # if only_within_period_of_interest:
        data = data.loc[(data.signed_period!='other')]

        if only_is_cultural:
            data = data.loc[(data.is_cultural==True)]

        if isinstance(parties_selection, list):
            data = data.loc[(data.party.isin(parties_selection))]
                            
        data = data.merge(state.parties, how='left', left_on='party', right_index=True)
        
        n_top_list = get_top_parties(data, period, party_name, n_top=top_n_parties)
               
        data = data.groupby([period, party_name])\
                .size()\
                .reset_index()\
                .rename(columns={ period: 'Period', party_name: 'Party', 0: 'Count' })

        if parties_selection is None:
            data = data.merge(n_top_list, how='inner', left_on=['Period', 'Party'], right_index=True)

        pivot = pd.pivot_table(data, index=['Period'], values=["Count"], columns=['Party'], fill_value=0)
        pivot.columns = [ x[-1] for x in pivot.columns ]
    
        if period == 'signed_year':
            missing_years = [ x for x in range(data.Period.min(), data.Period.max() + 1) if x not in pivot.index ]
            for year in missing_years:
                pivot.loc[year] = len(pivot.columns) * [0]
            pivot.sort_index(axis=0, inplace=True)
    
        if normalize_values is True:
            pivot = pivot.div(0.01 * pivot.sum(1), axis=0)

        if output_format.startswith('plot'):

            label = 'Number of treaties' if not normalize_values else 'Share%'

            ylabel = label if 'barh' not in output_format else ''
            xlabel = label if 'barh' in output_format else ''

            height = 10 if 'barh' in output_format and period == 'signed_year' else 6

            xticks = list(range(data.Period.min(), data.Period.max() + 1)) if 'line' in output_format and period == 'signed_year' else None
            
            plot_treaties_per_period(pivot, output_format, plot_style, figsize=(18, height), xlabel=xlabel, ylabel=ylabel, xticks=xticks)

        elif output_format == 'table':
            display(data)
            # display(HTML(data.to_html()))
        else:
            display(pivot)
            
    except Exception as ex:
        logger.error(ex)
        raise

def treaty_quantities_by_selected_parties_main():
    
    tw = BaseWidgetUtility(
        period=widgets.Dropdown(
            options=period_group_options,
            value='signed_period',
            description='Period:',
            layout=widgets.Layout(width='250px')
        ),
        party_name=widgets.Dropdown(
            options={
                'WTI Code': 'party',
                'WTI Name': 'party_name',
                'WTI Short': 'short_name',
                'Country': 'party_country'
            },
            value='party_name',
            description='Name:',
            layout=widgets.Layout(width='250px')
        ),
        country1=widgets.Dropdown(
            options=[None] + state.get_countries_list(),
            value=None,
            description='Country#1:',
            layout=widgets.Layout(width='250px')
        ),
        country2=widgets.Dropdown(
            options=[None] + state.get_countries_list(),
            value=None,
            description='Country#2:',
            layout=widgets.Layout(width='250px')
        ),
        parties_selection=widgets.Dropdown(
            options=default_party_options,  # https://stackoverflow.com/questions/35023744/how-to-order-entries-in-ipywidgets-dropdown-or-select
            value=default_party_options['PartyOf5'],
            description='Parties:',
            layout=widgets.Layout(width='250px')
        ),

        only_is_cultural=widgets.ToggleButton(
            description='Only Cultural', value=True, **toggle_style
        ),
        normalize_values=widgets.ToggleButton(
            description='Share%', **toggle_style
        ),
        output_format=widgets.Dropdown(
            description='Output', options=output_formats, value='plot_bar_stacked', layout=widgets.Layout(width='300px')
        ),
        plot_style=widgets.Dropdown(
            options=matplotlib_plot_styles, value='seaborn-pastel',
            description='Style:', layout=widgets.Layout(width='300px')
        ),
        top_n_parties=widgets.IntSlider(
            value=3, min=1, max=10, step=1,
            description='Top #:',
            continuous_update=True,
            layout=widgets.Layout(width='220px')
        )
    )

    itw = widgets.interactive(
        display_treaties_per_period,
        period=tw.period,
        party_name=tw.party_name,
        parties_selection=tw.parties_selection,
        only_is_cultural=tw.only_is_cultural,
        normalize_values=tw.normalize_values,
        output_format=tw.output_format,
        plot_style=tw.plot_style,
        top_n_parties=tw.top_n_parties
    )

    def on_country_change(change):
        if tw.country1.value is None or tw.country2.value is None:
            return
        name = (tw.country1.value or '') + ' vs ' + (tw.country2.value or '')
        if name not in tw.parties_selection.options.keys():
            tw.parties_selection.index = None
            new_options = extend(dict(tw.parties_selection.options), {name: [tw.country1.value, tw.country2.value]})
            tw.parties_selection.options = new_options
        tw.parties_selection.value = new_options[name]

    tw.country1.observe(on_country_change, names='value')
    tw.country2.observe(on_country_change, names='value')

    def on_parties_change(change):
        try:
            # print(change)
            # tw.top_n_parties.value = 0
            tw.top_n_parties.disabled = change['new'] is not None
        except Exception as ex:
            logger.info(ex)

    tw.parties_selection.observe(on_parties_change, names='value')

    first_column_box = widgets.VBox([tw.period, tw.party_name,])
    second_column_box = widgets.VBox([ tw.parties_selection, tw.top_n_parties ])
    another_column_box = widgets.VBox([ tw.country1, tw.country2 ])
    third_column_box = widgets.VBox([ tw.only_is_cultural, tw.normalize_values])
    fourth_column_box = widgets.VBox([ tw.output_format, tw.plot_style ])
    boxes = widgets.HBox([first_column_box, second_column_box, another_column_box, third_column_box, fourth_column_box ])
    display(widgets.VBox([boxes, itw.children[-1]]))
    itw.update()

treaty_quantities_by_selected_parties_main()


VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Period:', layout=Layout(width='250px'), opt…

###  Chart: Treaty Quantities by Selected Topics
This report displays the number of treaties per division periodand WTI category, or groups of categories. The "recode 7CULT" flag sets all treaties having "is_cultural" true to 7CULT, other "7NOCULT". Note that treaties categorized as 7CULT is always included, even if "is_cultural" is false. When the "+other" flag is checked, *all* other treaties are included recoded as an "OTHER" category.

Note that currently the grouping is only based on "topic1" i.e. "topic2" is ignored in this report.

** TODO: Recode 7CULT funkar inte som den ska. Man skall kunna välja WTI's ursprungsindelning ELLER Ben's omkodning enligt is_cultural (då omdöpt 7CORR). ** <br>
** TODO: Try multiselect of countries** <br>
** DONE: Add way to display a chart per country ** <br>


In [9]:
# Code
%matplotlib inline
from IPython.display import clear_output

def plot_display_quantity_of_topics(pivot, kind, stacked, xlabel='', ylabel='', plot_style='classic', figsize=(12,10), **kwargs):

    matplotlib.style.use(plot_style)
    
    ax = pivot.plot(kind=kind, stacked=stacked, figsize=figsize, **kwargs)

    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)
    # legend = ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=4)
    legend = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    legend.get_frame().set_linewidth(0.0)
    
    for tick in ax.get_xticklabels():
        tick.set_rotation(45)    

def display_quantity_of_topics(
    period,
    category_map_name,
    recode_7cult=False,
    normalize_values=False,
    include_other=False,
    output_format='chart',
    plot_style='classic',
    parties_selection=None
    ):
    try:
        
        data = state.treaties.copy()
        data = data.loc[(data.signed_period!='other')]

        if isinstance(parties_selection, list):
            data = data.loc[(data.party1.isin(parties_selection)|(data.party2.isin(parties_selection)))]
           
        category_map = category_group_maps[category_map_name]
        
        if not include_other:
            data = data.loc[(data.topic1.isin(category_map.keys())) | (data.topic2.isin(category_map.keys()))]

        if data.shape[0] == 0:
            print('No data for: ' + ','.join(parties_selection))
            return
            
        data['category'] = data.apply(lambda x: category_map.get(x['topic1'], category_map.get(x['topic2'], 'OTHER')), axis=1)
        data = data\
                .groupby([period, 'category'])\
                .size()\
                .reset_index()\
                .rename(columns={ period: 'Period', 'category': 'Category', 0: 'Count' })

        pivot = pd.pivot_table(data, index=['Period'], values=["Count"], columns=['Category'], fill_value=0)
        pivot.columns = [ x[-1] for x in pivot.columns ]
        
        if period == 'signed_year':
            missing_years = [ x for x in range(data.Period.min(), data.Period.max() + 1) if x not in pivot.index ]
            for year in missing_years:
                pivot.loc[year] = len(pivot.columns) * [0]
            pivot.sort_index(axis=0, inplace=True)

        if normalize_values is True:
            pivot = pivot.div(0.01 * pivot.sum(1), axis=0)

        if output_format.startswith('plot'):

            label1 = 'Number of treaties' if not normalize_values else 'Share%'
            title = 'Parties ' + ', '.join(parties_selection)
            
            ylabel = label1 if 'barh' not in output_format else '' 
            xlabel = label1 if 'barh' in output_format else ''

            stacked = 'stacked' in output_format
            kind = output_format.split('_')[1]
            height = 10 if 'barh' in output_format and period == 'signed_year' else 6

            plot_display_quantity_of_topics(
                pivot, kind=kind, stacked=stacked, xlabel=xlabel, ylabel=ylabel, plot_style=plot_style, figsize=(16,height), title=title
            )

        elif output_format == 'chart':
            print('bokeh plot not implemented')
            data.plot.line(figsize=(12,8))
        elif output_format == 'table':
            display(data)
            #display(HTML(data.to_html()))
        else:
            display(pivot)
    except Exception as ex:
        logger.error(ex)
        # raise
        
def treaty_quantities_by_selected_topics_main():

    party_options = {
        'PartyOf5': parties_of_interest,
        'China': [ 'CHINA' ],
        'France': [ 'FRANCE' ],
        'Italy': [ 'ITALY' ],
        'UK': [ 'UK' ],
        'India': [ 'INDIA' ],
        'Germany, after 1991': [ 'GERMU' ],
        'Germany, before 1945': [ 'GERMAN' ],
        'East Germany': [ 'GERME' ],
        'West Germany': [ 'GERMW' ],
        'Germany, allied occupation': [ 'GERMA' ],
        'Germany (all)': [ 'GERMU', 'GERMAN', 'GERME', 'GERMW', 'GERMA' ]
    }
    
    widget_container = BaseWidgetUtility(
        period=widgets.Dropdown(
            options=period_group_options,
            value='signed_period',
            description='Period:', layout=widgets.Layout(width='300px')
        ),
        category_map_name=widgets.Dropdown(
            options=category_group_maps.keys(),
            value='7CULTURE',
            description='Category:', layout=widgets.Layout(width='300px')
        ),
        recode_7cult=widgets.ToggleButton(
            description='Recode 7CULT',
            tooltip='Treat all treaties with cultural=yes as 7CULT',
            value=False, layout=widgets.Layout(width='120px')
        ),
        normalize_values=widgets.ToggleButton(
            description='Normalize%',
            tooltip='Display shares per category instead of count', layout=widgets.Layout(width='120px')
        ),
        include_other=widgets.ToggleButton(
            description='+Other', value=False,  layout=widgets.Layout(width='120px')
        ),
        output_format=widgets.Dropdown(
            description='Output',
            value='plot_bar_stacked',
            options=output_formats, **drop_style
        ),
        plot_style=widgets.Dropdown(
            options=matplotlib_plot_styles,
            value='seaborn-pastel',
            description='Style:', **drop_style
        ),   
        parties_selection=widgets.Dropdown(
            options=party_options,
            value=party_options['PartyOf5'],
            description='Parties:',
            layout=widgets.Layout(width='300px')
        ),
        chart_per_party=widgets.ToggleButton(
            description='Chart per party',
            tooltip='Display one chart per party', layout=widgets.Layout(width='120px')
        )
    )

    def display_quantity_of_topics_proxy(
        period,
        category_map_name,
        recode_7cult=False,
        normalize_values=False,
        include_other=False,
        output_format='chart',
        plot_style='classic',
        parties_selection=None,
        chart_per_party=False
    ):
        # clear_output()
        party_groups = [ [ x ] for x in parties_selection ] if chart_per_party else [ parties_selection ]
        for party_group in party_groups:
            display_quantity_of_topics(period, category_map_name, recode_7cult, normalize_values,
                                       include_other, output_format, plot_style, party_group)
        
    itw = widgets.interactive(
        display_quantity_of_topics_proxy,
        period=widget_container.period,
        category_map_name=widget_container.category_map_name,
        recode_7cult=widget_container.recode_7cult,
        normalize_values=widget_container.normalize_values,
        include_other=widget_container.include_other,
        output_format=widget_container.output_format,
        plot_style=widget_container.plot_style,
        parties_selection=widget_container.parties_selection,
        chart_per_party=widget_container.chart_per_party
    )

    boxes = widgets.HBox(
        [
            widgets.VBox([ widget_container.period, widget_container.category_map_name, widget_container.parties_selection]),        
            widgets.VBox([ widget_container.recode_7cult, widget_container.normalize_values, widget_container.chart_per_party]),
            widgets.VBox([ widget_container.include_other]),
            widgets.VBox([ widget_container.output_format, widget_container.plot_style ])
        ]
    )
    display(widgets.VBox([boxes, itw.children[-1]]))
    itw.update()

treaty_quantities_by_selected_topics_main()

VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Period:', layout=Layout(width='300px'), opt…

### Task: Headnote word toplist and word-pair co-occurence toplist
This report displays headnote toplists either single word occurrance or word-word co-occurrance toplists depending on whether or not the "Co-occurrance" is checked. The result is grouped by selected division's periods or by year.

The word co-occurrance is defined as the number of times a pair of words co-occur in the same headnote. The length of headnotes is ignored in the computation (all pairs have equal weight). Multiple occurance of a word in a headnote is taken into account i.e "cultural exchange cultural" is counted as two co-occurances, and "cultural exchange exchange cultural" is four co-occurrances. Stopwords are removed if "Remove stopwords" are checked.

Stopwords are always removed from the co-occurrance computation, whilst they are removed from single word occurrance toplist if the "Remove stopwords" flag is checked. The removal is based on NLTK's list of english stopwords (run ```nltk.corpus.stopwords.words('english')``` to display all stopwords).

The toplist can be filtered so that only treaties involving any or one of the five parties of interest are included, and words can be excluded based on character length. Each resulting group can also be restricted by both a maximum number of pairs to display per group, as well as a min co-occurrance count.

In [10]:
# Code
from nltk.stem import WordNetLemmatizer
import qgrid

toggle_style = dict(icon='', layout=widgets.Layout(width='140px', left='0'))

class HeadnoteTokenServiceOLD():

    def __init__(self, tokenizer, stopwords=None, lemmatizer=None, min_word_size=2):
        
        self.transforms = [
            tokenizer,
            lambda ws: ( x for x in ws if len(x) >= min_word_size ),
            lambda ws: ( x for x in ws if any(ch.isalpha() for ch in x)) 
        ]
        
        if stopwords is not None:
            self.transforms += [ lambda ws: ( x for x in ws if x not in stopwords ) ]
            
        if lemmatizer is not None:
            self.transforms += [ lambda ws: ( lemmatizer(x) for x in ws ) ]

    def _apply_transforms(self, ws):
        for f in self.transforms:
            ws = f(ws)
        return list(ws)
    
    def parse_headnotes(self, treaties):
        
        headnotes = treaties['headnote']
        
        texts = [ x.lower() for x in list(headnotes) ]
        #tokens = list(map(self._apply_transforms, texts))
        df = pd.DataFrame({'headnote': headnotes, 'tokens': tokens })
        
        return df
    
    def compute_stacked(self, treaties):
        
        df = self.parse_headnotes(treaties)
        
        df_stacked = pd.DataFrame(df.tokens.tolist(), index=df.index).stack()\
            .reset_index().rename(columns={'level_1': 'sequence_id', 0: 'token'})
            
        return df_stacked
    
    def compute_co_occurrence(self, treaties, pos_tags, only_cultural_treaties=False):

        # Filter out tags based on treaties of interest
        pos_tags = pos_tags.merge(treaties, how='inner', left_on='treaty_id', right_index=True)[[]]
        
        if only_cultural_treaties:
            df_pos_tags = df_pos_tags[(df_pos_tags.is_cultural.str.contains('yes',na=False))]

        # Self join of words within same treaty
        df_co_occurrence = pd.merge(df_pos_tags, df_pos_tags, how='inner', left_on='treaty_id', right_on='treaty_id')
        # Only consider a specific poir once
        df_co_occurrence = df_co_occurrence[(df_co_occurrence.wid_x < df_co_occurrence.wid_y)]
        # Reduce number of returned columns
        df_co_occurrence = df_co_occurrence[['treaty_id', 'year_x', 'is_cultural_x', 'lemma_x', 'lemma_y' ]]
        # Rename columns
        df_co_occurrence.columns = ['treaty_id', 'year', 'is_cultural', 'lemma_x', 'lemma_y' ]

        # Sort token pair so smallest always comes first
        lemma_x = df_co_occurrence[['lemma_x', 'lemma_y']].min(axis=1)
        lemma_y = df_co_occurrence[['lemma_x', 'lemma_y']].max(axis=1)
        df_co_occurrence['lemma_x'] = lemma_x
        df_co_occurrence['lemma_y'] = lemma_y

        return df_co_occurrence

class HeadnoteTokenCorpus():

    def __init__(self, treaties, tokenize=None, stopwords=None, lemmatize=None, min_size=2):
        
        tokenize = tokenize or nltk.tokenize.word_tokenize
        lemmatize = lemmatize or WordNetLemmatizer().lemmatize
        stopwords = stopwords or nltk.corpus.stopwords.words('english')
        
        self.transforms = [
            tokenize,
            lambda ws: ( x for x in ws if len(x) >= min_size ),
            lambda ws: ( x for x in ws if any(ch.isalpha() for ch in x)),
            lambda ws: list(set(ws)) 
        ]
        
        #if stopwords is not None:
        #    self.transforms += [ lambda ws: ( x for x in ws if x not in stopwords ) ]
            
        #if lemmatizer is not None:
        #    self.transforms += [ lambda ws: ( lemmatizer(x) for x in ws ) ]
        
        treaty_tokens = self._compute_stacked(treaties)
        vocabulary = treaty_tokens.token.unique()
        lemmas = list(map(lemmatize, vocabulary))
        lemma_map = { w: l for (w, l) in zip(*(vocabulary, lemmas)) if w != l }
        stopwords_map = { s : True for s in stopwords }
        treaty_tokens['lemma'] = treaty_tokens.token.apply(lambda x: lemma_map.get(x, x))
        treaty_tokens['is_stopword'] = treaty_tokens.token.apply(lambda x: stopwords_map.get(x, False))

        self.treaty_tokens = treaty_tokens.set_index(['treaty_id', 'sequence_id'])
        
    def _apply_transforms(self, ws):
        for f in self.transforms:
            ws = f(ws)
        return list(ws)
    
    def _parse_headnotes(self, treaties):
        
        headnotes = treaties['headnote']
        
        texts = [ x.lower() for x in list(headnotes) ]
        tokens = list(map(self._apply_transforms, texts))
        df = pd.DataFrame({'headnote': headnotes, 'tokens': tokens })
        
        return df
    
    def _compute_stacked(self, treaties):
        
        df = self._parse_headnotes(treaties)
        
        df_stacked = pd.DataFrame(df.tokens.tolist(), index=df.index).stack()\
            .reset_index().rename(columns={'level_1': 'sequence_id', 0: 'token'})
            
        return df_stacked
    
def compute_co_occurrance(treaties):
    
    treaty_tokens = state.treaty_headnote_corpus.treaty_tokens
    i1 = treaties.index
    # i2 = treaty_tokens.reset_index().set_index('treaty_id').index
    i2 = treaty_tokens.index.get_level_values(0)
    treaty_tokens = treaty_tokens[i2.isin(i1)]
    
    treaty_tokens = treaty_tokens.loc[treaty_tokens.is_stopword==False]
    treaty_tokens = treaty_tokens.reset_index().drop(['is_stopword', 'sequence_id'], axis=1).set_index('treaty_id')

    co_occurrance = treaty_tokens.merge(treaty_tokens, how='inner', left_index=True, right_index=True)
    co_occurrance = co_occurrance.loc[(co_occurrance['token_x'] < co_occurrance['token_y'])]
    #co_occurrance['token'] = co_occurrance.apply(lambda row: row[groupby_pair[0]] + ' - ' + row[groupby_pair[1]], axis=1)
    co_occurrance['token'] = co_occurrance.apply(lambda row: ' - '.join([row['token_x'].upper(), row['token_y'].upper()]), axis=1)
    co_occurrance['lemma'] = co_occurrance.apply(lambda row: ' - '.join([row['lemma_x'].upper(), row['lemma_y'].upper()]), axis=1)
    co_occurrance = co_occurrance.assign(is_stopword=False, sequence_id=0)[['sequence_id', 'token', 'lemma', 'is_stopword']]
    
    return co_occurrance

def create_bigram_transformer(documents):
    import gensim.models.phrases
    bigram = gensim.models.phrases.Phrases(map(nltk.tokenize.word_tokenize, documents))
    return lambda ws: bigram[ws]

def remove_snake_case(snake_str):
    return ' '.join(x.title() for x in snake_str.split('_'))

def get_top_partiesssss(data, period, party_name, n_top=5):
    xd = data.groupby([period, party_name]).size().rename('TopCount').reset_index()
    top_list = xd.groupby([period]).apply(lambda x: x.nlargest(n_top, 'TopCount'))\
        .reset_index(level=0, drop=True)\
        .set_index([period, party_name])
    return top_list

result=None
def display_headnote_toplist(
    period=None,
    parties=None,
    extra_groupbys=None,
    only_is_cultural=True,
    use_lemma=False,
    compute_co_occurance=False,
    remove_stopwords=True,
    min_word_size=2,
    n_min_count=1,
    output_format='table',
    n_top=50
    # plot_style=tw.plot_style
):
    global ihnw, result
    
    try:
        hnw.progress.value = 1    
        treaties = state.treaties.loc[state.treaties.signed_period != 'other']

        if state.treaty_headnote_corpus is None:
            print('Preparing headnote corpus for first time use')
            state.treaty_headnote_corpus = HeadnoteTokenCorpus(treaties=treaties)

        if only_is_cultural:
            treaties = treaties.loc[(state.treaties.is_cultural)]

        if parties is not None:
            ids = state.stacked_treaties.loc[(state.stacked_treaties.party.isin(parties))].index
            treaties = treaties.loc[ids]

        hnw.progress.value += 1

        if compute_co_occurance:

            treaty_tokens = compute_co_occurrance(treaties)

        else:

            treaty_tokens = state.treaty_headnote_corpus.treaty_tokens

            if remove_stopwords is True:
                treaty_tokens = treaty_tokens.loc[treaty_tokens.is_stopword==False]

            treaty_tokens = treaty_tokens.reset_index().set_index('treaty_id')

        hnw.progress.value += 1

        treaty_tokens = treaty_tokens\
            .merge(treaties, how='inner', left_index=True, right_index=True)\
            .drop(['sequence', 'is_cultural_yesno', 'source', 'signed', 'headnote', 'is_cultural',
                   'topic1', 'topic2', 'title'], axis=1)

        hnw.progress.value += 1

        token_or_lemma = 'token' if not use_lemma else 'lemma'

        groupbys  = []
        groupbys += [ period ] if not period is None else []
        groupbys += (extra_groupbys or [])
        groupbys += [ token_or_lemma ]

        result = treaty_tokens.groupby(groupbys).size().reset_index().rename(columns={0: 'Count'})

        hnw.progress.value += 1

        ''' Filter out the n_top most frequent words from each group '''
        result = result.groupby(groupbys[-1]).apply(lambda x: x.nlargest(n_top, 'Count'))\
            .reset_index(level=0, drop=True)\
            # .set_index(groupbys)

        if min_word_size > 0:
            result = result.loc[result[token_or_lemma].str.len() >= min_word_size]

        if n_min_count > 1:
            result = result.loc[result.Count >= n_min_count]

        hnw.progress.value += 1

        result = result.sort_values(groupbys[:-1] + ['Count'], ascending=len(groupbys[:-1])*[True] + [False])

        hnw.progress.value += 1

        if output_format in ('table', 'qgrid'):
            result.columns = [ remove_snake_case(x) for x in result.columns ]
            if output_format == 'table':
                display(HTML(result.to_html()))
            else:
                qgrid_widget = qgrid.show_grid(result, show_toolbar=True)
                qgrid_widget
        elif output_format == 'unstack':
            result = result.set_index(groupbys).unstack(level=0).fillna(0).astype('int32')
            result.columns = [ x[1] for x in result.columns ]
            display(HTML(result.to_html()))
        elif output_format == 'unstack_plot':
            result = result.set_index(list(reversed(groupbys))).unstack(level=0).fillna(0).astype('int32')
            result.columns = [ x[1] for x in result.columns ]
            result.plot(kind='bar', figsize=(16,8))

    except Exception as ex:
        logger.error(ex)
        
    hnw.progress.value += 1
    hnw.progress.value = 0

hnw = BaseWidgetUtility(
    period=widgets.Dropdown(
        options={
            '': None,
            'Year': 'signed_year',
            'Default division': 'signed_period',
            'Alt. division': 'signed_period_alt'
        },
        value='signed_period',
        description='Period:', **drop_style
    ),
    parties=widgets.Dropdown(
        options=default_party_options,
        value=None,
        description='Parties:', **drop_style
    ),
    use_lemma=widgets.ToggleButton(
        description='Use lemma', value=False,
        tooltip='Use WordNet lemma', **toggle_style
    ),
    remove_stopwords=widgets.ToggleButton(
        description='Remove stopwords', value=True,
        tooltip='Do not include stopwords', **toggle_style
    ),
    extra_groupbys=widgets.Dropdown(
        options={
            '': None,
            'Topic': [ 'Topic' ],
        },
        value=None,
        description='Groupbys:', **drop_style
    ),
    min_word_size=widgets.BoundedIntText(
        value=2, min=0, max=5, step=1,
        description='Min word:', layout=widgets.Layout(width='140px')
    ),
    only_is_cultural=widgets.ToggleButton(
        description='Only Cultural', value=True,
        tooltip='Display only "is_cultural" treaties', **toggle_style
    ),
    compute_co_occurance=widgets.ToggleButton(
        description='Cooccurrence', value=True,
        tooltip='Compute Cooccurrence', **toggle_style
    ),
    output_format=widgets.Dropdown(
        description='Output', value='table',
        options={
            'Table': 'table',
            'Qgrid': 'qgrid',
            'Unstack': 'unstack',
            'Unstack plot': 'unstack_plot'
        }, **drop_style
    ),
    plot_style=widgets.Dropdown(
        options=matplotlib_plot_styles,
        value='seaborn-pastel',
        description='Style:', **drop_style
    ),
    n_top=widgets.IntSlider(
        value=25, min=2, max=100, step=10,
        description='Top/grp #:', # continuous_update=False,
    ),
    n_min_count=widgets.IntSlider(
        value=2, min=1, max=10, step=1,
        tooltip='Filter out words with count less than specified value',
        description='Min count:', # continuous_update=False,
    ),
    progress=wf.create_int_progress_widget(min=0, max=10, step=1, value=0, layout=widgets.Layout(width='99%')),
)

ihnw = widgets.interactive(
    display_headnote_toplist,
    period=hnw.period,
    parties=hnw.parties,
    extra_groupbys=hnw.extra_groupbys,
    only_is_cultural=hnw.only_is_cultural,
    n_min_count=hnw.n_min_count,
    n_top=hnw.n_top,
    min_word_size=hnw.min_word_size,
    use_lemma=hnw.use_lemma,
    compute_co_occurance=hnw.compute_co_occurance,
    remove_stopwords=hnw.remove_stopwords,
    output_format=hnw.output_format,
    # plot_style=tw.plot_style
)

boxes = widgets.HBox(
    [
        widgets.VBox([ hnw.period, hnw.parties, hnw.min_word_size ]),
        widgets.VBox([ hnw.extra_groupbys, hnw.n_top, hnw.n_min_count]),
        widgets.VBox([ hnw.only_is_cultural, hnw.use_lemma, hnw.remove_stopwords, hnw.compute_co_occurance]),
        widgets.VBox([ hnw.output_format, hnw.progress ])
    ]
)
display(widgets.VBox([boxes, ihnw.children[-1]]))
ihnw.update()


VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Period:', index=1, layout=Layout(width='260…

###  <span style='color:blue'>**Mandatory Step**</span>: Prepare Treaty Text Corpora

This code cell is a mandatory step for subsequent text corpus statistics. 

This step processes the treaty text for from given compressed archive (ZIP-file), each language , and stores in an efficient Market-Matrix (MM) corpus format. The corpora is only stored if it is not previously stored, or the "Force Update" is specified. Note that an update MUST be forced whenever the treaty archive is updated - otherwise the text in the new archive is ignored.

In [18]:
# Code

sort_chained = lambda x, f: list(x).sort(key=f) or x
    
def ls_sorted(path):
    return sort_chained(list(filter(os.path.isfile, glob.glob(path))), os.path.getmtime)
       
class CompressedFileReader(object):

    def __init__(self, archive_pattern, filename_pattern='*.txt'):
        self.archive_pattern = archive_pattern
        self.filename_pattern = filename_pattern

    def __iter__(self):

        for zip_path in glob.glob(self.archive_pattern):
            with zipfile.ZipFile(zip_path) as zip_file:
                filenames = [ name for name in zip_file.namelist() if fnmatch.fnmatch(name, self.filename_pattern) ]
                for filename in filenames:
                    try:
                        with zip_file.open(filename, 'rU') as text_file:
                            content = text_file.read()
                            content = gensim.utils.to_unicode(content, 'utf8', errors='ignore')
                            content = content.replace('-\r\n', '').replace('-\n', '')
                            yield os.path.basename(filename), content
                    except:
                        print('Unicode error: {}'.format(filename))
                        raise
                        
class TreatyCorpus(TextCorpus):

    def __init__(self, content_iterator, dictionary=None, metadata=False, character_filters=None,
                 tokenizer=None, token_filters=None, bigram_transform=False
    ):
        self.content_iterator = content_iterator
        
        token_filters = [
           (lambda tokens: [ x.lower() for x in tokens ]),
           (lambda tokens: [ x for x in tokens if any(map(lambda x: x.isalpha(), x)) ])
        ] + (token_filters or [])
        
        #if bigram_transform is True:
        #    train_corpus = TreatyCorpus(content_iterator, token_filters=[ x.lower() for x in tokens ])
        #    phrases = gensim.models.phrases.Phrases(train_corpus)
        #    bigram = gensim.models.phrases.Phraser(phrases)
        #    token_filters.append(
        #        lambda tokens: bigram[tokens]
        #    )           
        
        super(TreatyCorpus, self).__init__(
            input=True,
            dictionary=dictionary,
            metadata=metadata,
            character_filters=character_filters,
            tokenizer=tokenizer,
            token_filters=token_filters
        )
        
    def getstream(self):
        """Generate documents from the underlying plain text collection (of one or more files).
        Yields
        ------
        str
            Document read from plain-text file.
        Notes
        -----
        After generator end - initialize self.length attribute.
        """
        filenames = []
        num_texts = 0
        for filename, content in self.content_iterator:
            yield content
            filenames.append(filename)
        self.length = num_texts
        self.filenames = filenames
        self.document_names = self._compile_document_names()
                 
    def get_texts(self):
        '''
        This is mandatory method from gensim.corpora.TextCorpus. Returns stream of documents.
        '''
        for document in self.getstream():
            yield self.preprocess_text(document)
            
    def preprocess_text(self, text):
            """Apply `self.character_filters`, `self.tokenizer`, `self.token_filters` to a single text document.
            Parameters
            ---------
            text : str
                Document read from plain-text file.
            Return
            ------
            list of str
                List of tokens extracted from `text`.
            """
            for character_filter in self.character_filters:
                text = character_filter(text)

            tokens = self.tokenizer(text)
            for token_filter in self.token_filters:
                tokens = token_filter(tokens)

            return tokens
        
    def _compile_document_names(self):
        
        document_names = pd.DataFrame(dict(
            document_name=self.filenames,
            treaty_id=[ x.split('_')[0] for x in self.filenames ]
        )).reset_index().rename(columns={'index': 'document_id'})
        
        document_names = document_names.set_index('document_id')   
        dupes = document_names.groupby('treaty_id').size().loc[lambda x: x > 1]
        
        if len(dupes) > 0:
            logger.critical('Warning! Duplicate treaties found in corpus: {}'.format(' '.join(list(dupes.index))))
            
        return document_names

class MmCorpusStatisticsService():
    
    def __init__(self, corpus, dictionary, language):
        self.corpus = corpus
        self.dictionary = dictionary
        self.stopwords = nltk.corpus.stopwords.words(language[1])
        _ = dictionary[0]
        
    def get_total_token_frequencies(self):
        dictionary = self.corpus.dictionary
        freqencies = np.zeros(len(dictionary.id2token))
        document_stats = []
        for document in corpus:
            for i, f in document:
                freqencies[i] += f
        return freqencies

    def get_document_token_frequencies(self):
        from itertools import chain
        '''
        Returns a DataFrame with per document token frequencies i.e. "melts" doc-term matrix
        '''
        data = ((document_id, x[0], x[1]) for document_id, values in enumerate(self.corpus) for x in values )
        pd = pd.DataFrame(list(zip(*data)), columns=['document_id', 'token_id', 'count'])
        pd = pd.merge(self.corpus.document_names, left_on='document_id', right_index=True)

        return pd

    def compute_word_frequencies(self, remove_stopwords):
        id2token = self.dictionary.id2token
        term_freqencies = np.zeros(len(id2token))
        document_stats = []
        for document in self.corpus:
            for i, f in document:
                term_freqencies[i] += f
        stopwords = set(self.stopwords).intersection(set(id2token.values()))
        df = pd.DataFrame({
            'token_id': list(id2token.keys()),
            'token': list(id2token.values()),
            'frequency': term_freqencies,
            'dfs':  list(self.dictionary.dfs.values())
        })
        df['is_stopword'] = df.token.apply(lambda x: x in stopwords)
        if remove_stopwords is True:
            df = df.loc[(df.is_stopword==False)]
        df['frequency'] = df.frequency.astype(np.int64)
        df = df[['token_id', 'token', 'frequency', 'dfs', 'is_stopword']].sort_values('frequency', ascending=False)
        return df.set_index('token_id')

    def compute_document_stats(self):
        id2token = self.dictionary.id2token
        stopwords = set(self.stopwords).intersection(set(id2token.values()))
        df = pd.DataFrame({
            'document_id': self.corpus.index,
            'document_name': self.corpus.document_names.document_name,
            'treaty_id': self.corpus.document_names.treaty_id,
            'size': [ sum(list(zip(*document))[1]) for document in self.corpus],
            'stopwords': [ sum([ v for (i,v) in document if id2token[i] in self.stopwords]) for document in self.corpus],
        }).set_index('document_name')
        df[['size', 'stopwords']] = df[['size', 'stopwords']].astype('int')
        return df

    def compute_word_stats(self):
        df = self.compute_document_stats()[['size', 'stopwords']]
        df_agg = df.agg(['count', 'mean', 'std', 'min', 'median', 'max', 'sum']).reset_index()
        legend_map = {
            'count': 'Documents',
            'mean': 'Mean words',
            'std': 'Std',
            'min': 'Min',
            'median': 'Median',
            'max': 'Max',
            'sum': 'Sum words'
        }
        df_agg['index'] = df_agg['index'].apply(lambda x: legend_map[x]).astype('str')
        df_agg = df_agg.set_index('index')
        df_agg[df_agg.columns] = df_agg[df_agg.columns].astype('int')
        return df_agg.reset_index()
    
#@staticmethod

class ExtMmCorpus(gensim.corpora.MmCorpus):
    """Extension of MmCorpus that allow TF normalization based on document length.
    """

    @staticmethod
    def norm_tf_by_D(doc):
        D = sum([x[1] for x in doc])
        return doc if D == 0 else map(lambda tf: (tf[0], tf[1]/D), doc)

    def __init__(self, fname):
        gensim.corpora.MmCorpus.__init__(self, fname)
        
    def __iter__(self):
        for doc in gensim.corpora.MmCorpus.__iter__(self):
            yield self.norm_tf_by_D(doc)

    def __getitem__(self, docno):
        return self.norm_tf_by_D(gensim.corpora.MmCorpus.__getitem__(self, docno))

class TreatyCorpusSaveLoad():

    def __init__(self, source_folder, lang):
        
        self.mm_filename = os.path.join(source_folder, 'corpus_{}.mm'.format(lang))
        self.dict_filename = os.path.join(source_folder, 'corpus_{}.dict.gz'.format(lang))
        self.document_index = os.path.join(source_folder, 'corpus_{}_documents.csv'.format(lang))
        
    def store_as_mm_corpus(self, treaty_corpus):
        
        gensim.corpora.MmCorpus.serialize(self.mm_filename, treaty_corpus, id2word=treaty_corpus.dictionary.id2token)
        treaty_corpus.dictionary.save(self.dict_filename)
        treaty_corpus.document_names.to_csv(self.document_index, sep='\t')

    def load_mm_corpus(self, normalize_by_D=False):
    
        corpus_type = ExtMmCorpus if normalize_by_D else gensim.corpora.MmCorpus
        corpus = corpus_type(self.mm_filename)
        corpus.dictionary = gensim.corpora.Dictionary.load(self.dict_filename)
        corpus.document_names = pd.read_csv(self.document_index, sep='\t').set_index('document_id')  

        return corpus
    
    def exists(self):
        return os.path.isfile(self.mm_filename) and \
            os.path.isfile(self.dict_filename) and \
            os.path.isfile(self.document_index)

def store_mm_corpora(source_path, force, languages):
    
    try:
        print('Current archive:{}'.format(source_path))
        tokenizer = nltk.tokenize.word_tokenize
        source_folder = os.path.split(source_path)[0]
        for language in languages.split(','):
            loader = TreatyCorpusSaveLoad(source_folder, language)
            if not loader.exists() or force:
                print('Processing: {}'.format(language))
                stream = CompressedFileReader(source_path, filename_pattern='*_{}*.txt'.format(language))
                treaty_corpus = TreatyCorpus(stream, tokenizer=tokenizer)        
                loader.store_as_mm_corpus(treaty_corpus)
        print('Corpus is up-to-date!')
    except Exception as ex:
        logger.error(ex)
        
current_archives = (ls_sorted('./data/*.zip') or [])

cuw = BaseWidgetUtility(
    source_path=widgets.Dropdown(
        options=current_archives,
        value=current_archives[-1] if len(current_archives) else None,
        description='Corpus:' #, **drop_style
    ),
    force_corpus_update=widgets.ToggleButton(
        description='Force Update',
        tooltip='Force refresh saved corpus cache (a performance feature). Use when ZIP-archive has been updated.',
        value=False #, **toggle_style
    )
)

icuw = widgets.interactive(
    store_mm_corpora,
    source_path=cuw.source_path,
    force=cuw.force_corpus_update,
    languages='en,it,fr,de'
)

display(widgets.VBox([widgets.HBox([cuw.source_path, cuw.force_corpus_update]), icuw.children[-1]]))

icuw.update()


VBox(children=(HBox(children=(Dropdown(description='Corpus:', options=('./data/treaty_corpus_20180821.zip',), …

In [19]:
# Verify that all checked en/fr/de files in WTI exist in corpus



### Task: Basic Corpus Statistics

In [20]:
# Code 

corpus = None
def display_token_toplist(source_folder, language, statistics='', remove_stopwords=False):
    global tlw, corpus
    try:
        
        tlw.progress.value = 1

        corpus = TreatyCorpusSaveLoad(source_folder=source_folder, lang=language[0]).load_mm_corpus()

        tlw.progress.value = 2
        service = MmCorpusStatisticsService(corpus, dictionary=corpus.dictionary, language=language)

        print("Corpus consists of {} documents, {} words in total and a vocabulary size of {} tokens."\
                  .format(len(corpus), corpus.dictionary.num_pos, len(corpus.dictionary)))

        tlw.progress.value = 3
        if statistics == 'word_freqs':
            display(service.compute_word_frequencies(remove_stopwords))
        elif statistics == 'documents':
            display(service.compute_document_stats())
        elif statistics == 'word_count':
            display(service.compute_word_stats())
        else:
            print('Unknown: ' + statistics)
            
    except Exception as ex:
        logger.error(ex)
        
    tlw.progress.value = 5
    tlw.progress.value = 0
    
tlw = BaseWidgetUtility(
    language=widgets.Dropdown(
        options={
            'English': ('en', 'english'),
            'French': ('fr', 'french'),
            'German': ('de', 'german'),
            'Italian': ('it', 'italian')
        },
        value=('en', 'english'),
        description='Language:', **drop_style
    ),
    statistics=widgets.Dropdown(
        options={
            'Word freqs': 'word_freqs',
            'Documents': 'documents',
            'Word count': 'word_count'
        },
        value='word_count',
        description='Statistics:', **drop_style
    ),    
    remove_stopwords=widgets.ToggleButton(
        description='Remove stopwords', value=True,
        tooltip='Do not include stopwords in token toplist', **toggle_style
    ),    
    progress=wf.create_int_progress_widget(min=0, max=5, step=1, value=0) #, layout=widgets.Layout(width='100%')),
)

itlw = widgets.interactive(
    display_token_toplist,
    source_folder='./data',
    language=tlw.language,
    statistics=tlw.statistics,
    remove_stopwords=tlw.remove_stopwords
)

boxes = widgets.HBox(
    [
        tlw.language, tlw.statistics, tlw.remove_stopwords, tlw.progress
    ]
)
display(widgets.VBox([boxes, itlw.children[-1]]))
itlw.update()


VBox(children=(HBox(children=(Dropdown(description='Language:', index=3, layout=Layout(width='260px'), options…

### <span style='color: red'>WORK IN PROGRESS</span> Task: Network Visualization of Signed Treaties
<table>
    <tr><th>Layout algorithm</th><th>C</th><th>K</th><th>p</th><th>Note</th></tr>
    <tr><td>NetworkX ([fruchterman_reingold](https://networkx.github.io/documentation/networkx-1.11/reference/generated/networkx.drawing.layout.fruchterman_reingold_layout.html))</td><td></td><td></td><td>Optimal distance between nodes.</td><td>Fruchterman-Reingold force-directed algorithm.</td></tr>
    <tr><td>NetworkX ([Spectral](https://networkx.github.io/documentation/networkx-1.11/reference/generated/networkx.drawing.nx_pylab.draw_spectral.html#networkx.drawing.nx_pylab.draw_spectral))</td><td></td><td></td><td></td><td>Position nodes using the eigenvectors of the graph Laplacian.</td></tr>
    <tr><td>NetworkX (Circular)</td><td></td><td></td><td></td><td>Position nodes on a circle.</td></tr>
    <tr><td>NetworkX ([Shell](https://networkx.github.io/documentation/networkx-1.11/reference/generated/networkx.drawing.nx_pylab.draw_shell.html#networkx.drawing.nx_pylab.draw_shell))</td><td></td><td></td><td></td><td>Position nodes in a bipartite graph in concentric circles.</td></tr>
    <tr><td>NetworkX (Kamada-Kawai)</td><td></td><td></td><td></td></tr>
    <tr><td>graph-tool ([arf](https://graph-tool.skewed.de/static/doc/draw.html#graph_tool.draw.arf_layout))</td><td>Attracting force between adjacent vertices (a).</td><td>Opposing force between vertices (d).</td><td></td></tr>
    <tr><td>graph-tool ([sfdp](https://graph-tool.skewed.de/static/doc/draw.html#graph_tool.draw.sfdp_layout))</td><td>Relative strength of repulsive forces (C/100).</td><td>Optimal edge length.</td><td>Strength of the attractive force between connected components (gamma).</td></tr>
    <tr><td>graph-tool ([fruchterman_reingold](https://graph-tool.skewed.de/static/doc/draw.html#graph_tool.draw.fruchterman_reingold_layout))</td><td>Repulsive force between vertices (a=2*N*K).</td><td>Attracting force between adjacent vertices (r = 2\*C).</td><td></td><td>Fruchterman-Reingold force-directed algorithm.</td></tr>
    <tr><td>graphviz ([neato](https://graphviz.gitlab.io/_pages/pdf/neatoguide.pdf))</td><td></td><td>K=K</td><td></td></tr>
    <tr><td>graphviz ([dot](https://graphviz.gitlab.io/_pages/pdf/dotguide.pdf))</td><td></td><td>K=K</td><td></td></tr>
    <tr><td>graphviz (circo)</td><td></td><td>K=K</td><td></td></tr>
    <tr><td>graphviz (fdp)</td><td></td><td>K=K</td><td></td></tr>
    <tr><td>graphviz (sfdp)</td><td></td><td>K=K</td><td></td></tr>
</table>

*N = Number of nodes in graph*
        
- **TODO: Button "Recode 7CULT" button**<br>
- **TODO: Add Category and year as edge labels **<br>
- **TODO: Split result into one graph per category**

In [22]:
# Visualize treaties
import bokeh.palettes as pals
import community
import types

%run ./common/network_utility
%run ./common/plot_utility
    
periods_division = [
    (1919, 1939), (1940, 1944), (1919, 1944), (1945, 1955), (1956, 1966), (1967, 1972)
    ]

label_text_opts=dict(
    x_offset=0, #y_offset=5,
    level='overlay',
    text_align='center',
    text_baseline='bottom',
    render_mode='canvas',
    text_font="Tahoma",
    text_font_size="9pt",
    text_color='black'
    )

network_plot_opts = dict(
    x_axis_type=None,
    y_axis_type=None,
    background_fill_color='white',
    line_opts=dict(color='green', alpha=0.5 ),
    node_opts=dict(color=None, level='overlay', alpha=1.0),
    )

def network_edges_to_dicts(network, layout):
    LD = [ extend(dict(source=u,target=v,xs=[layout[u][0], layout[v][0]], ys=[layout[u][1], layout[v][1]]), d) for u, v, d in G.edges(data=True) ]
    LD.sort(key=lambda x: x['signed'])
    edges = dict(zip(LD[0],zip(*[d.values() for d in LD])))
    return edges

def pandas_to_network_edges(data):
    return [ (x[0], x[1], { y: x[j] for j, y in enumerate(data.columns)}) for i, x in data.iterrows() ]
    
def get_party_network_data(
    parties,
    period,
    only_is_cultural=True,
    party_name='party'
    ):
    
    global state
    
    data = state.stacked_treaties.copy()

    data = data.loc[(data.signed_period!='other')]

    if only_is_cultural:
        data = data.loc[(data.is_cultural==True)]

    if isinstance(parties, list):
        data = data.loc[(data.party.isin(parties))]
    else:
        data = data.loc[(data.reversed==False)]

    data = data.loc[(data.signed_period != period)]
    data = data.loc[(data.signed_year.between(period[0], period[1]))]
    data = data.sort_values('signed')

    # data = data.groupby(['party', 'party_other']).size().reset_index().rename(columns={0: 'weight'})
    
    data = data[[ 'party', 'party_other', 'signed', 'topic', 'headnote']]

    if party_name != 'party':
        for column in ['party', 'party_other']:
            data[column] = data[column].apply(lambda x: state.get_party_name(x, party_name))

    data['weight'] = 1.0
            
    return data

def create_party_network(data, K, node_partition, palette): #, multigraph=True):

    #if multigraph:
    
    edges_data = pandas_to_network_edges(data)

    G = nx.MultiGraph(K=K)
    G.add_edges_from(edges_data)
    #else:
    #    edges_data = [ tuple(x) for x in data.values ]
    #    print(edges_data)
    #    G = nx.Graph(K=K)
    #    G.add_weighted_edges_from(edges_data)

    if node_partition is not None:
        partition = community.best_partition(G)
        partition_color = { n: palette[p % len(palette)] for n, p in partition.items() }
        nx.set_node_attributes(G, partition, 'community')
        nx.set_node_attributes(G, partition_color, 'fill_color')
    else:
        #nx.set_node_attributes(G, 0, 'community')
        nx.set_node_attributes(G, palette[0], 'fill_color')

    nx.set_node_attributes(G, dict(G.degree()), name='degree')
    nx.set_node_attributes(G, dict(nx.betweenness_centrality(G, weight=None)), name='betweenness')
    nx.set_node_attributes(G, dict(nx.closeness_centrality(G)), name='closeness')
    
    # if not multigraph:
    #    nx.set_node_attributes(G, dict(nx.eigenvector_centrality(G, weight=None)), name='eigenvector')
        
    # nx.set_node_attributes(G, dict(nx.communicability_betweenness_centrality(G)), name='communicability_betweenness')
    
    return G

def setup_node_size(nodes, node_size, node_size_range):

    if node_size in nodes.keys() and node_size_range is not None:
        nodes['clamped_size'] = clamp_values(nodes[node_size], node_size_range)
        node_size = 'clamped_size'
    return node_size
    
def setup_label_y_offset(nodes, node_size):

    label_y_offset = 'y_offset' if node_size in nodes.keys() else node_size + 5
    if label_y_offset == 'y_offset':
        nodes['y_offset'] = [ y + r for (y, r) in zip(nodes['y'], [ r / 2.0 + 5 for r in nodes[node_size] ]) ]
    return label_y_offset

def plot_party_network(
    nodes,
    edges,
    node_description=None,
    node_size=5,
    node_opts=None,
    line_opts=None,
    text_opts=None,
    element_id='nx_id3',
    figsize=(900, 900),
    tools=None,
    **figkwargs
    ):
    
    edges_source = bm.ColumnDataSource(edges)
    nodes_source = bm.ColumnDataSource(nodes)

    node_opts = extend(DFLT_NODE_OPTS, node_opts or {})
    line_opts = extend(DFLT_EDGE_OPTS, line_opts or {})

    p = figure(plot_width=figsize[0], plot_height=figsize[1], tools=tools or TOOLS, **figkwargs)

    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None
    
    if 'line_color' in edges.keys():
        line_opts = extend(line_opts, { 'line_color': 'line_color', 'alpha': 1.0})

    r_lines = p.multi_line('xs', 'ys', line_width='weight', source=edges_source, **line_opts)
    r_nodes = p.circle('x', 'y', size=node_size, source=nodes_source, **node_opts)

    if 'fill_color' in nodes.keys():
        r_nodes.glyph.fill_color = 'fill_color'

    if node_description is not None:
        p.add_tools(bm.HoverTool(renderers=[r_nodes], tooltips=None, callback=WidgetUtility.\
            glyph_hover_callback(nodes_source, 'node_id', text_ids=node_description.index, \
                                 text=node_description, element_id=element_id))
        )

    label_opts = extend(DFLT_TEXT_OPTS, text_opts or {})

    p.add_layout(bm.LabelSet(source=nodes_source, **label_opts))

    # return p
    handle = bp.show(p, notebook_handle=True)
    return types.SimpleNamespace(
        handle=handle,
        edges_source=edges_source,
        nodes_source=nodes_source,
        nodes=nodes,
        edges=edges,
    )

handle_data = None

def display_party_network(
    parties,
    period,
    only_is_cultural=True,
    layout_algorithm='',
    C=1.0,
    K=0.10,
    p1=0.10,
    output='network_bokeh',
    party_name='party',
    node_size_range=[40,60],
    refresh=False,
    palette_name=None,
    width=900,
    height=900,
    node_size=None,
    node_partition=None,
    weight_threshold=0.0,
    weight_scale=1.0,
    weight_normalize=True,
    category_map_name=None
    ):
    
    global state, zn, G, layout, handle_data
    
    try:
        #multigraph = False
        figsize=(width, height)
        palette_id = max(pals.all_palettes[palette_name].keys())
        palette = pals.RdYlBu[11] if palette_name is None else pals.all_palettes[palette_name][palette_id]

        zn.refresh.value = False
        zn.progress.value = 1
        
        data = get_party_network_data(parties, period, only_is_cultural, party_name)
        
        if category_map_name is not None:
            category_map = category_group_maps[category_map_name]
            data = data.loc[(data.topic.isin(category_map.keys()))]
            data['category'] = data.apply(lambda x: category_map.get(x['topic'], 'OTHER'), axis=1)
            
            line_palette = pals.Set1[8]
            group_keys = category_group_settings[category_map_name].keys()
            line_palette_map = { k: i % len(line_palette) for i, k in enumerate(group_keys) }
            # print(line_palette_map)
            data['line_color'] = data.category.apply(lambda x: line_palette[line_palette_map[x]])

        else:
            data['category'] = data.topic
            
        zn.progress.value = 2
        
        #if not multigraph:
        #    data = data.groupby(['party', 'party_other']).size().reset_index().rename(columns={0: 'weight'})
        
        G = create_party_network(data, K, node_partition, palette) #, multigraph)
        
        zn.progress.value = 3
        
        if output == 'bokeh_plot':
            
            if weight_threshold > 0:
                G = filter_by_weight(G, weight_threshold)
            
            layout = layout_network(G, layout_algorithm, **dict(scale=1.0, K=K, C=C, p=p1))
            zn.progress.value = 4
            
            edges = network_edges_to_dicts(G, layout)
            nodes = NetworkUtility.get_node_attributes(G, layout)
            
            edges = { k: list(edges[k]) for k in edges }
            nodes = { k: list(nodes[k]) for k in nodes }
    
            node_size = node_size if not node_size is None else node_size_range[0]
            node_size = setup_node_size(nodes, node_size, node_size_range)
            
            y_offset = setup_label_y_offset(nodes, node_size)
            text_opts = extend(label_text_opts, dict(y_offset=y_offset, x_offset=0))
            zn.progress.value = 5 
        
            handle_data = plot_party_network(
                nodes=nodes,
                edges=edges,
                figsize=figsize,
                node_size=node_size,
                text_opts=text_opts,
                **network_plot_opts
            )
            
            zn.progress.value = 6
            zn.time_travel_range.max = len(edges['source'])
            zn.time_travel_range.value = [0, len(edges['source'])]
            
            #bp.show(p)

        elif output == 'table':
            display(data)
        else:
            display(pivot_ui(data))
    except Exception as ex:
        print(ex)
        raise
    finally:
        zn.progress.value = 0

zn = BaseWidgetUtility(
    period=widgets.Dropdown(
        options={
            '{} to {}'.format(x[0], x[1]): x for x in list(set(period_divisions[0] + period_divisions[1]))
        },
        value=period_divisions[0][0],
        description='Period:', layout=widgets.Layout(width='220px')
    ),
    category_map_name=widgets.Dropdown(
        options=category_group_maps.keys(),
        description='Category:', layout=widgets.Layout(width='300px')
    ),
    parties=widgets.Dropdown(
        description='Parties:',
        options=default_party_options,
        value=default_party_options['PartyOf5'],
        layout=widgets.Layout(width='220px')
    ),
    party_name=widgets.Dropdown(
        description='Name:',
        options={
            'WTI Code': 'party',
            'WTI Name': 'party_name',
            'WTI Short': 'short_name',
            'CC': 'country_code',
            'Country': 'party_country'
        },
        value='short_name',
        layout=widgets.Layout(width='220px')
    ),
    node_size=widgets.Dropdown(
        description='Node size:',
        options={
            '(default)': None,
            'Degree centrality': 'degree',
            'Closeness centrality': 'closeness',
            'Betweenness centrality': 'betweenness',
            'Eigenvector centrality': 'eigenvector'            
            #'communicability_betweenness': 'communicability_betweenness'
        },
        value=None,
        layout=widgets.Layout(width='220px')
    ),
    palette=widgets.Dropdown(
        description='Color:',
        options={
            palette_name: palette_name
                    for palette_name in bokeh.palettes.all_palettes.keys()
                        if any([ len(x) > 7 for x in bokeh.palettes.all_palettes[palette_name].values()])
        },
        layout=widgets.Layout(width='220px')
    ),
    C=widgets.IntSlider(
        description='C', min=0, max=100, step=1, value=1,
        continuous_update=False, layout=widgets.Layout(width='240px', height='30px')  # , orientation='vertical'
    ),
    K=widgets.FloatSlider(
        description='K', min=0.01, max=1.0, step=0.01, value=0.10,
        continuous_update=False, layout=widgets.Layout(width='240px', height='30px')  # , orientation='vertical'
    ),
    p=widgets.FloatSlider(
        description='p', min=0.01, max=2.0, step=0.01, value=1.10,
        continuous_update=False, layout=widgets.Layout(width='240px', height='30px')  # , orientation='vertical', 
    ),
    node_size_range=widgets.IntRangeSlider(
        description='Node size range',
        value=[20, 40], min=5, max=100, step=1,
        continuous_update=False,
        layout=widgets.Layout(width='240px', height='30px'),  # , orientation='vertical'
        style={'font-size': '9pt' }
    ),
    fig_width=widgets.IntSlider(
        description='Width', min=600, max=1600, step=100, value=1000,
        continuous_update=False, layout=widgets.Layout(width='240px', height='30px')  # , orientation='vertical'
    ),
    fig_height=widgets.IntSlider(
        description='Height', min=600, max=1600, step=100, value=700,
        continuous_update=False, layout=widgets.Layout(width='240px', height='30px')  # , orientation='vertical'
    ),
    only_is_cultural=widgets.ToggleButton(
        description='Only Cultural', value=True,
        tooltip='Display only "is_cultural" treaties', layout=widgets.Layout(width='100px')
    ),
    output=widgets.Dropdown(
        description='Output:',
        options={ 'Plot': 'bokeh_plot', 'List': 'table' }, # ,'Framework': 'framework_plot' },
        value='bokeh_plot',
        layout=widgets.Layout(width='220px')
    ),
    layout_algorithm=widgets.Dropdown(
        description='Layout',
        options=layout_function_name,
        value='graphtool_sfdp',
        layout=widgets.Layout(width='220px')
    ),
    progress=wf.create_int_progress_widget(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="50%")),
    refresh=widgets.ToggleButton(
        description='Refresh', value=False,
        tooltip='Update plot', layout=widgets.Layout(width='100px')
    ),
    node_partition=widgets.Dropdown(
        description='Partition:',
        options={
            '(default)': None,
            'Louvain': 'louvain',
        },
        value=None,
        layout=widgets.Layout(width='220px')
    ),
    simple_mode=widgets.Checkbox(
        value=False,
        description='Simple',
        disabled=False,
        layout=widgets.Layout(width='150px')
    ),
    time_travel_range=widgets.IntRangeSlider(
        value=[0, 100],
        min=0,
        max=100,
        step=1,
        description='Time travel',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d',
        layout=widgets.Layout(width='80%')
    ),
    time_travel_label=widgets.Label(value="")
) 

def on_value_change(change):
    display_mode = 'none' if change['new'] == True else ''
    zn.node_partition.layout.display = display_mode
    zn.node_size.layout.display = display_mode
    zn.node_size_range.layout.display = display_mode
    zn.layout_algorithm.layout.display = display_mode
    zn.C.layout.display = display_mode
    zn.K.layout.display = display_mode
    zn.p.layout.display = display_mode
    zn.fig_width.layout.display = display_mode
    zn.fig_height.layout.display = display_mode
    zn.palette.layout.display = display_mode

zn.simple_mode.observe(on_value_change, names='value')
zn.simple_mode.value = True

wn = widgets.interactive(
    display_party_network,
    parties=zn.parties,
    period=zn.period,
    only_is_cultural=zn.only_is_cultural,
    layout_algorithm=zn.layout_algorithm,
    C=zn.C,
    K=zn.K,
    p1=zn.p,
    output=zn.output,
    party_name=zn.party_name,
    node_size_range=zn.node_size_range,
    refresh=zn.refresh,
    palette_name=zn.palette,
    width=zn.fig_width,
    height=zn.fig_height,
    node_size=zn.node_size,
    node_partition=zn.node_partition,
    category_map_name=zn.category_map_name
)

boxes = widgets.HBox([
    widgets.VBox([widgets.HBox([zn.parties, zn.only_is_cultural]), widgets.HBox([zn.period, zn.refresh]),
                  widgets.HBox([zn.category_map_name]),
                  widgets.HBox([zn.simple_mode, zn.progress])
                 ]),
    widgets.VBox([zn.layout_algorithm, zn.party_name, zn.output, zn.palette]),
    widgets.VBox([zn.K, zn.C, zn.p, zn.node_partition]),
    widgets.VBox([zn.fig_width, zn.fig_height, zn.node_size, zn.node_size_range]),
])

display(widgets.VBox([boxes, wn.children[-1]]))

wn.update()

#Code
def display_partial_party_network(range=[1,100]):
    global handle_data    
    edge_count = len(handle_data.edges['source'])
    edges = { k: handle_data.edges[k][range[0]:range[1]] for k in handle_data.edges } 
    nodes_indices = set(edges['source'] + edges['target'])
    df = pd.DataFrame({k:handle_data.nodes[k] for k in handle_data.nodes if isinstance(handle_data.nodes[k], list) }).set_index('id')
    nodes = df.loc[nodes_indices].reset_index().to_dict(orient='list')
    nodes = extend(nodes, {k:handle_data.nodes[k] for k in handle_data.nodes if not isinstance(handle_data.nodes[k], list) })
    handle_data.edges_source.data.update(edges)    
    handle_data.nodes_source.data.update(nodes)
    min_year, max_year = min(handle_data.edges_source.data['signed']).year, max(handle_data.edges_source.data['signed']).year
    zn.time_travel_range.description = '{}-{}'.format(min_year, max_year)
    bokeh.io.push_notebook(handle=handle_data.handle)
    
iw_time_travel = widgets.interactive(
    display_partial_party_network,
    range=zn.time_travel_range
)
time_travel_box = widgets.VBox([widgets.VBox([zn.time_travel_label, zn.time_travel_range]), iw_time_travel.children[-1]])

display(time_travel_box)

VBox(children=(HBox(children=(VBox(children=(HBox(children=(Dropdown(description='Parties:', index=9, layout=L…

<Figure size 432x288 with 0 Axes>

VBox(children=(VBox(children=(Label(value=''), IntRangeSlider(value=(0, 19), continuous_update=False, descript…

### <span style='color: red'>WORK IN PROGRESS</span> Task: Treaty Keyword Extraction (using TF-IDF weighing)
- [ML Wiki.org](http://mlwiki.org/index.php/TF-IDF)
- [Wikipedia](https://en.wikipedia.org/wiki/Tf%E2%80%93idf)
- Spärck Jones, K. (1972). "A Statistical Interpretation of Term Specificity and Its Application in Retrieval".
- Manning, C.D.; Raghavan, P.; Schutze, H. (2008). "Scoring, term weighting, and the vector space model". ([PDF](http://nlp.stanford.edu/IR-book/pdf/06vect.pdf))
- https://markroxor.github.io/blog/tfidf-pivoted_norm/
$\frac{tf-idf}{\sqrt(rowSums( tf-idf^2 ) )}$
- https://nlp.stanford.edu/IR-book/html/htmledition/pivoted-normalized-document-length-1.html

Neural Network Methods in Natural Language Processing, Yoav Goldberg:
![image.png](attachment:image.png)

In [27]:
# Code
from scipy.sparse import csr_matrix
%timeit

    
def get_top_tfidf_words(data, n_top=5):
    top_list = data.groupby(['treaty_id'])\
        .apply(lambda x: x.nlargest(n_top, 'score'))\
        .reset_index(level=0, drop=True)
    return top_list

def compute_tfidf_scores(corpus, dictionary, smartirs='ntc'):
    #model = gensim.models.logentropy_model.LogEntropyModel(corpus, normalize=True)
    model = gensim.models.tfidfmodel.TfidfModel(corpus, dictionary=dictionary, normalize=True) #, smartirs=smartirs)
    rows, cols, scores = [], [], []
    for r, document in enumerate(corpus): 
        vector = model[document]
        c, v = zip(*vector)
        rows += (len(c) * [ int(r) ])
        cols += c
        scores += v
        
    return csr_matrix((scores, (rows, cols)))
    
if True: #'tfidf_cache' not in globals():
    tfidf_cache = {
    }
    
def display_tfidf_scores(source_folder, language, period, n_top=5, threshold=0.001):
    
    global state, tfw, tfidf_cache
    
    try:
        treaties = state.treaties

        tfw.progress.value = 0
        tfw.progress.value += 1
        if language[0] not in tfidf_cache.keys():
            corpus = TreatyCorpusSaveLoad(source_folder=source_folder, lang=language[0])\
                .load_mm_corpus(normalize_by_D=True)
            document_names = corpus.document_names
            dictionary = corpus.dictionary
            _ = dictionary[0]

            tfw.progress.value += 1
            A = compute_tfidf_scores(corpus, dictionary)

            tfw.progress.value += 1
            scores = pd.DataFrame(
                [ (i, j, dictionary.id2token[j], A[i, j]) for i, j in zip(*A.nonzero())],
                columns=['document_id', 'token_id', 'token', 'score']
            )
            tfw.progress.value += 1
            scores = scores.merge(document_names, how='inner', left_on='document_id', right_index=True)\
                .drop(['document_id', 'token_id', 'document_name'], axis=1)

            scores = scores[['treaty_id', 'token', 'score']]\
                .sort_values(['treaty_id', 'score'], ascending=[True, False])

            tfidf_cache[language[0]] = scores

        scores = tfidf_cache[language[0]]
        if threshold > 0:
            scores = scores.loc[scores.score >= threshold]

        tfw.progress.value += 1

        #scores = get_top_tfidf_words(scores, n_top=5)
        #scores = scores.groupby(['treaty_id']).sum() 

        scores = scores.groupby(['treaty_id'])\
            .apply(lambda x: x.nlargest(n_top, 'score'))\
            .reset_index(level=0, drop=True)\
            .set_index('treaty_id')

        if period is not None:
            periods = state.treaties[period]
            scores = scores.merge(periods.to_frame(), left_index=True, right_index=True, how='inner')\
                .groupby([period, 'token']).score.agg([np.mean])\
                .reset_index().rename(columns={0:'score'}) #.sort_values('token')

        #['token'].apply(' '.join)

        print(scores)
    except Exception as ex:
        logger.error(ex)
        raise
        
    tfw.progress.value = 0

#if 'tfidf_scores' not in globals():
#    tfidf_scores = compute_document_tfidf(corpus, corpus.dictionary, state.treaties)
#    tfidf_scores = tfidf_scores.sort_values(['treaty_id', 'score'], ascending=[True, False])

tfw = BaseWidgetUtility(
    language=widgets.Dropdown(
        options={
            'English': ('en', 'english'),
            'French': ('fr', 'french'),
            'German': ('de', 'german'),
            'Italian': ('it', 'italian')
        },
        value=('en', 'english'),
        description='Language:', **drop_style
    ),
    remove_stopwords=widgets.ToggleButton(
        description='Remove stopwords', value=True,
        tooltip='Do not include stopwords in token toplist', **toggle_style
    ),    
    n_top=widgets.IntSlider(
        value=5, min=1, max=25, step=1,
        description='Top #:',
        continuous_update=False
    ),
    threshold=widgets.FloatSlider(
        value=0.001, min=0.0, max=0.5, step=0.01,
        description='Threshold:',
        tooltip='Word having a TF-IDF score below this value is filtered out',
        continuous_update=False,
        readout_format='.3f',
    ), 
    period=widgets.Dropdown(
        options={
            '': None,
            'Year': 'signed_year',
            'Default division': 'signed_period',
            'Alt. division': 'signed_period_alt'
        },
        value='signed_period',
        description='Period:', **drop_style
    ),
    output=widgets.Dropdown(
        options={
            '': None,
            'Year': 'signed_year',
            'Default division': 'signed_period',
            'Alt. division': 'signed_period_alt'
        },
        value='signed_period',
        description='Output:', **drop_style
    ),
    progress=widgets.IntProgress(min=0, max=5, step=1, value=0) #, layout=widgets.Layout(width='100%')),
)

itfw = widgets.interactive(
    display_tfidf_scores,
    source_folder='./data',
    language=tfw.language,
    n_top=tfw.n_top,
    threshold=tfw.threshold,
    period=tfw.period
)

boxes = widgets.HBox(
    [
        widgets.VBox([tfw.language, tfw.period]),
        widgets.VBox([tfw.n_top, tfw.threshold]),
        widgets.VBox([tfw.progress, tfw.output])
    ]
)

display(widgets.VBox([boxes, itfw.children[-1]]))
itfw.update()


VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Language:', index=3, layout=Layout(width='2…

### <span style='color:red'>IGNORE EVERYTHING BELOW</span>

## Network info and statistics

__START HERE__

[set_node_attributes](https://networkx.github.io/documentation/stable/reference/generated/networkx.classes.function.set_node_attributes.html#networkx.classes.function.set_node_attributes)


In [None]:
#Code
def display_partial_party_network(position=1):
    global handle_data
    edge_count = len(handle_data.edges['source'])
    edges = { k: handle_data.edges[k][:position] for k in handle_data.edges } 
    nodes_indices = set(edges['source'] + edges['target'])
    df = pd.DataFrame({k:handle_data.nodes[k] for k in handle_data.nodes if isinstance(handle_data.nodes[k], list) }).set_index('id')
    nodes = df.loc[nodes_indices].reset_index().to_dict(orient='list')
    nodes = extend(nodes, {k:handle_data.nodes[k] for k in handle_data.nodes if not isinstance(handle_data.nodes[k], list) })
    handle_data.edges_source.data.update(edges)
    handle_data.nodes_source.data.update(nodes)
    bokeh.io.push_notebook(handle=handle_data.handle)
    
iw_time_travel = widgets.interactive(
    display_partial_party_network,
    position=zn.time_travel_position
)
time_travel_box = widgets.VBox([zn.time_travel_position, iw_time_travel.children[-1]])

display(time_travel_box)

In [None]:
# Code
from nltk.stem import WordNetLemmatizer

class CoOccurrance():

    def __init__(self, tokenizer, stopwords=None, lemmatizer=None, min_word_size=2):
        
        self.transforms = [
            tokenizer,
            lambda ws: ( x for x in ws if len(x) >= min_word_size ),
            lambda ws: ( x for x in ws if any(ch.isalpha() for ch in x)) 
        ]
        
        if stopwords is not None:
            self.transforms += [ lambda ws: ( x for x in ws if x not in stopwords ) ]
            
        if lemmatizer is not None:
            self.transforms += [ lambda ws: ( lemmatizer(x) for x in ws ) ]

    def _apply_transforms(self, ws):
        for f in self.transforms:
            ws = f(ws)
        return list(ws)
    
    def compute(self, headnotes):
        
        texts = [ x.lower() for x in list(headnotes) ]
        tokens = list(map(self._apply_transforms, texts))
        df = pd.DataFrame({'headnote': headnotes, 'tokens': tokens })
        
        df_stacked = pd.DataFrame(df.tokens.tolist(), index=df.index).stack()\
            .reset_index().rename(columns={'level_1': 'sequence_id', 0: 'token'})
            
        return df_stacked
    
    def compute_co_occurrence(self, treaties, pos_tags, only_cultural_treaties=False):

        # Filter out tags based on treaties of interest
        pos_tags = pos_tags.merge(treaties, how='inner', left_on='treaty_id', right_index=True)[[]]
        
        if only_cultural_treaties:
            df_pos_tags = df_pos_tags[(df_pos_tags.is_cultural.str.contains('yes',na=False))]

        # Self join of words within same treaty
        df_co_occurrence = pd.merge(df_pos_tags, df_pos_tags, how='inner', left_on='treaty_id', right_on='treaty_id')
        # Only consider a specific poir once
        df_co_occurrence = df_co_occurrence[(df_co_occurrence.wid_x < df_co_occurrence.wid_y)]
        # Reduce number of returned columns
        df_co_occurrence = df_co_occurrence[['treaty_id', 'year_x', 'is_cultural_x', 'lemma_x', 'lemma_y' ]]
        # Rename columns
        df_co_occurrence.columns = ['treaty_id', 'year', 'is_cultural', 'lemma_x', 'lemma_y' ]

        # Sort token pair so smallest always comes first
        lemma_x = df_co_occurrence[['lemma_x', 'lemma_y']].min(axis=1)
        lemma_y = df_co_occurrence[['lemma_x', 'lemma_y']].max(axis=1)
        df_co_occurrence['lemma_x'] = lemma_x
        df_co_occurrence['lemma_y'] = lemma_y

        return df_co_occurrence
    
def create_bigram_transformer(documents):
    import gensim.models.phrases
    bigram = gensim.models.phrases.Phrases(map(nltk.tokenize.word_tokenize, documents))
    return lambda ws: bigram[ws]

treaties = state.treaties.loc[(state.treaties.is_cultural)]
headnotes = treaties['headnote']
stopwords = nltk.corpus.stopwords.words('english')
tokenizer = nltk.tokenize.word_tokenize
lemmatizer = WordNetLemmatizer().lemmatize
df = CoOccurrance(tokenizer=tokenizer, stopwords=stopwords, lemmatizer=lemmatizer, min_word_size=2).compute(headnotes)
df.head()