Networks and Word Vectors with MeSH Labels
==========================================

In [None]:
%load_ext line_profiler
%load_ext memory_profiler
%load_ext autoreload
%autoreload 2

In [None]:
import os
import ast
import json
import itertools

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from collections import defaultdict, Counter
from datetime import datetime
from itertools import zip_longest
from matplotlib.ticker import NullFormatter

from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource, HoverTool, LinearColorMapper, Range1d
from bokeh.palettes import viridis

from analysis.src.data.readnwrite import get_data_dir
from analysis.src.data.data_utilities import flatten, eval_column, grouper

pd.options.display.max_columns = 99
output_notebook()

In [None]:
from rhodonite.dynamics import PhylomemeticGraph
from rhodonite.graphs import SlidingWindowGraph
from rhodonite.spectral import association_strength

In [None]:
from gensim.corpora import Dictionary

In [None]:
from graph_tool.generation import price_network
from graph_tool.draw import graph_draw
from graph_tool.all import GraphView, Graph

In [None]:
%matplotlib inline

# Paths
# Get the top path
data_path = get_data_dir()

# Create the path for external data
ext_data = os.path.join(data_path, 'external')
# Raw data
raw_data = os.path.join(data_path, 'raw')
# And external data
proc_data = os.path.join(data_path, 'processed')
# And interim data
inter_data = os.path.join(data_path, 'interim')
# And figures
fig_path = os.path.join(data_path, 'figures')

# Get date for saving files
today = datetime.utcnow()

today_str = "_".join([str(x) for x in [today.year,today.month,today.day]])

## 1. Load Data

We are going to load both the GDB and the RWJF Pioneer and Global projects, and join them into a single dataframe.

In [None]:
gdb_df = pd.read_csv(os.path.join(raw_data, 'gdb.csv'))

In [None]:
rwjf_df = pd.read_csv(os.path.join(inter_data, 'rwjf_pioneer_and_global_projects.csv'))

Now we need to join the other relevant data modules:

Dates for GDB:

In [None]:
gdb_dates_df = pd.read_csv(os.path.join(inter_data, 'gdb_dates.csv'))
gdb_df = pd.concat([gdb_df, gdb_dates_df], axis=1)

MeSH labels:

In [None]:
gdb_mesh_df = pd.read_csv(os.path.join(inter_data, 'gdb_mesh_labels.csv'))
rwjf_mesh_df = pd.read_csv(os.path.join(inter_data, 'rwjf_mesh_labels.csv'))

gdb_df = pd.concat([gdb_df, gdb_mesh_df], axis=1)
rwjf_df = pd.concat([rwjf_df, rwjf_mesh_df], axis=1)

We're going to remove projects from GitHub as they don't play nicely with MeSH terms, and Crunchbase as they're very short. There are also some projects with null descriptions.

In [None]:
gdb_df = gdb_df[gdb_df['source_id'] != 'GitHub']
gdb_df = gdb_df[gdb_df['source_id'] != 'Crunchbase']
gdb_df['description'][pd.isnull(gdb_df['description'])] = ''

Let's concatenate the two sets of projects and extract their descriptions

In [None]:
gdb_df = pd.concat([gdb_df, rwjf_df], axis=0)
gdb_df.set_index('doc_id', inplace=True)
gdb_df = gdb_df.drop_duplicates(subset='description')

In [None]:
descriptions = list(gdb_df['description'].values)

## Building a MeSH Label Corpus

We need to build a corpus of MeSH label transformed documents that is appropriate for the network we want to build. This will require some filtering, however first we should build a vocabulary of all the terms that we have, so that we can reference any of them by a unique ID at any time.

In [None]:
description_mesh_labels = eval_column(gdb_df, 'mesh_labels')

For filtering later, we will calculate the counts of the MeSH labels. We know already that there are some labels which are highly over-represented, and many which occur only once in the data.

In [None]:
def frequency_filter(docs, high_threshold=None, low_threshold=None, remove=[], counter=None):
    """freqency_filter
    Filters words from a corpus that occur more frequently than high_threshold
    and less frequently than low_threshold.
    
    Args:
        docs (:obj:`list` of :obj:`list`): Corupus of tokenised documents.
        high_threshold (int): Upper limit for token frequency
        low_threshold (int): Lower limit for token frequency
        remove (:obj:`list`): List of terms to remove
    
    Yields:
        doc_filtered (:obj:`list`): Document with elements removed based
            on frequency
    """
    docs_filtered = []
    if counter is None:
        counter = Counter(flatten(docs))
    for doc in docs:
        doc_filtered = []
        for t in doc:
            if t in remove:
                continue
            if high_threshold is not None:
                if counter[t] > high_threshold:
                    continue
            if low_threshold is not None:
                if counter[t] < low_threshold:
                    continue
            doc_filtered.append(t)
        docs_filtered.append(doc_filtered)
    return docs_filtered

def filter_description_labels(description_labels, fn):
    return [list(filter(fn, dl)) for dl in description_labels]   

In [None]:
mesh_label_counts = Counter(flatten(description_mesh_labels))
mesh_label_counts.most_common(20)

In [None]:
description_mesh_labels_filtered = frequency_filter(description_mesh_labels, high_threshold=40000,
                                                    low_threshold=5,
                                                    remove = 
                                                    ['Students', 'Humans', 'Animals', 'Research','Goals',
                                                     'Universities', 'Research Personnel', 'United States', 
                                                     'United Kingdom', 'Research', 'Awards and Prizes',
                                                     'Faculty', 'Mice', 'Mathematics', 'Fellowships and Scholarships',
                                                    'Surveys and Questionnaires'])

In [None]:
from gensim.models.phrases import Phrases, Phraser

In [None]:
bigrams = Phrases(description_mesh_labels_filtered, min_count=3)
bigrammer = Phraser(bigrams)

In [None]:
description_mesh_labels_bigrams = [bigrammer[d] for d in description_mesh_labels_filtered]

In [None]:
trigrams = Phrases(description_mesh_labels_bigrams)
trigrammer = Phraser(trigrams)

In [None]:
description_mesh_labels_trigrams = [trigrammer[d] for d in description_mesh_labels_bigrams]

In [None]:
description_mesh_labels_final = []
for d in description_mesh_labels_trigrams:
    corrected_d = []
    for t in d:
        if len(t.split('_')) > 1:
            parts = t.split('_')
            corrected_d.append(' '.join(sorted(set(parts))))
        else:
            corrected_d.append(t)
    description_mesh_labels_final.append(corrected_d)

In [None]:
dictionary_mesh_labels = Dictionary(description_mesh_labels_final)

## Filtering Descriptions

In [None]:
gdb_df['coocurrence_labels'] = description_mesh_labels_final

In [None]:
gdb_df_co = gdb_df[gdb_df['coocurrence_labels'].str.len() > 2]

## Splitting Projects by Year

We'll take the most recent 10 years of projects.

In [None]:
gdb_df_co = gdb_df_co[(gdb_df_co['year'] >= 2006) & (gdb_df_co['year'] < 2018)]

In [None]:
gdb_df_co['year'].value_counts()

## Building a Sliding Window Coocurrence Network

From here we will want to create a new set of labelled descriptions where the terms with very high counts and little semantic value are removed, and also those that appear very few times in the corpus. We will also need to map the labels to token IDs which can then act as the vertex values in our graph.

In [None]:
times = range(2006, 2018)
co_graphs = [SlidingWindowGraph(gdb_df_co[gdb_df_co['year'] == t]['coocurrence_labels'],
                             dictionary=dictionary_mesh_labels, window_size=2)
          for t in times]

In [None]:
co_graphs = [g.prepare() for g in co_graphs]
co_graphs = [g.build() for g in co_graphs]

In [None]:
association_strengths = [association_strength(g) for g in co_graphs]

In [None]:
start_period = 0
end_period = 3

In [None]:
pg = PhylomemeticGraph(co_graphs[start_period:end_period], association_strengths[start_period:end_period],
                       dictionary_mesh_labels, times[start_period:end_period],
                       max_weight=None, min_weight=None)

In [None]:
# %time pg = pg.prepare('/Users/grichardson/cfinder/pg_out', '/Users/grichardson/cfinder/CFinder_commandline_mac')
%time pg = pg.prepare('/Users/grichardson/cfinder/pg_out')

In [None]:
for cs in pg.clique_sets:
    print(Counter([len(c) for c in cs]))

In [None]:
%time pg.build(workers=4, min_clique_size=5, delta_0=0.4, parent_limit=4)

In [None]:
def find_antecedents(vertices, limit):
    """find_antecedents
    Find all the antecedents of a particular vertex.
    
    Args:
        vertices (:obj:`iter` of :obj:`Vertex`): A list of vertices
            for which the antecedents need to be found.
    
    Returns:
        antes (:obj:`iter` of :obj:`Vertex`): A list of vertices that
            are the antecedents of the input vertices.
    """
    l = 0
    antes = []
    for v in vertices:
        if l < limit:
            if v.in_degree() > 0:
                antes.append(list(v.in_neighbors()))
                if len(list(v.in_neighbors())) > 1:
                    antes += find_antecedents(v.in_neighbors())
                else:
                    antes += find_antecedents(list(v.in_neighbors()))
                l += 1
    return antes

def find_descendents(vertices, limit):
    """find_descendents
    Find all the descendents of a particular vertex.
    
    Args:
        vertices (:obj:`iter` of :obj:`Vertex`): A list of vertices
            for which the descendents need to be found.
    
    Returns:
        desc (:obj:`iter` of :obj:`Vertex`): A list of vertices that
            are the descendents of the input vertices.
    """
    l = 0
    desc = []
    for v in vertices:
        if l < limit:
            if v.out_degree() > 0:
                desc.append(list(v.out_neighbors()))
                if len(list(v.out_neighbors())) > 1:
                    desc += find_descendents(v.out_neighbors(), limit - l - 1)
                else:
                    desc += find_descendents(list(v.out_neighbors()), limit - l - 1)
                l += 1
    return desc

In [None]:
pg_thresh = GraphView(pg, efilt=lambda e: pg.ep['jaccard_weights'][e] > 0.5)
pg_thresh = GraphView(pg_thresh, vfilt=lambda v: (v.out_degree() > 0) | (v.in_degree() > 0))
graph_draw(pg_thresh, vertex_fill_color=pg_thresh.vp['age'])

In [None]:
for vertex in pg_thresh.vertices():
#     if np.random.randint(0, 10) > 5:
    if vertex.in_degree() == 1:
#             if vertex.out_degree() > 0:
        terms_s = pg_thresh.vp['terms'][vertex]
        print('-', pg_thresh.vp['times'][vertex], '-')
        print(' + '.join(sorted([dictionary_mesh_labels[t] for t in terms_s])))
        print(pg_thresh.vp['density'][vertex])
        print('\n=== Parents ===')

        for i, n in enumerate(vertex.in_neighbors()):
            terms_n = pg_thresh.vp['terms'][n]
            print(pg_thresh.vp['times'][n])
            print(' + '.join(sorted([dictionary_mesh_labels[t] for t in terms_n])))

#         print('\n=== Children ===')

#         for i, n in enumerate(n.out_neighbours()):
#             terms_n = pg_thresh.vp['terms'][n]
#             print(pg_thresh.vp['times'][n])
#             print(' + '.join(sorted([dictionary_mesh_labels[t] for t in terms_n])))

        print('\n')

## Load Pre-Computed Graph

In [None]:
from rhodonite.dynamics import label_emergence, label_special_events

In [None]:
pg_full = Graph()

In [None]:
pg_full.load(os.path.join(inter_data, 'phylomemetics/pg_2006_2017_cat.xml.gz'))

In [None]:
def correct_density(g):
    corr = g.new_vertex_property('float')
    for v in g.vertices():
        card = len(g.vp['terms'][v])
        corr[v] = g.vp['density'][v] / card
    return corr

In [None]:
corr = correct_density(pg_full)

In [None]:
def filter_phylomemetic_graph(g, term, dictionary, min_jaccard=0):
    """filter_phylomemetic_graph
    Get a subgraph of vertices that contain a particular term.
    """
    term_id = dictionary.token2id[term]
    g_filt = GraphView(
        g,
        vfilt=lambda v: term_id in g.vp['terms'][v],
        efilt=lambda e: g.ep['jaccard_weights'][e] > min_jaccard
    )
    return g_filt

In [None]:
def get_aggregate_vp(g, vp, vp_grouper, agg=None):
    """aggregate_property_map
    
    Args:
        g (:obj:`Graph`): A graph.
        vp (str): String representing an internal property map
            of graph, g.
        vp_grouper (str): String representing name of an internal
            property map that will be used to group by.
        agg (:obj:`function`): Function to aggregate by. For
            example, min, max, sum, numpy.mean, etc.
    Returns:
        (:obj:`iter` of float): Aggregated values from x. 
    """
    vp_vals = get_vp_values(g, vp)
    vp_agg = get_vp_values(g, vp_grouper)
    
    sid_x = vp_agg.argsort()
    # Get where the sorted version of base changes groups
    split_idx = np.flatnonzero(np.diff(vp_agg[sid_x]) > 0) + 1
    # OR np.unique(base[sidx],return_index=True)[1][1:]

    # Finally sort inp based on the sorted indices and split based on split_idx
    vp_vals_grouped = np.split(vp_vals[sid_x], split_idx)
    
    x = sorted(set(vp_agg))
    if agg: 
        y = [agg(vvg) for vvg in vp_vals_grouped]
    else:
        y = vp_vals_grouped

    return x, y

#### Label Age, Emergence, and Special Events

In [None]:
emergence = label_emergence(pg_full)
pg_full.vp['emergence'] = emergence

In [None]:
pg_full.vp['density_c'] = corr

In [None]:
branching, merging = label_special_events(pg_full)
pg_full.vp['merging'] = branching
pg_full.vp['branching'] = merging

In [None]:
ages = label_ages(pg_full)
pg_full.vp['age'] = ages

#### Normalise Density

In [None]:
years, density_anual_mean = get_aggregate_vp(pg_full, 'density', 'times', agg=np.mean)
year_density_mean_mapping = {k: v for k, v in zip(years, density_anual_mean)}

density_normed = pg_full.new_vertex_property('float')

for v in pg_full.vertices():
    year = pg_full.vp['times'][v]
    density = pg_full.vp['density'][v]
    d_mean = year_density_mean_mapping[year]
    density_normed[v] = density / d_mean
    
pg_full.vp['density_normed'] = density_normed

### Analysis

#### Density and Emergence

In [None]:
pg_full_density, pg_full_emergence = get_aggregate_vp(pg_full, 'density_normed', 'emergence', np.median)

fig, ax = plt.subplots()
plt.plot(['Ephemeral', 'Emerging', 'Steady', 'Declining'], pg_full_emergence / np.max(pg_full_emergence),
        linewidth=3)

ax.set_xlabel('Emergence')
ax.set_ylabel('Median Field Density (Normalised)')
plt.tight_layout()
plt.show()

In [None]:
emergence_map = {0: 'Ephemeral', 1: 'Emerging', 2: 'Steady', 3: 'Declining'}
thresh_pgs = []
thresh_dfs = []
j_threshes = [0.4, 0.5, 0.6, 0.7, 0.8]
for j_thresh in j_threshes:
    thresh_pg = GraphView(
        pg_full,
        efilt=lambda e: pg_full.ep['jaccard_weights'][e] > j_thresh
    )
    emergence = label_emergence(thresh_pg)
    branching, merging = label_special_events(thresh_pg)
    ages = label_ages(thresh_pg)
    df = pd.DataFrame({
                 'branching': branching.get_array(),
                 'merging': merging.get_array(),
                 'emergence': emergence.get_array(),
                 'age': ages.get_array()}
                )
    df['emergence'].map(emergence_map)
    df.columns = [c + '_{}'.format(j_thresh) for c in df.columns]
    thresh_dfs.append(df)
    thresh_pgs.append(thresh_pg)

In [None]:
pg_df = pd.concat(thresh_dfs, axis=1)
years = get_vp_values(thresh_pg, 'times')
density = get_vp_values(thresh_pg, 'density_normed')
pg_df['year'] = years
pg_df['density'] = density

In [None]:
for c in pg_df.columns:
    if 'emergence' in c:
        pg_df[c] = pg_df[c].map(emergence_map)
    if ('density_' in c) | ('year_' in c):
        pg_df.drop(c, inplace=True, axis=1)

In [None]:
pg_df.head()

In [None]:
fig, ax = plt.subplots()
for j in j_threshes[:-1]:
    cols = [c for c in pg_df.columns if str(j) in c]
    cols.append('density')
    df_temp = pg_df[cols]
    group = df_temp.groupby('emergence_{}'.format(j)).mean()
    group = group.loc[['Ephemeral', 'Emerging', 'Steady', 'Declining']]
    plt.plot(group.index.values, group['density'], linewidth=3, label=r'$\delta_0: {}$'.format(j))
    plt.scatter(group.index.values, group['density'], linewidth=3, label=None)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels)
plt.tight_layout()
plt.show()

#### Emergence and Special Events

In [None]:
pg_5 = GraphView(pg_full,
                 efilt=lambda e: pg_full.ep['jaccard_weights'][e] > 0.45)
emergence = label_emergence(pg_5)
pg_5.vp['emergence'] = emergence
branching, merging = label_special_events(pg_5)
pg_5.vp['branching'] = branching
pg_5.vp['merging'] = merging
ages = label_ages(pg_5)
pg_5.vp['age'] = ages

In [None]:
props = {}
for k in pg_5.vertex_properties.keys():
    if k != 'terms':
        props[k] = get_vp_values(pg_5, k)
    
pg_5_df = pd.DataFrame(props)
pg_5_df.drop(columns=['color', 'density', 'density_normed'], inplace=True)
pg_5_df['terms'] = [pg_5.vp['terms'][v] for v in pg_5.vertices()]

In [None]:
pg_5_df = pg_5_df[pg_5_df['times'] != 2006]
pg_5_df = pg_5_df[pg_5_df['times'] != 2007]
pg_5_df = pg_5_df[pg_5_df['times'] != 2017]

In [None]:
pg_5_groupby_year = pg_5_df.groupby('times')
years = pg_5_groupby_year.count().index.values

merging_frac_year = pg_5_groupby_year.sum()['merging'] / pg_5_groupby_year.count()['merging'] * 100
branching_frac_year = pg_5_groupby_year.sum()['branching'] / pg_5_groupby_year.count()['branching']  * 100

fig, ax = plt.subplots(1, figsize=(7, 5))
ax.plot(years, merging_frac_year, label='% Merging', linewidth=3)
ax.scatter(years, merging_frac_year, label=None)
ax.plot(years, branching_frac_year, label='% Branching', linewidth=3)
ax.scatter(years, branching_frac_year, label=None)
ax.set_ylabel('% of Fields')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels)
plt.show()

In [None]:
# times = []
# dens = []
# terms = []
# branching = []
# mergning = []
# vertex = []
# for v in pg_5.vertices():
#     if (pg_5.vp['emergence'][v] == 1) | (pg_5.vp['emergence'][v] == 2) | (pg_5.vp['emergence'][v] == 3):
#         terms.append(pg_5.vp['terms'][v])
#         dens.append(pg_5.vp['density_normed'][v])
#         times.append(pg_5.vp['times'][v])
#         branching.append(pg_5.vp['branching'][v])
#         branching.append(pg_5.vp['merging'][v])
#         vertex.append(int(v))
# df = pd.DataFrame({'year': times, 'density': dens, 'terms': terms, 'vertex': vertex,
#                    'merging': merging, 'branching': branching})
# df = df[df['year'] < 2017]

query_terms = ['Software']
term_id = [dictionary_mesh_labels.token2id[t] for t in query_terms]

df_terms = pg_5_df[pg_5_df['terms'].apply(lambda x: True if all([t in x for t in term_id]) else False)]

term_groupby_year = df_terms.groupby('times')
years = term_groupby_year.count().index.values

merging_frac_year = term_groupby_year.sum()['merging'] / term_groupby_year.count()['merging'] * 100
branching_frac_year = term_groupby_year.sum()['branching'] / term_groupby_year.count()['branching']  * 100

fig, ax = plt.subplots(1, figsize=(7, 5))
ax.plot(years, merging_frac_year, label='% Merging', linewidth=3)
ax.scatter(years, merging_frac_year, label=None)
ax.plot(years, branching_frac_year, label='% Branching', linewidth=3)
ax.scatter(years, branching_frac_year, label=None)
ax.set_ylabel('% of Fields')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels)
plt.show()

In [None]:
emergence_fractions = defaultdict(list)

for i, g in pg_5_groupby_year:
    n = len(g)
    counts = g['emergence'].value_counts()
    for j in range(4):
        try:
            emergence_fractions[emergence_map[j].lower()].append(counts.loc[j] / n)
        except:
            emergence_fractions[emergence_map[j].lower()].append(np.nan)

emergence_fractions_df = pd.DataFrame(emergence_fractions)

In [None]:
fig, axs = plt.subplots(2, 4, figsize=(16, 6))

for ax, c in zip(axs[0], emergence_fractions_df.columns):
    ax.scatter(emergence_fractions_df[c] * 100, merging_frac_year * 100)
    ax.set_title('% {}'.format(c.title()))
for ax, c in zip(axs[1], emergence_fractions_df.columns):
    ax.scatter(emergence_fractions_df[c] * 100, branching_frac_year * 100, color='#ff7f0e')

axs[0][0].set_ylabel('% Merging', fontsize=12)
axs[1][0].set_ylabel('% Branching', fontsize=12)
for ax_x in axs:
    for ax_y in ax_x:
        ax_y.tick_params(
            axis='x', which='both', bottom=False, top=False, labelbottom=False) 
        ax_y.tick_params(
            axis='y', which='both', left=False, right=False, labelleft=False)
plt.tight_layout()
plt.show()


#### Density and Age

In [None]:
pg_5_density_age_mean, pg_5_age = get_aggregate_vp(pg_5, 'density_normed', 'age', agg=np.mean)

In [None]:
fig, ax = plt.subplots()
ax.plot(pg_5_density_age_mean[:8], pg_5_age[:8], linewidth=3)
ax.scatter(pg_5_density_age_mean[:8], pg_5_age[:8])

ax.set_xlabel('Branch Age (years)')
ax.set_ylabel('Median Density')

plt.show()

#### Surrounding Densities

In [None]:
for 

#### Density and Times

In [None]:
def get_term_density_vs_time(g, dictionary, term_sets=[]):
    term_df = []
    for terms in term_sets:
        dens = []
        times = []
        term_ids = [dictionary.token2id[t] for t in terms]
        for v in g.vertices():
            terms_v = g.vp['terms'][v]
            if set(term_ids).issubset(terms_v):
                dens.append(g.vp['density_normed'][v])
                times.append(g.vp['times'][v])
        df = pd.DataFrame({' '.join(terms): dens, 'year': times})
        term_df.append(df.groupby('year').mean())
    df = pd.concat(term_df, axis=1)
    return df

In [None]:
female_pregnancy_df = get_term_density_vs_time(
    pg_5,
    dictionary_mesh_labels,
    [['Machine Learning'],
     ['Algorithms'],
     ['Pregnancy'],
     ['Neoplasms'],
    ]
)

In [None]:
ax = plt.plot(female_pregnancy_df.loc[2007:2016])

#### Linear Network Plot

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.decomposition import PCA, TruncatedSVD

In [None]:
mlb = MultiLabelBinarizer(classes=dictionary_mesh_labels.keys(), sparse_output=True)
pca = PCA(n_components=1)
svd = TruncatedSVD(n_components=1)

In [None]:
times = []
dens = []
terms = []
vertex = []
for v in pg_5.vertices():
    if (pg_5.vp['emergence'][v] == 1) | (pg_5.vp['emergence'][v] == 2) | (pg_5.vp['emergence'][v] == 3):
        terms.append(pg_5.vp['terms'][v])
        dens.append(pg_5.vp['density_normed'][v])
        times.append(pg_5.vp['times'][v])
        vertex.append(int(v))
df = pd.DataFrame({'year': times, 'density': dens, 'terms': terms, 'vertex': vertex})
df = df[df['year'] < 2017]

In [None]:
mlb.fit(df['terms'])
svd.fit(mlb.transform(df['terms']))

In [None]:
matrix = mlb.transform(df['terms'])
svd_vals = svd.transform(matrix)

df['y_pos'] = svd_vals

In [None]:
df['density_log'] = np.log10(df['density'])
df.set_index('vertex', inplace=True)

In [None]:
edges = [(np.int(e.source()), np.int(e.target())) for e in pg_5.edges()]

In [None]:
[k for k in dictionary_mesh_labels.token2id.keys() if 'Robot' in k]

In [None]:
query_terms = ['Machine Learning', 'Brain']
term_id = [dictionary_mesh_labels.token2id[t] for t in query_terms]

df_term = df[df['terms'].apply(lambda x: True if all([t in x for t in term_id]) else False)]

edges_term = [e for e in edges if (e[0] in df_term.index.values) & (e[1] in df_term.index.values)]

edges_x0 = [df.loc[e[0]]['year'].astype(np.int32) for e in edges_term]
edges_x1 = [df.loc[e[1]]['year'].astype(np.int32) for e in edges_term]
edges_y0 = [df.loc[e[0]]['y_pos'].astype(np.float32) for e in edges_term]
edges_y1 = [df.loc[e[1]]['y_pos'].astype(np.float32) for e in edges_term]


terms = [', '.join([dictionary_mesh_labels[t] for t in tokens]) for tokens in df_term['terms'].values]

In [None]:
cds = ColumnDataSource({'year': df_term['year'].astype(np.int32),
                        'y_pos': df_term['y_pos'].astype(np.float32),
                        'terms': terms,
                        'density': df_term['density_log'].astype(np.float32)})
cds_edges = ColumnDataSource({'x0': np.array(edges_x0).astype(np.float32),
                              'x1': np.array(edges_x1).astype(np.float32),
                              'y0': np.array(edges_y0).astype(np.float32),
                              'y1': np.array(edges_y1).astype(np.float32)})

In [None]:
p = figure(width=900, height=400)

hover = HoverTool(
    tooltips=[
        ('Terms', '@terms'),
    ],
names=["vertices"])

color_mapper = LinearColorMapper(
    palette='Magma256', low=min(df_term['density_log']), high=max(df_term['density_log']))

p.segment(source=cds_edges, x0='x0', y0='y0', x1='x1', y1='y1', alpha=0.2, line_width=2, color='gray', name='edges')
p.circle(source=cds, x='year', y='y_pos', size=7, fill_alpha=0.6, line_alpha=0, name='vertices',
        color={'field': 'density', 'transform': color_mapper})

p.add_tools(hover)
p.x_range = Range1d(2006, 2017)
p.xaxis.ticker = list(range(2007, 2017))
p.yaxis.ticker = []
p.xaxis.major_label_text_font_size = "12pt"

show(p)

In [None]:
from graph_tool.topology import label_components

In [None]:
comps = label_components(pg_5)

In [None]:
fig, ax = plt.subplots(figsize=(16, 4))
ax.plot(comps[0].get_array()[:5000])

#### plt.scatter( df_term['year'], df_term['y_pos'])


In [None]:
def get_association_strength_vs_time(term_0, term_1):
    a_s = []
    co_s = []
    for a, c in zip(association_strengths, co_graphs):
        if (term_0 in c.token2vertex) & (term_1 in c.token2vertex):
            t_0 = c.token2vertex[term_0]
            t_1 = c.token2vertex[term_1]
            edge = tuple(sorted([t_0, t_1]))
            try:
                a_s.append(a[edge] / np.mean(a.get_array()))
                co_s.append(c.ep.cooccurrences[edge])
            except:
                a_s.append(0)
                co_s.append(0)
        else:
            a_s.append(0)
            co_s.append(0)
    return a_s, co_s

In [None]:
a_s, co_s = get_association_strength_vs_time('Machine Learning', 'Bacteria')
plt.plot(a_s)

In [None]:
preds = {}

for v in pg_5.vertices():
    desc = find_descendents([v], 2)
    preds[v] = desc

In [None]:
gfilt = GraphView(
        pg_full, 
        efilt=lambda e: pg_full.ep['jaccard_weights'][e] > 0.8,
        vfilt=lambda v: pg_full.vp['times'][v] not in [2006, 2017, 2018])

In [None]:
x_dna, y_dna = get_vp_x_y(pg_dna, 'times', 'density', np.median)
x_ml, y_ml = get_vp_x_y(pg_ml, 'times', 'density', np.median)
x_preg, y_preg = get_vp_x_y(pg_preg, 'times', 'density', np.median)

In [None]:
plt.plot(x_dna[1:], y_dna[1:])
plt.plot(x_ml[1:], y_ml[1:])
plt.plot(x_preg[1:], y_preg[1:])

In [None]:
def sankey(g):
    