In [17]:
from LDA_simplified import LDA

In [18]:
#lda = LDA(5,'/Users/timi/Documents/TU/Python/Viz2/Project/aktuell/InfovisVAST-papers.jig.txt')
#lda = LDA(5,'C:/Users/Andras/TU/Visualization 2/Viz2_Project/data/InfovisVAST-papers.jig.txt')
lda = LDA(5,'InfovisVAST-papers.jig.txt')

Build up the nodes and edges for the cytoscape network graph

In [19]:
# extract the color from class settings: will be defined in stylesheet
# the label is the top 4 words
# position defined by lda
def get_graph_topic_nodes():
    return [{'data': {'id': id_, 'label': vals[0]},
             'classes': f'topic_{vals[1]}'                                                                                        #,"opacity": 0.1
            ,'position':{'x' : vals[2][0] , 'y' : vals[2][1]}
            } for id_,vals in lda.get_topic_nodes_().items()                        
        ]

In [20]:
# document nodes: id by document id, class defined in stylesheet,
# define the color and belonging
def get_graph_document_nodes():
    return [                       
            {'data': {'id': id_,'size': 1000}, 
             'style': {'shape': 'circle'},
             'classes': vals[1]
             #vals[1]
            } for id_,vals in lda.get_doc_nodes_().items()                        
        ]


In [21]:
# document edges based on cosine similarity
def get_graph_cos_sim_edges():
        return [
            {'data': {'source': f[0], 'target': f[1],'label': f'{f[0]} -> {f[1]}'}} for f in lda.get_edges_()
        ]

In [22]:
# edges between the  topics and the related documents: they are invisible
def get_doc_topic_edges():
    return  [
            {'data': {'source': id_, 'target': vals[2],'label': f'{id_} -> {vals[2]}',"edgeLength":200, 'size': 5},
            'style': {'line-color': 'white', "opacity": 0}}  for id_,vals in lda.get_doc_nodes_().items() 
        ]

In [23]:
# Update stylessheet:  define all the new clusters class settings
# after the update_lda step e.g.
def update_stylesheet():
    colors = sorted(set([vals[1] for id_,vals in lda.get_doc_nodes_().items()]))
    #get the new colors for the topics and the nodes
    node_classes = [{
                    'selector': f'.{c}',
                    'style': {
                        'background-color': c
                    }
                } for c in colors]
    
    topic_classes = [{
                    'selector': f'.topic_{c}',
                    'style': { 'border-color': c,
               'border-width': 2,        
               'background-color': 'white',
                'shape': 'rectangle','content': 'data(label)','text-halign':'center',
                'text-valign':'center','text-wrap': 'wrap','width':'label','height':'label'}
                }for c in colors     ]

    
    return node_classes + topic_classes   # Class selectors
                 
                           
            

In [24]:
import plotly.express as px
def build_cluster_summary_view():
    clust = 'Clusters'
    labels = [clust]
    parents = ['']
    marker_colors = ["white"]
    text_info = ['']
    for id_,vals in lda.get_topic_nodes_().items():
        labels.append(id_)
        #text_info.append('+'.join([v for v in vals[0].replace('\n', ' ').split(' ')]))
        #text_info.append('\n'.join([v for v in vals[0].replace('\n', ' ').split(' ')]))
        text_info.append(vals[0].replace('\n','<br>'))
        parents.append(clust)
        marker_colors.append(vals[1])
        #print([v for v in vals[1].replace('\\n', ' ').split(' ')])
    clust_sum_view = dict()
    clust_sum_view['labels'] =  labels
    clust_sum_view['parents'] =  parents
    clust_sum_view['marker_colors'] =  marker_colors
    clust_sum_view['text_info'] =  text_info
    return  clust_sum_view

initial_sum_view = build_cluster_summary_view()

In [25]:
def build_cluster_merge_list():
    clusters = []
    for k in lda.get_topic_nodes().keys():
        clusters.append({'label': k, 'value': k.replace('Cluster ', '')})

    return clusters


In [32]:
from jupyter_dash import JupyterDash  #  pip install jupyter-dash
import dash_cytoscape as cyto  # pip install dash-cytoscape==0.2.0 or higher
import dash_html_components as html
import dash_core_components as dcc

from dash.dependencies import Output, Input,State
import pandas as pd  # pip install pandas


import dash_bootstrap_components as dbc #pip install dash-bootstrap-components
import dash_table
import dash
import plotly.graph_objects as go
import json
import plotly.graph_objs as go


import dash_dangerously_set_inner_html



external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
# SKETCHY
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.MATERIA], suppress_callback_exceptions=True)


app.layout = html.Div([html.Div(html.H1('iVisClustering: An Interactive Visual Document Clustering via Topic Modeling', style={'backgroundColor':'lightgray'})),
    ############## FIRST ROW ########################
    dbc.Row([
        dbc.Col([
            html.Div(id='empty-div', children=''),
            dcc.Location(id='url', refresh=True),
            html.H3('Cluster Relation View'),   
            html.Br(),
            html.A(html.Button('Reset settings'),href='/'),
            html.Br(),
            html.I("Number of clusters (min: 2 , max: 10)"),
            dcc.Input(id="input1", type="number", min = 2, value = 5, max = 10),
            dbc.Modal([dbc.ModalHeader("Warning!"),
                       dbc.ModalBody("Number of clusters should be within 2 and 10"),
                       dbc.ModalFooter(dbc.Button(
                                            "Close", id="close_warning_num_of_cluster", className="ml-auto"
                                        )
                                    ),
                       ],
            id="number_of_topics_warning",
            centered=True),
            html.Button("Update", id="update_button", n_clicks = 0),
            html.Br(),
            html.I("Cosine similarity:"),
        html.Div(id="output"),
            dcc.Slider(
                id='my-daq-slider-ex',
                min=0, max=1, value=0.4, step = 0.05
            ),
            html.Br(),
            dcc.Interval(id='refresh', interval=1),
            html.I("Delete marked node/document:"),
           html.Button("Delete", id="delete_button",  n_clicks = 0)],width = 2),
            dbc.Modal([dbc.ModalHeader("Warning!"),
                           dbc.ModalBody("You have to select a document (node in graphs, except document summary) to delete"),
                           dbc.ModalFooter(dbc.Button(
                                                "Close", id="close_delete_document_warning", className="ml-auto"
                                            )
                                        ),
                            ],
                id="delete_document_warning",
                centered=True),
        dbc.Col([
           cyto.Cytoscape(
                id='cytoscape',
                minZoom=0.2,
                maxZoom=2,
                autoRefreshLayout = True,                
                layout={'name': 'cose', 'animate': True},        
                style={'width': '100%', 'height': '400px'},
                elements=  get_graph_topic_nodes() + get_graph_document_nodes() + get_graph_cos_sim_edges() + get_doc_topic_edges(),
                stylesheet = update_stylesheet()
        
        
    )



    ], width = 3),
        ###### intitialise the place for the cluster summary view,
        ###### update it with a callback
        dbc.Col([html.Div([
            html.H3('Cluster Summary View'),
            html.Div([dcc.Graph(id = 'clust_sum_graph',
                                #style={'width': '40vh', 'height': '40vh',},
                                style={'width': '100%', 'height': '40vh',},
                                figure= go.Figure(data = go.Treemap(labels  = initial_sum_view['labels'],
                                                                    parents = initial_sum_view['parents'],
                                                                    marker_colors =  initial_sum_view['marker_colors'],
                                                                    text = initial_sum_view['text_info']),
                                                  layout = go.Layout(margin={'t': 0, 'l': 0, 'r': 5, 'b' : 2})))])
            ,html.Div([ html.Button("Delete Cluster", id="delete_cluster_button",  n_clicks = 0),
                        dbc.Modal([
                                    dbc.ModalHeader("Warning!"),
                                    dbc.ModalBody("You have to select a cluster first from the Cluster summary view"),
                                    dbc.ModalFooter(
                                        dbc.Button(
                                            "Close", id="close-warning", className="ml-auto"
                                        )
                                    ),
                                ],
                                    id="cluster_delete_failed_warning",
                                    centered=True,
        )])])
        ], width = 2),
        dbc.Col([html.Div([html.I("Merge clusters:"),
                            html.Br(),
                           dcc.Checklist(
                                options = build_cluster_merge_list(),
                                id = 'cluster_merge_checklist'
                           ),
                           html.Button("Merge Clusters", id="merge_cluster_button",  n_clicks = 0),
                           dbc.Modal([
                                        dbc.ModalHeader("Warning!"),
                                        dbc.ModalBody("You have to select a cluster first from the Cluster summary view"),
                                        dbc.ModalFooter(
                                            dbc.Button(
                                                "Close", id="close_cluster_merge_warning", className="ml-auto"
                                            )
                                        ),
                                    ],
                            id="cluster_merge_failed_warning",
                            centered=True,
            )])],width = 1),
        ### initialize the views, update by callback
       dbc.Col([html.Div([
            html.H3('Term Weight View'),
            dcc.Input(id="input_barchart", type="number", min = 2, value = 10),
            html.Div([
                dash_table.DataTable(
                                id='table',
                                #data=[{'Words': '', 'Probabilities': ''}],
                                #columns= ['Words','Probabilities'],
                                editable = True,
                                style_cell={'width': '50px',
                                            'height': '30px',
                                            'textAlign': 'left'}
                                      )
                           ]
            ,id='term-weighttable')

        ], style= {'display': 'none'}, id = 'term-weight-input')], width = 2),
        dbc.Col([html.Div([
            html.Button('Change Prob.', id='change-probs', n_clicks=0),
            html.Div([],id='bar_chart')
        ], style= {'display': 'none'}, id = 'change-prob-btn-div')], width = 2)
    ],
    style = {'padding': 10}
    ),

     ## paralell coordinates plot
    dbc.Row([
        dbc.Col([
            html.H3('Parallel Coordinates View'),
            dcc.Graph(id='parallel_coord', 
                      style={'width': '130vh', 'height': '50vh'},
                      figure = go.Figure(data= 
            go.Parcoords(
                line = dict(color = lda.get_parall_coord_df()['Dominant_Topic'],
                           colorscale = list(lda.get_colors().values())[:lda.get_k()]),
                dimensions = [                       
                    dict(range = [0,1],
                        label = f'Cluster {i}', 
                        values = lda.get_parall_coord_df()[i])
                     for i in range(lda.get_k())                      
                ]

            )
            ))                                  
        ], width = 5),
        dbc.Col([ html.I("Set threshold for the probability, that a document belongs to the cluster :"),
                 html.I("Filter out noisy documents"),
                    dcc.Slider(
                id='pc_slider',
                min=0, max=1, value=0.4, step = 0.05
            ),
        ], width = 2),
        dbc.Col([html.Div([
            html.H3('Word Cloud'),
            html.Div([],id='word_cloud')
        ], style= {'display': 'none'}, id = 'word_cloud_style')], width = 5)
         
    ]),
    ########### THIRD ROW #####################    
                     
    dbc.Row([
        dbc.Col([
                html.Div([
            html.Div([],id='dt')

        ], style= {'display': 'none'}, id = 'dt_input')], width = 8)
        ])])
    
    
   
    
        
#])    

  

######### Number of clusters ################    


@app.callback(    
    Output('cytoscape', 'stylesheet'),        
    Output('cytoscape','elements'),
    Output('cytoscape','layout'),
    Output('parallel_coord','figure'),
    Output('clust_sum_graph','figure'),
    Output('cluster_delete_failed_warning','is_open'),
    Output('number_of_topics_warning','is_open'),
    Output('delete_document_warning','is_open'),
    Output('cluster_merge_failed_warning','is_open'),
    Output('cluster_merge_checklist', 'value'),
    Output('cluster_merge_checklist', 'options'),
    Input('my-daq-slider-ex', 'value'),
    Input("update_button", "n_clicks"),
    Input('delete_button', "n_clicks"),
    Input('cytoscape','tapNodeData'),
    Input('cytoscape', 'selectedNodeData'),
    Input('pc_slider','value'),
    Input('change-probs', "n_clicks"),
    Input('clust_sum_graph','clickData'),
    Input('delete_cluster_button', "n_clicks"),
    Input('merge_cluster_button', "n_clicks"),
    Input("close-warning", "n_clicks"),
    Input("close_warning_num_of_cluster", "n_clicks"),
    Input("close_delete_document_warning", "n_clicks"),
    Input("close_cluster_merge_warning", "n_clicks"),
    State('parallel_coord','figure'),
    State("input1", "value"),
    State('cytoscape','stylesheet'),
    State('cytoscape', 'elements'),
    State('cytoscape', 'layout'),
    State('clust_sum_graph','figure'),
    State('clust_sum_graph','clickData'),
    State('cluster_delete_failed_warning','is_open'),
    State('number_of_topics_warning','is_open'),
    State('delete_document_warning','is_open'),
    State('cluster_merge_failed_warning','is_open'),
    State('cluster_merge_checklist','value'),
    State('cluster_merge_checklist', 'options'),

    prevent_initial_call = True
    
)
def update_graph(value_slider, update_n_button,delete_button, tapNodeData,selectedNodeData,
                 pc_slider,change_prob_button,clust_sum_data,delete_cluster_button,merge_cluster_button,clust_delete_warn_button,
                 close_warning_num_of_cluster,close_delete_document_warning,close_cluster_merge_warning,
                 pc_figure,cluster_number_value, stylesheet, elements,layout,clust_sum_graph, clust_sum_latest,
                 cluster_delete_failed_warning,number_of_topics_warning,delete_document_warning,cluster_merge_warning,
                 cluster_merge_checklist_vals,cluster_merge_checklist_opts):
    ctx = dash.callback_context
    figure = pc_figure
    clust_sum_figure = clust_sum_graph
    cluster_delete_warn_dialog = cluster_delete_failed_warning
    update_cluster_warn_dialog = number_of_topics_warning
    document_delete_warn_dialog = delete_document_warning
    close_cluster_merge_warning_dialog = cluster_merge_warning
    clicked_element = ctx.triggered[0]['prop_id'].split('.')[0]
    cluster_merge_checklist_values =  cluster_merge_checklist_vals
    cluster_merge_checklist_options = cluster_merge_checklist_opts

        
    values = []

    

    #################### cosine similarity ############################
    if  clicked_element == 'my-daq-slider-ex':
        slider_value =  ctx.triggered[0]['value']
        lda.set_cosine_sim_treshold(slider_value)
        ## need to update the elements of the graph
        elements = get_graph_topic_nodes() + get_graph_document_nodes() + get_graph_cos_sim_edges() + get_doc_topic_edges()

    #################  number of clusters ##############################    
    elif clicked_element == 'update_button':
        if cluster_number_value is not None:
            ### elements need to be updated
            lda.set_number_of_clusters(cluster_number_value)
            lda.update_lda()
            elements = get_graph_topic_nodes() + get_graph_document_nodes() + get_graph_cos_sim_edges() + get_doc_topic_edges()
            ### layout of the graph needs to be updated
            layout = {'name': 'cose'}

            ### parallel coordinates plot needs to be updated
            ### according to the number of clusters
            figure=figure = go.Figure(data=
                go.Parcoords(
                    line = dict(color = lda.get_parall_coord_df()['Dominant_Topic'],
                               colorscale = list(lda.get_colors().values())[:lda.get_k()]),
                    dimensions = [
                        dict(range = [0,1],
                            label = f'Cluster {i}',
                            values = lda.get_parall_coord_df()[i])
                         for i in range(lda.get_k())
                    ]))

            ### update stylesheet
            stylesheet = update_stylesheet()

            #update cluster summary view
            update_clust_summary = build_cluster_summary_view()
            clust_sum_figure =  go.Figure(go.Treemap(labels  = update_clust_summary['labels'],
                                                         parents = update_clust_summary['parents'],
                                                         marker_colors =  update_clust_summary['marker_colors'],
                                                         text = update_clust_summary['text_info']),
                                              layout = go.Layout(margin={'t': 0, 'l': 0, 'r': 15, 'b' : 2}))

            #Reset checkboxes
            cluster_merge_checklist_values = None
            cluster_merge_checklist_options = build_cluster_merge_list()
                                                   

        else: #show warning dialog if the input value is out of the limit
            update_cluster_warn_dialog = True
    ################### highlight rows #################################    
    # the choosen document will be highlighted on the parallel coordinates plot
    elif clicked_element == 'cytoscape' and tapNodeData is not None: 
        ## recolor the choosen line
        color_list = list(lda.get_colors().values())[:lda.get_k()]
        color_list.append('black')
        df = lda.get_parall_coord_df()
        df = df.reset_index()
        df['index'] = range(1, len(df) + 1)
        df.set_index('index')
        df.loc[tapNodeData['id'],'Dominant_Topic'] = lda.get_k()
        
        figure= go.Figure(data= 
            go.Parcoords(
                line = dict(color = df['Dominant_Topic'],
                           colorscale = color_list),
                dimensions = [                       
                    dict(range = [0,1],
                        label = f'Cluster {i}', 
                        values = df[i])
                     for i in range(lda.get_k())                      
                ]))
        
    ################## delete documents #################################
    elif clicked_element == 'delete_button':
        #If no node selected before, the input variable should be None, we have to handle it
        if tapNodeData is not None and 'Cluster' not in tapNodeData['id']:
            lda.remove_document(int(tapNodeData['id']))
            elements = get_graph_topic_nodes() + get_graph_document_nodes() + get_graph_cos_sim_edges() + get_doc_topic_edges()

        else: #Show warning dialog
            document_delete_warn_dialog = True
            

    ######### Filter the paralell coordinates by the given threshold #####
    elif clicked_element == 'pc_slider':
            #print('pc_slider triggered')
            lda.filter_parall_coords_topic_contribution(ctx.triggered[0]['value'])
            #parall_coord_input = lda.get_filtered_topics_df()
            #filtered_parall_coords = lda.get_parall_coord_df().loc[lda.get_filtered_topics_df()['Title']]
            filtered_parall_coords = lda.get_filtered_parall_coords_df()
            figure = go.Figure(data= go.Parcoords(
                line = dict(color = filtered_parall_coords['Dominant_Topic'],
                           colorscale = list(lda.get_colors().values())[:lda.get_k()]),
                dimensions = [                       
                    dict(range = [0,1],
                        label = f'Cluster {i}', 
                        values = filtered_parall_coords[i])
                     for i in range(lda.get_k())                      
                ]))
    elif clicked_element == 'clust_sum_graph':
        #print('clust_sum_graph')
        print(clust_sum_data)
        #print(clust_sum_data['points'][0]['label'])
        lda.set_last_selected_cluster_from_clust_sum_view(clust_sum_data['points'][0]['label'])
        #set currently selected cluster from summary view in class aas actual in order to delete
        
    elif clicked_element == 'delete_cluster_button':
        if clust_sum_latest is None:
            #show the warning that no cluster selected before in the cluster summary view
            cluster_delete_warn_dialog = True
        else:            
            #print(clust_sum_latest['points'][0]['label'])
            lda.delete_cluster()
            elements = get_graph_topic_nodes() + get_graph_document_nodes() + get_graph_cos_sim_edges() + get_doc_topic_edges()
            ### layout of the graph needs to be updated
            layout = {'name': 'cose'}

            ### parallel coordinates plot needs to be updated
            ### according to the number of clusters
            figure=figure = go.Figure(data=
                go.Parcoords(
                    line = dict(color = lda.get_parall_coord_df()['Dominant_Topic'],
                               colorscale = list(lda.get_colors().values())[:lda.get_k()]),
                    dimensions = [
                        dict(range = [0,1],
                            label = f'Cluster {i}',
                            values = lda.get_parall_coord_df()[i])
                         for i in range(lda.get_k())
                    ]))

            ### update stylesheet
            stylesheet = update_stylesheet()

            #update cluster summary view
            update_clust_summary = build_cluster_summary_view()
            clust_sum_figure =  go.Figure(go.Treemap(labels  = update_clust_summary['labels'],
                                                     parents = update_clust_summary['parents'],
                                                     marker_colors =  update_clust_summary['marker_colors'],
                                                     text = update_clust_summary['text_info']),
                                          layout = go.Layout(margin={'t': 0, 'l': 0, 'r': 5, 'b' : 2}))
            
            #Reset checklist
            cluster_merge_checklist_values =  None
            cluster_merge_checklist_options = build_cluster_merge_list()
            
    elif clicked_element == "merge_cluster_button":
            if cluster_merge_checklist_vals is None or len(cluster_merge_checklist_vals) == 0:
                close_cluster_merge_warning_dialog = True
            else:
                print(cluster_merge_checklist_values)
                cluster_merge_checklist = dcc.Checklist(options = build_cluster_merge_list(),
                                        id = 'cluster_merge_checklist')
            

            
    elif clicked_element in ["close-warning","close_warning_num_of_cluster","close_delete_document_warning","close_cluster_merge_warning"]:
            if clicked_element == "close-warning" : cluster_delete_warn_dialog = False
            elif clicked_element == "close_warning_num_of_cluster": update_cluster_warn_dialog = False
            elif clicked_element == "close_cluster_merge_warning": close_cluster_merge_warning_dialog = False
            else: document_delete_warn_dialog = False

    #elif ctx.triggered[0]['prop_id'].split('.')[0] == 'change-probs':

            
    return [stylesheet, elements,layout, figure,clust_sum_figure,cluster_delete_warn_dialog,update_cluster_warn_dialog,document_delete_warn_dialog,close_cluster_merge_warning_dialog,cluster_merge_checklist_values,cluster_merge_checklist_options]

@app.callback(
    Output('empty-div', 'children'),
    Input('cytoscape', 'mouseoverNodeData'),
    Input('cytoscape','mouseoverEdgeData'),
    Input('cytoscape','tapEdgeData'),
    Input('cytoscape','tapNodeData'),
    Input('cytoscape','selectedNodeData')
)
def update_layout(mouse_on_node, mouse_on_edge, tap_edge, tap_node, snd):
    print("Mouse on Node: {}".format(mouse_on_node))
    print("Mouse on Edge: {}".format(mouse_on_edge))
    print("Tapped Edge: {}".format(tap_edge))
    print("Tapped Node: {}".format(tap_node))
    print("------------------------------------------------------------")
    print("All selected Nodes: {}".format(snd))
    print("------------------------------------------------------------")

    return 'see print statement for nodes and edges selected.'


# Highlight document words:
# https://www.machinelearningplus.com/nlp/topic-modeling-visualization-how-to-present-results-lda-models/#6.-What-is-the-Dominant-topic-and-its-percentage-contribution-in-each-document

############################  Term-weight view   #######################
@app.callback(
    Output('table', 'data'),
    Output('table', 'columns'),
    Output('bar_chart', 'children'),
    Output('change-prob-btn-div', 'style'),
    Output('term-weight-input', 'style'),  
    Output('word_cloud_style', "style"),
    Output('word_cloud', "children"),
    Input('cytoscape', 'tapNodeData'),
    Input('input_barchart', 'value'),
    Input('change-probs', "n_clicks"),
    State('table', 'columns'),  
    State('table', 'data'),
    prevent_initial_call = True
)
def update_barplot(tapNodeData,value, btn_change_prob,table_column,table_data):
    ctx = dash.callback_context
    style = {'display': 'block'}
    columns = [{'name':i,'id':i} for i in ['Words','Probabilities']]    
    clicked_element = ctx.triggered[0]['prop_id'].split('.')[0]

    ### if you klick on the button change probs
    if clicked_element == 'change-probs':
        json_data = json.dumps(table_data)
        changed_data = pd.read_json(json_data, orient = 'records').sort_values(by = "Probabilities",
                                                                            ascending = False)


        #print(lda.get_word_probabilities())
        #print("********************************")
        for i,d in changed_data.iterrows():
            lda.update_term_topic_weight(lda.get_last_selected_cluster(), d['Words'],d["Probabilities"])

        lda.term_prob_update_lda()
        #print("********************************")
        #print(lda.get_word_probabilities())

        lda.set_word_probabilities(changed_data)
        
        data = lda.get_top_n_word_probs_for_topic_i(lda.get_last_selected_cluster(), value).sort_values(by = "Probabilities",
                                                                            ascending = False)
        table_data = data.to_dict('records')

        ### the values in the data table need to be changed
        
        


        child = html.Div([
                dash_table.DataTable(
                                id='table',
                                data=table_data,
                                columns=columns,
                                editable = True,
                                style_cell={'width': '50px',
                                            'height': '30px',
                                            'textAlign': 'left'}
                                      )
                           ])
        

        ### the barplot needs to be changed
        bar_color = lda.get_colors()[lda.get_last_selected_cluster()]
        figure=html.Div([dcc.Graph(id='horizontal_bar_plot', 
                      style={'width': '150%', 'height': '400px'},
                      figure = px.bar(data.sort_values(by = "Probabilities",
                                                                            ascending = True), 
                        x = "Probabilities",
                        y = "Words",
                        color_discrete_sequence = [bar_color] * len(data),## color by cluster color
                        orientation='h'
                        ))])



        ''' IMPLEMENT WORD CLOUD'''
        data = go.Scatter(
                             x=[random.random() for i in random.choices(range(30), k=30)],
                             y=[random.random() for i in random.choices(range(30), k=30)],
                             mode='text',
                             text=lda.get_top_n_word_probs_for_topic_i(lda.get_last_selected_cluster(), 20).Words,
                             marker={'opacity': 0.3},
                             textfont={'size': lda.get_top_n_word_probs_for_topic_i(lda.get_last_selected_cluster(), 20).Probabilities * 2000,
                                   'color': lda.get_colors()[lda.get_last_selected_cluster()]})

        layout = go.Layout({"plot_bgcolor": "rgba(0, 0, 0,0)",'xaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False},
                                                'yaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False}})

        word_cloud = html.Div([dcc.Graph(id='word_cloud',
                      figure = go.Figure(data=[data], layout=layout)
                        )])
        
        
    #### if you click on a cluster node on the graph    
    elif 'Cluster' in tapNodeData['id']:
        cluster_id = int(tapNodeData['id'].replace('Cluster ',''))
        lda.set_last_selected_cluster(cluster_id)
        
        data = lda.get_top_n_word_probs_for_topic_i(cluster_id, value).sort_values(by = "Probabilities",
                                                                            ascending = False)
        table_data = data.to_dict('records')
        ## change table, according to the cluster choosen
        '''
        child = html.Div([
                dash_table.DataTable(
                                id='table',
                                data=table_data, 
                                columns=columns,
                                editable = True,
                                style_cell={'width': '50px',
                                            'height': '30px',
                                            'textAlign': 'left'}
                                      )
                           ])
        '''
        ## change the barplot, accoring to the cluster choosen
        bar_color = lda.get_colors()[lda.get_last_selected_cluster()]
        figure=html.Div([dcc.Graph(id='horizontal_bar_plot', 
                      style={'width': '150%', 'height': '400px'},
                      figure = px.bar(lda.get_top_n_word_probs_for_topic_i(cluster_id, value), 
                        x = "Probabilities",
                        y = "Words",
                        color_discrete_sequence = [bar_color] * len(data),
                        orientation='h'
                        ))
                    ])
        data = go.Scatter(
                             x=[random.random() for i in random.choices(range(30), k=30)],
                             y=[random.random() for i in random.choices(range(30), k=30)],
                             mode='text',
                             text=lda.get_top_n_word_probs_for_topic_i(lda.get_last_selected_cluster(), 20).Words,
                             marker={'opacity': 0.3},
                             textfont={'size': lda.get_top_n_word_probs_for_topic_i(lda.get_last_selected_cluster(), 20).Probabilities * 2000,
                                   'color': lda.get_colors()[lda.get_last_selected_cluster()]})

        layout = go.Layout({"plot_bgcolor": "rgba(0, 0, 0,0)",'xaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False},
                                            'yaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False}})

        word_cloud = html.Div([dcc.Graph(id='word_cloud_plot',
                      figure = go.Figure(data=[data], layout=layout)
                        )])

        
    else:
        return dash.no_update



    return [table_data,columns,figure, style, style, style, word_cloud]


    
    
######## Show the documents from a choosen cluster in the document view #########

@app.callback(
    Output('dt', 'children'),
    Output('dt_input', 'style'),
    Input('cytoscape', 'tapNodeData'),
    prevent_initial_call = True    
)
def update_result(tapNodeData):
    if tapNodeData is not None and 'Cluster' not in tapNodeData['id']:
        node_title = lda.get_document_title_by_id(int(tapNodeData['id']))
        data=lda.get_data()[lda.get_data()['title'] == node_title]
        doc_with_higlighted_terms = lda.build_term_higlights(data)
        child = html.Div(children = [dash_dangerously_set_inner_html.DangerouslySetInnerHTML(f'''<h2>{node_title}</h2>'''),
                        html.Div([dash_dangerously_set_inner_html.DangerouslySetInnerHTML(doc_with_higlighted_terms)])])
        style = {'display': 'block'}
    else: 
        return dash.no_update
    
    return [child,style] 
    
  
#figure.layout.update(showlegend=False)
app.run_server( port=8051, dev_tools_hot_reload=True)


Dash app running on http://127.0.0.1:8051/


In [27]:
import plotly
import plotly.graph_objs as go
from plotly.offline import plot
import random
data = go.Scatter(x=[random.random() for i in range(30)],
                 y=[random.random() for i in range(30)],
                 mode='text',
                 text=lda.get_top_n_word_probs_for_topic_i(2, 10).Words,
                 marker={'opacity': 0.3},
                 textfont={'size': lda.get_top_n_word_probs_for_topic_i(2, 10).Probabilities * 4000, 
                           'color': lda.get_colors()[2]})
layout = go.Layout({'xaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False},
                    'yaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False}})
fig = go.Figure(data=[data], layout=layout)

fig

In [28]:
cluster_id = 2
value = 20
data = go.Scatter(
                             x=[random.random() for i in range(30)],
                             y=[random.random() for i in range(30)],
                             mode='text',
                             text=lda.get_top_n_word_probs_for_topic_i(cluster_id, value).Words,
                             marker={'opacity': 0.3},
                             textfont={'size': lda.get_top_n_word_probs_for_topic_i(cluster_id, value).Probabilities * 4000, 
                                   'color': lda.get_colors()[cluster_id]})
layout = go.Layout({'xaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False},
                                            'yaxis': {'showgrid': False, 'showticklabels': False, 'zeroline': False}})
figure = go.Figure(data=[data], layout=layout)
figure

In [29]:
lda.get_top_n_word_probs_for_topic_i(2, 10).Probabilities * 1000

9     5.101448
8     5.779908
7     6.137368
6     6.290117
5     8.626731
4     8.884333
3     9.376108
2    11.908143
1    11.977700
0    12.757132
Name: Probabilities, dtype: float64

In [30]:
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
import dash_core_components as dcc
top_terms = lda.get_terms_with_best_topic().to_dict('records')

data=lda.get_data()[lda.get_data()['title'] == node_title].split(' ')
for d in data:
    lemmat_d = lemmatizer.lemmatize(d)
    color = 'black'
    if lemmat_d in top_terms.keys():
        color = top_terms[lemmat_d]



html.Div(dcc.Markdown('''Am I red yet?'''), style={'color':'red'})

#print(lda.get_terms_with_best_topic())

NameError: name 'node_title' is not defined

In [None]:
lda.lda_get_state().__dict__['sstats'] = 1

lda.lda_get_state().__dict__['sstats']
#.__dict__['eta']

In [None]:
lda.lda_get_state().__dict__

In [None]:

#lda.lda_get_state()['sstats']

lda.lda_get_state().__dict__

In [None]:
lda.lda_get_lda_model().__dict__

In [None]:
lda.lda_get_lda_model().__dict__['num_topics'] = 4

In [None]:
lda.lda_get_lda_model().__dict__['alp']

In [None]:
lda.lda_get_lda_model().__dict__

In [None]:
import matplotlib.pyplot as plt
plt.get_cmap('gist_rainbow')

cm =  colors.Normalize(vmin=0, vmax=NUM_COLORS-1)
scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)


In [None]:
import numpy as np
np.sum(lda.lda_get_state().__dict__['sstats'],axis= 1)

ee = np.sum(lda.lda_get_state().__dict__['sstats'],axis = 0)
print(ee)


In [None]:
lda.lda_get_state().__dict__['sstats']