In [1]:
# -*- coding: utf8

import sys
sys.path.append('../code/')

In [2]:
from amutils import build_graph
from amutils import build_reverse_index
from amutils import load_am_json_data

from disrupt import compute_disruption

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
plt.rcParams['figure.figsize']  = (18, 10)
plt.rcParams['axes.labelsize']  = 16
plt.rcParams['axes.titlesize']  = 16
plt.rcParams['legend.fontsize'] = 16
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16
plt.rcParams['lines.linewidth'] = 2

In [4]:
plt.ion()

plt.style.use('seaborn-colorblind')
plt.rcParams['figure.figsize']  = (12, 8)

In [5]:
def despine(ax=None):
    if ax is None:
        ax = plt.gca()
    # Hide the right and top spines
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.axes.get_yaxis().set_visible(False)

In [6]:
json_data = load_am_json_data()
decades, genres, styles = build_reverse_index(json_data)

In [7]:
def rank_nodes(decade=None, genre=None, style=None,
               min_in=1, min_out=0, restrictive=False):
    to_use = None
    if decade is not None:
        to_use = set(decades[decade])
    if genre is not None:
        if to_use is None:
            to_use = set(genres[genre])
        else:
            to_use = set(genres[genre]).intersection(to_use)
    if style is not None:
        if to_use is None:
            to_use = set(styles[style])
        else:
            to_use = set(styles[style]).intersection(to_use)
    
    print('Computing disruption!')
    G = build_graph(json_data, to_use, restrictive=restrictive)
    disrupt = compute_disruption(G, min_in, min_out)
    disrupt = disrupt.dropna()
    
    diffs = []
    for ni, nj, nk in disrupt.values[:, :-1]:
        total = ni + nj + nk
        D = np.random.dirichlet([1 + ni, 1 + nj, 1+nk], size=10000)
        pos_i = D[:, 0]
        pos_j = D[:, 1]
        diff = pos_i - pos_j
        diffs.append(diff)
    
    
    disrupt['name'] = [json_data[id_]['name'] for id_ in disrupt.index]
    posteriors = pd.DataFrame(diffs, index=disrupt.index)
    posteriors['name'] = disrupt['name']
    if to_use:
        disrupt = disrupt.loc[disrupt.index.isin(to_use)]
        posteriors = posteriors.loc[posteriors.index.isin(to_use)]
        
    return disrupt, posteriors

In [8]:
from ipywidgets import interact_manual

decade_options = [None] + list(sorted(decades.keys()))
genre_options = [None] + list(sorted(genres.keys()))
style_options = [None] + list(sorted(styles.keys()))

@interact_manual
def interactive_rank(initial_decade=decade_options,
                     genre=genre_options,
                     style=style_options,
                     min_in=[1, 2, 3, 4, 5],
                     min_out=[0, 1, 2, 3, 4, 5],
                     restrictive=False):
    
    disruption, posterior = rank_nodes(initial_decade, genre, style,
                                       min_in, min_out, restrictive)
    positive = disruption[disruption['disruption'] > 0]
    negative = disruption[disruption['disruption'] < 0]
    top = positive.nlargest(10, ['disruption'])
    bottom = negative.nsmallest(10, ['disruption'])
    
    top_plot = posterior.loc[top.index]
    bottom_plot = posterior.loc[bottom.index]
    names_top = top_plot['name']
    names_bottom = bottom_plot['name']
    del top_plot['name']
    del bottom_plot['name']
    
    top_plot = pd.DataFrame(top_plot.T.values, columns=names_top)
    bottom_plot = pd.DataFrame(bottom_plot.T.values, columns=names_bottom)
    
    axes = top_plot.plot.kde(subplots=True, color='magenta')
    for ax in axes:
        despine(ax)
        leg = ax.legend(loc='upper left', frameon=False)
        for item in leg.legendHandles:
            item.set_visible(False)
        ax.set_xlim((-1, 1))
        ax.set_xlabel('Posterior Disruption')
    plt.tight_layout()
    plt.show()
    plt.close()
    
    axes = bottom_plot.plot.kde(subplots=True, color='magenta')
    for ax in axes:
        despine(ax)
        leg = ax.legend(loc='upper left', frameon=False)
        for item in leg.legendHandles:
            item.set_visible(False)
        ax.set_xlim((-1, 1))
        ax.set_xlabel('Posterior Disruption')
    plt.tight_layout()

interactive(children=(Dropdown(description='initial_decade', options=(None, 1890, 1900, 1910, 1920, 1930, 1940…