# Test some visualization and data transformsrun

## Import all packages

In [None]:
import pandas as pd
import plotly.express as px
## Dash
from dash import  dcc
import dash_bootstrap_components as dbc
import dash
from jupyter_dash import JupyterDash
from dash import html
import dash_html_components as html

import pyTigerGraph as tg

import dash_cytoscape as cyto



# Finally some TigerGraph

Before beginning, we must execute two queries which will help populate portions of our Graph. More specifically, running the ex2_main_query followed by the algo_louvain query creates referral edges used to link prescribers together. These edges are then used to detect communities of prescribers, which are useful in analyzing the relationships among healthcare providers, claims, and referrals.

To run these queries, we simply execute the following:


In [None]:
TG_HOST = "http://protomolecule.magichome"
TG_USERNAME = "tigergraph" # This should remain the same...
TG_PASSWORD = "Tigergraph"  # Shh, it's our password!
TG_GRAPHNAME = "HealthCareReferral" # The name of the graph
conn = tg.TigerGraphConnection(host=TG_HOST, username=TG_USERNAME, password=TG_PASSWORD, graphname=TG_GRAPHNAME)
# token = conn.getToken(conn.createSecret())
token = '25gofu2jjfbj5kaf891pn2ms7p9kved7'
conn.apiToken = token

print("Connected!")
# print(token)
conn.runInstalledQuery("ex2_main_query")
conn.runInstalledQuery("algo_louvain")

In [None]:
print("Vertices \n")

for vertex in conn.getVertexTypes():
  print(" " + vertex + ": " + repr(conn.getVertexCount(vertex)))


In [None]:
print("Edges \n")

for edge in conn.getEdgeTypes():
  print(" " + edge + ": " + repr(conn.getEdgeCount(edge)))


In [None]:
person_num = "pre78"

claims = conn.runInstalledQuery("getClaims", params={"inputPrescriber": person_num})[0]['claims']
pd.DataFrame.from_records(claims)


In [None]:
claims = conn.runInstalledQuery("getClaims", params={"inputPrescriber": person_num})[0]['claims']
title_map = {}; count_list = []; description_list = []

for number, claim in enumerate(claims):

  title = claim['attributes']['CodeGroupTitle']
  desc = claim['attributes']['ICD10CodeDescription']

  if desc == "":
    desc = "None provided!"

  if title in title_map:
    title_map[title] = title_map[title] + 1
  else:
    title_map[title] = 1

  count_list.append(number)
  description_list.append(desc)

In [None]:
titleList = []
countList = []

for entry in title_map:
  titleList.append(entry)
  countList.append(title_map[entry])

# Next, we create a bar chart using a DataFrame

countData = pd.DataFrame(data=(zip(titleList, countList)), columns=['Claim Title', 'Count'])
bar = px.bar(countData, x='Claim Title', y='Count', title='', color_discrete_sequence =["#DDEE00"]*len(countData))

bar.update_xaxes(type='category', categoryorder='category ascending')
bar.update_layout(margin=dict(l=1, r=1, t=1, b=1), template='ggplot2', xaxis_title=None)

bar.show()

In [None]:
descriptionData = pd.DataFrame(data=(zip(count_list, description_list)), columns=['Claim Num', 'Claim Description'])
header = [html.Thead(html.Tr([html.Th("Claim Number"), html.Th("Claim Description")]))]

table = html.Div(
    dbc.Table.from_dataframe(descriptionData, striped=True, bordered=True),
    style={'overflowY':'scroll', 'height':'450px'}
)


In [None]:
def getNetwork(person_num):
  comms = conn.runInstalledQuery("Print_community", params={"inputPrescriber": person_num})[1]['@@edgeList']

  vertices = {}
  els = []

  for entry in comms:
    source = entry['from_id']
    target = entry['to_id']
    if source not in vertices:
      if source == person_num:
        els.append({'data': {'id': source, 'label': source}, 'classes':'red'})
      else:
        els.append({'data': {'id': source, 'label': source}})
    if target not in vertices:
      els.append({'data': {'id': target, 'label': target}})

    els.append({'data': {'source': source, 'target': target}})


  network = cyto.Cytoscape(
                  id='cytoscape',
                  elements=els,
                  layout={'name': 'breadthfirst', 'padding':0, 'x1':-1000},
                  stylesheet= [
                      {
                          'selector': 'node',
                          'style': {
                              'content': 'data(label)'
                          }
                      },
                      {
                          'selector': '.red',
                          'style': {
                              'background-color': 'red',
                            }
                      }
                  ],

                style={'width': '100%', 'height': '500px', 'margin-left':0}
              )

  return network


## Claim_Titles = '''USE GRAPH HealthCareReferral
CREATE QUERY Claim_Titles() FOR GRAPH HealthCareReferral {
    MapAccum<STRING, SumAccum<INT>> @@allClaims;

    start = {Claim.*};

    claims = SELECT c
                FROM start:c
                ACCUM @@allClaims += (c.CodeGroupTitle -> 1);

    PRINT @@allClaims as freqClaims;
}
INSTALL QUERY Claim_Titles'''

print(conn.gsql(Claim_Titles, options=[]))


In [None]:
def getClaimsPieChart():
  res = conn.runInstalledQuery("Claim_Titles")[0]['freqClaims']

  claims = list(res.keys())
  counts = list(res.values())

  pie = px.pie(values=counts,
              names=claims,
              hole=0.2,
              )

  pie.update_layout(width=2000, title_x=0.5, showlegend=False, margin=dict(l=10, r=10, t=10, b=10))
  pie.update_traces(textposition='inside',
                    textinfo='label+percent',
                    )
  return pie


In [None]:
getClaimsPieChart()

In [None]:
Helper_NumMembers = '''USE GRAPH HealthCareReferral
CREATE QUERY Helper_NumMembers(vertex<Prescriber> inputPrescriber) FOR GRAPH HealthCareReferral
RETURNS (INT) {
  SumAccum<INT> @@numMembers;
    SumAccum<int> @@cid;

    Start={inputPrescriber};
    Start=Select s from Start:s post-accum @@cid += s.communityId;

    Start = {Prescriber.*};

    Start = select s from Start:s-(referral:e)-:t
            where s.communityId == @@cid and s.communityId == t.communityId
            accum @@numMembers += 1;

  RETURN @@numMembers;
}
INSTALL QUERY Helper_NumMembers'''

# print(conn.gsql(Helper_NumMembers, options=[]))

In [None]:
Helper_NumReferrals = '''USE GRAPH HealthCareReferral
CREATE QUERY Helper_NumReferrals(vertex<Prescriber> inputPrescriber) FOR GRAPH HealthCareReferral
RETURNS (INT) {

  SumAccum<INT> @@numReferrals;

    Start={inputPrescriber};

  referrals = SELECT p1
                FROM Start:p1 -(referral:r) -Prescriber:p2
                ACCUM @@numReferrals += 1;

  RETURN @@numReferrals;
}
INSTALL QUERY Helper_NumReferrals'''

# print(conn.gsql(Helper_NumReferrals, options=[]))


In [None]:
CommVsReferrals = '''USE GRAPH HealthCareReferral
CREATE QUERY CommVsReferrals() FOR GRAPH HealthCareReferral {
  MapAccum<INT, AvgAccum> @@commReferrals;

  start = {Prescriber.*};

  allPres = SELECT p
              FROM start:p
              ACCUM @@commReferrals += (Helper_NumMembers(p) -> Helper_NumReferrals(p));

  PRINT @@commReferrals as commReferrals;
}
INSTALL QUERY CommVsReferrals'''

# print(conn.gsql(CommVsReferrals, options=[]))


In [None]:
def getScatterChart():
  referrals = conn.runInstalledQuery("CommVsReferrals")[0]['commReferrals']

  sizes = list(referrals.keys())
  avgs = list(referrals.values())

  scatter = px.scatter(x=sizes, y=avgs, size=avgs, color=avgs)
  scatter.update_coloraxes(colorbar_title="Avg")

  scatter.update_layout(
      title = "Prescriber Community Size vs. Average Referrals per Prescriber",
      xaxis_title = "Prescriber Community Size",
      yaxis_title = "Avg Referrals per Prescriber",
      width=1000
  )
  return scatter


In [None]:
getScatterChart()

In [None]:
  titleCard =  dbc.Card([
                dbc.CardBody([
                              html.Center(html.H1("TigerGraph's HealthCare Starter Kit", className='card-title')),
                            ])
              ],
              color='light', # Options include: primary, secondary, info, success, warning, danger, light, dark
              style={
                  "width":"55rem",
                  #"margin-left":"1rem",
                  "margin-top":"1rem",
                  "margin-bottom":"1rem"
                  }
            )

vItems = [dbc.ListGroupItem(vertex + ": " + repr(conn.getVertexCount(vertex)), color='info') for vertex in conn.getVertexTypes()]
eItems = [dbc.ListGroupItem(edge + ": " + repr(conn.getEdgeCount(edge)), color='success') for edge in conn.getEdgeTypes()]

listItems = vItems + eItems

statsListGroup = dbc.ListGroup(
                      listItems,
                      horizontal=True
                  )


In [None]:
pieChart = getClaimsPieChart()
scatterChart = getScatterChart()

pieChartCard = dbc.Card([
                  dbc.CardBody([
                                html.H1("All Submitted Claims", className='card-title'),
                                html.P("Which categories of claims are most frequent?\n Which areas should prescribers focus on?", className='card-body'),  
                                dcc.Graph(id='Pie Chart', figure=pieChart)
                              ])
                ],
                outline=True,
                color='info', # Options include: primary, secondary, info, success, warning, danger, light, dark  
                style={
                    "width":"50rem",
                    "margin-right":"1rem",
                    "margin-bottom":"1rem"
                    }
              )

scatterChartCard = dbc.Card([
                  dbc.CardBody([
                                html.H1("Prescriber Communities", className='card-title'),
                                html.P("Do communities lead to more business? How do the number of referrals compare to community size?", className='card-body'),  
                                dcc.Graph(id='Scatter Chart', figure=scatterChart)
                              ])
                ],
                outline=True,
                color='info', # Options include: primary, secondary, info, success, warning, danger, light, dark  
                style={
                    "width":"50rem",
                    "margin-left":"1rem",
                    "margin-bottom":"1rem"
                    }
              )


In [None]:
def getClaims(person_num):

  claims = conn.runInstalledQuery("getClaims", params={"inputPrescriber": person_num})[0]['claims']
  title_map = {}; count_list = []; description_list = []

  for number, claim in enumerate(claims):

    title = claim['attributes']['CodeGroupTitle']
    desc = claim['attributes']['ICD10CodeDescription']

    if desc == "":
      desc = "None provided!"

    if title in title_map:
      title_map[title] = title_map[title] + 1
    else:
      title_map[title] = 1

    count_list.append(number)
    description_list.append(desc)


  # We'll create a table w/ the Descriptions!

  descriptionData = pd.DataFrame(data=(zip(count_list, description_list)), columns=['Claim Num', 'Claim Description'])
  header = [html.Thead(html.Tr([html.Th("Claim Number"), html.Th("Claim Description")]))]

  table = html.Div(
      dbc.Table.from_dataframe(descriptionData, striped=True, bordered=True),
      style={'overflowY':'scroll', 'height':'450px'}
  )

  # We'll create a bar chart w/ the Claim Titles

  titleList = []
  countList = []

  for entry in title_map:
    titleList.append(entry)
    countList.append(title_map[entry])

  countData = pd.DataFrame(data=(zip(titleList, countList)), columns=['Claim Title', 'Count'])
  bar = px.bar(countData, x='Claim Title', y='Count', title='', color_discrete_sequence =["#DDEE00"]*len(countData))

  bar.update_xaxes(type='category', categoryorder='category ascending')
  bar.update_layout(margin=dict(l=1, r=1, t=1, b=1), template='ggplot2', xaxis_title=None)

  max_key = max(title_map, key=title_map. get)

  return len(claims), table, bar, max_key


In [None]:
def getPrescriberInfo(person_num):
  network = getNetwork(person_num)
  number, table, bar, max_title = getClaims(person_num)

  prescriberTitleCard = dbc.Card([
          dbc.CardBody([
                        html.Center(dbc.Badge([html.H1("  Prescriber " + person_num + "'s Claims  ", className='card-title')], color="light")),
                        html.Center(html.P("This individual has a total of " + repr(number) + " claims. Their most referred to specialization is: " + max_title, className='card-body')),  
                      ])
        ],
        outline=True,
        color='info',
        style={
            "width":"98rem",
            "margin-left":"1rem",
            "margin-bottom":"1rem",
            "margin-top":"1rem"
            }
      )

  tableCard = dbc.Card([
          dbc.CardBody([
                        html.H1("Claims by Prescriber", className='card-title'),
                        html.P("A detailed description of each claim...", className='card-body'),  
                        table
                      ])
        ],
        outline=True,
        color='info',
        style={
            "width":"50rem",
            "margin-left":"1rem",
            "margin-bottom":"1rem",
            "margin-top":"1rem"
            }
      )

  barCard = dbc.Card([
          dbc.CardBody([
                        html.H1("Claims by Category", className='card-title'),
                        html.P("Which claims are being prescribed most?", className='card-body'),  
                        dcc.Graph(id='Bar Chart', figure=bar)
                      ])
        ],
        outline=True,
        color='info',
        style={
            "width":"50rem",
            "margin-left":"1rem",
            "margin-bottom":"1rem",
            "margin-top":"1rem"
            }
      )

  networkCard = dbc.Card([
          dbc.CardBody([
                        html.H1("Prescriber Network", className='card-title'),
                        html.P("Who's part of this prescriber's community?", className='card-body'),  
                        network
                      ])
        ],
        outline=True,
        color='info',
        style={
            "width":"50rem",
            "margin-left":"1rem",
            "margin-bottom":"1rem",
            "margin-top":"1rem"
            }
      )

  return prescriberTitleCard, tableCard, barCard, networkCard, network


In [None]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

person_num = "pre78"
prescriberTitleCard, tableCard, barCard, networkCard, network = getPrescriberInfo(person_num)

app.layout = html.Center(html.Div([

                dbc.Row(titleCard, justify="center"),
                dbc.Row(statsListGroup, justify="center"),
                html.Br(),

                dbc.Row([
                  pieChartCard,
                  scatterChartCard,
                ], justify="center"),

                html.Hr(),

                prescriberTitleCard,
                # tableCard,
                # network,
                # network,
                dbc.Row([
                  dbc.Col([
                      tableCard,
                      # network
                  ]),
                  dbc.Col([
                       barCard
                       # networkCard    
                  ])
                ], justify="center"),
                networkCard

             ]))

app.run_server(mode='external')  # 'external' or 'inline'
# app.run_server(mode='external', port=8555)
               #, dev_tools_ui=True, debug=True, dev_tools_hdev_tools_hot_reload=True, threaded=True)


In [None]:
network