In [12]:
import networkx as nx
import plotly.graph_objects as go
import pandas as pd
import chart_studio.plotly as py
import plotly
import math
from addEdge import addEdge
from ast import literal_eval
import pickle


In [74]:
filepath = r"H:\Shared drives\Pandemic Data\slf_model\outputs\slf_scenarios"
header_path = os.path.join(filepath, "header.csv")
header = pd.read_csv(header_path)

countries_list = literal_eval(
                        header[
                            header.attributes.str.contains("starting_countries")
                        ].values[0, 2]
                    )

starting_countries = countries_list[0]
starting_countries

viz_data_path = os.path.join(filepath, "summary_data")

summary_data = pickle.load(open(os.path.join(viz_data_path, "summary_data.p"), "rb"))
def country_codes():
    # takes custom data file made from probability file
    names_data = pd.read_csv("country_names.csv")

    country_codes_dict = {}
    country_codes_dict["Origin"] = "ORG"
    for index, row in names_data.iterrows():

        country_codes_dict[row["NAME"]] = row["ISO3"]
    country_codes_dict["Taiwan"] = "TWN"
    return country_codes_dict

prop_dict = summary_data['BAU_alpha0.2_lamda3.64_6801-6804']['network']['prop_dict']
prop_countries = starting_countries
for country in prop_dict:
    if prop_dict[country] > 0:
        prop_countries.append(country)




In [75]:
G = summary_data['BAU_alpha0.2_lamda3.64_6801-6804']['network']['diGraph']
#prop50countries = summary_data['BAU_alpha0.2_lamda3.64_6801-6804']['cartographic']['data']['fi_prop50']['countries']

G = G.subgraph(prop_countries)
country_codes_dict = country_codes()
H = nx.restricted_view(G, ['China'], [])
for u,v in G.edges():
    G[u][v]['log_intros'] = math.log( G[u][v]['num_intros'])


In [76]:
cent = nx.degree_centrality(G)
t0 = []
t1 = []
t2 = []
t3 = []
for country in cent:
    if country in starting_countries:
        t0.append(country)
    elif cent[country] > 0.5:
       t1.append(country)
    elif cent[country] > 0.1 :
        t2.append(country)
    else:
        t3.append(country)


In [77]:
shells = [t0, t1,t2,t3]
pos = nx.shell_layout(G, shells)

In [82]:

x_pos = []
y_pos = []
iso_labels = []
hover_labels = []
for country in pos:
    x_pos.append(pos[country][0])
    y_pos.append(pos[country][1])
    iso_labels.append(country_codes_dict[country])
    G.nodes[country]['pos'] = pos[country]
    hover_labels.append(country + '<br> degree centrality: ' + str(cent[country]))


In [83]:
arrowangle = 15
edge_trace_list = []
edge_x = []
edge_y = []
for edge in G.edges():
        
    start = G.nodes[edge[0]]["pos"]
    
    end = G.nodes[edge[1]]["pos"]
    edge_x_pos = []
    edge_y_pos = []
    # edge_x, edge_y = addEdge(start, end, edge_x, edge_y, 1, 'end', .02, 6, 40)
    edge_x_pos, edge_y_pos = addEdge(
        start, end, edge_x_pos, edge_y_pos, 1, "end", 0.04, arrowangle, 50
    )
    #edge_x.extend(edge_x_pos)
    #edge_y.extend(edge_y_pos)
    
    edge_text = (
        edge[0] + ' to ' + edge[1]
    )

    edge_text_list = []
    for i in range(
                9
            ):  # each label needs to be duplicated 9 times for each of the 9 points of the drawn arrows from addEdge()
                edge_text_list.append(edge_text)

    trace = go.Scatter(  # creates a trace for each edge, appends to list to be drawn later)
            x=edge_x_pos,
            y=edge_y_pos,
            line=dict(width=2, color='white'),
            hoverinfo="text",
            text=(edge_text_list),
            mode="lines",
            opacity = 0.5
        )
    edge_trace_list.append(trace)

node_trace = go.Scatter(
        x=x_pos,
        y=y_pos,
        mode="markers",
        hoverinfo="text",
        hovertext=hover_labels,
        marker=dict(
            showscale=False,
            colorscale='aggrnyl',
            reversescale=False,
            color='blue',
            size=50,
            colorbar=dict(
                thickness=35,
                title="Probability of Introduction",
                titlefont=dict(color="white", size=14),
                tickfont=dict(color="white"),
                xanchor="left",
                titleside="right",
                bgcolor="#19191a",
            ),
            line_width=3,
            line_color='grey',
        ),
    )
    
fig = go.Figure(
        data=[],
        layout=go.Layout(
            plot_bgcolor="#19191a",
            paper_bgcolor="#19191a",
            titlefont_size=16,
            showlegend=False,
            hovermode="closest",
            margin=dict(b=0, l=0, r=0, t=0, pad=0),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        ),
    )




for trace in edge_trace_list:
    fig.add_trace(trace)
fig.add_trace(node_trace)



annotations = []
for (
        node
    ) in (
        G.nodes()
    ):  # sets color for node ISO code annotations, which float over each node
        if node != "Origin":
            x, y = G.nodes[node]["pos"]
            node_text = country_codes_dict[node]
            

            annotations.append(
                dict(
                    x=x,
                    y=y,
                    xref="x",
                    yref="y",
                    text=node_text,  # node name that will be displayed
                    xanchor="right",
                    xshift=15,
                    font=dict(color='white', size=12),
                    showarrow=False,
                    arrowhead=1,
                    ax=-10,
                    ay=-10,
                ),
            )
fig.update_layout(
        #height=850,  # sets fig size - could potentially be adaptive
        showlegend=False,
        annotations=annotations,  # shows iSO annotations
    )
fig
