In [1]:
import joblib
import numpy as np
import tarfile
import boto3
import io
from sagemaker import get_execution_role
import pandas as pd
from multiprocessing import Pool
from glob import glob
from sklearn.metrics.pairwise import cosine_similarity
from plotnine import ggplot, ggtitle, geoms, aes, theme_classic, scales, theme, labels, element_blank, ylim,facet_wrap
import networkx as nx
import pygraphviz as pgv # pygraphviz should be available
import json
from sklearn.cluster import AgglomerativeClustering
from IPython.core.display import HTML, display
from string import Template

In [2]:
descriptions = pd.concat([pd.read_csv(g) for g in glob('../data/descriptions/*.csv')])
connection = boto3.client('s3')
paginator = connection.get_paginator('list_objects_v2')

pages = paginator.paginate(Bucket='ascsagemaker', Prefix="JMP_congressional_nmf/unigram_models")
unigram_tarballs = []
for page in pages:
    for ob in page['Contents']:
        if ob['Key'].endswith('.tar.gz'):
            unigram_tarballs.append(ob['Key'])


In [3]:
def make_json(persistent_array,k1,k2,chamber,order_1, order_2):
    chamber_descrip = descriptions.loc[descriptions.chamber == int(chamber)]
    k1_descrip = chamber_descrip.loc[chamber_descrip.k == k1].to_numpy()
    k2_descrip = chamber_descrip.loc[chamber_descrip.k == k2].to_numpy()
    JSON = {"nodes":[],'links':[]}
    for ir,r in enumerate(order_1): JSON['nodes'].append({"name":f"{k1}_{r}",'level':k1,'descrip':list(k1_descrip[r,1:-2])})
    for ic, c in enumerate(order_2): JSON['nodes'].append({"name":f"{k2}_{c}",'level':k2,'descrip':list(k2_descrip[c,1:-2])})
    
    level1,level2 = np.where(persistent_array == 1)
    for ix, _from in enumerate(level1):
        JSON['links'].append({"source":f'{k1}_{_from}','target':f'{k2}_{level2[ix]}'})
    
    return JSON

In [4]:
def map_persistent_topics(k1,k2,chamber,Hs, thresh=0.75):
    # get models for k1 and k2
    H1 = Hs[k1]
    H2 = Hs[k2]
    
    clustering_H1 = AgglomerativeClustering(H1.shape[0],affinity='cosine',linkage='complete').fit(cosine_similarity(H1))
    clustering_H2 = AgglomerativeClustering(H2.shape[0],affinity='cosine',linkage='complete').fit(cosine_similarity(H2))
    H1_order = clustering_H1.labels_
    H2_order = clustering_H2.labels_
    
    topic_array = np.zeros([H1.shape[0],H2.shape[0]])
    
    for t,topic_k in enumerate(H1):
        for t2,topic_k2 in enumerate(H2):
            topic_array[t,t2] = cosine_similarity([topic_k,topic_k2])[0,1]

    persistent_array =  np.where(topic_array >= thresh, 1, 0)
    results = {'array':persistent_array,'json':make_json(persistent_array,k1,k2,chamber,H1_order,H2_order)}
    return results
    

In [32]:
base_template = Template('''

<div id='graph-div'></div>
<script> $js_text </script>

''')

plot_template = Template('''


var margin = {top:20, right:20, bottom:30, left:40},
    width = 1100,
    height = 2000 - margin.top - margin.bottom;
    
var svg = d3.select('#graph-div').append('svg')
    .attr("width",width)
    .attr("height",height);
    
var div = d3.select("body").append("div")
    .attr("class", "tooltip")
    .style("opacity", 0)
    .style('background','white')
    .style('border','2px solid black')
    .style('position','absolute');

var data = $python_data;

nodes = svg.selectAll('rect')
    .data(data.nodes)
    .enter()
    .append('rect')
    .attr('id',function(d){return "x"+d.name})
    .attr('x',function(d){return xSpacer(d.level)})
    .attr('width',20)
    .attr('y',function(d){return yCenter(d.name)})
    .attr('height',10)
    .attr('stroke','black')
    .attr('fill','white');
    
nodes
    .on('mouseover',function(d){
        div.style('opacity',1)
        .html(d.descrip)
        .style('left',d3.event.pageX + 40 + 'px')
        .style('top',d3.event.pageY - 40 + 'px');
    })
    .on('mouseout',function(d){
    div.style('opacity',0);    
    })
    .on('click',function(d){
        nodes.attr('fill','white');
        if (d3.select(this).attr('fill') == 'red'){
            d3.select(this).attr('fill','white')
            links.attr('stroke','black')
        } else {
            d3.select(this).attr('fill','red')
            highlighter(d)
        }
    })

links = svg.selectAll('line')
    .data(data.links)
    .enter()
    .append('line')
    .attr('x1',function(d){return xSpacer(d.source.split('_')[0]) + 20})
    .attr('x2',function(d){return xSpacer(d.target.split('_')[0])})
    .attr('y1',function(d){return yCenter(d.source) + 5})
    .attr('y2',function(d){return yCenter(d.target) + 5})
    .attr('stroke','black');


function highlighter(d){
    links.each(function(x){
        d3.select(this).style('stroke',function(x){
            if (x.source == d.name || x.target == d.name){
                if (d3.select(this).style('stroke') == "rgb(0, 0, 0)"){
                    return 'red'
                }
            }
        }).style('stroke-width',function(x){
            if (x.source == d.name || x.target == d.name){
                if (d3.select(this).style('stroke') == "red"){
                    return "3px"
                }
            }
        })
        
    })
    
}

function xSpacer(level){
    if (level == 25){
        return(1)
    } else if (level == 50) {
        return(200)
    } else if (level == 75) {
        return(400)
    } else if (level == 100){
        return(600)
    } else if (level == 125){
        return(800)
    }
}

function yCenter(name){
    level = name.split('_')[0]
    ypos = name.split('_')[1] * 15
    if (level == 125) {
        startPos = 10
    } else if (level == 100) {
        startPos = 200
    } else if (level == 75) {
        startPos = 400
    } else if (level == 50) {
        startPos = 600
    } else if (level == 25) {
        startPos = 800
    }
    return(startPos + ypos)

}
''')

In [33]:
def display_topic_graph(chamber_ix,thresh):
    t = unigram_tarballs[chamber_ix]
    object_ = connection.get_object(Bucket='ascsagemaker',Key=t)['Body'].read()
    tar = tarfile.open(fileobj=io.BytesIO(object_))

    # get members and chamber name
    members = tar.getmembers()
    chamber = members[0].name.split('_')[1]
    print(f'got members -- {chamber}')

    topic_dists = []
    for member in members:
        # load the model.pkl and get name
        model = joblib.load(tar.extractfile(member=tar.getmember(member.name)))
        model['model'] = member.name
        H = model['H']
        topic_dists.append(H)
    
    Hs = {}
    for i,H in enumerate(topic_dists):
        Hs[H.shape[0]] = H

    combos = [(25,50),(50,75),(75,100),(100,125)]
    result_list = []
    persistent_json = 0
    
    for combo in combos:
        results = map_persistent_topics(combo[0],combo[1],chamber,Hs, thresh)
        
        if type(persistent_json) == int:
            persistent_json = results['json']
        else:
            persistent_json['nodes'].extend(results['json']['nodes'])
            persistent_json['links'].extend(results['json']['links'])
    return persistent_json
    

In [34]:
HTML("<script src='https://d3js.org/d3.v5.min.js'></script>")


In [36]:
persistent_json = display_topic_graph(-1,0.3)

js_text = plot_template.substitute({'python_data':json.dumps(persistent_json)})
HTML(base_template.substitute({'js_text':js_text}))

got members -- 114
