In [75]:
import plotly.graph_objects as go
import json

# Load the JSON data
with open('sankey_data.json', 'r') as f:
    json_data = json.load(f)

# Initialize lists to store source, target, and value
sources = []
targets = []
values = []

# Extract data from the links dictionary
links_data = json_data['links']

# Loop through each key-value pair in the links dictionary
for key, value in links_data.items():
    if key.startswith('source'):
        # Extract source index from value list
        source_index = value[0]
        
        # Extract target index using corresponding target key
        target_key = f"target{key[6:]}"
        target_index = links_data[target_key][0]
        
        # Extract value count using corresponding value key
        value_key = f"value{key[6:]}"
        value_count = links_data[value_key][0]['count']
        
        # Append source, target, and value to their respective lists
        sources.append(source_index)
        targets.append(target_index)
        values.append(value_count)

nodes = json_data['nodes']

# Extract the first column from 'nodes'
first_column = [node['name'][0] for node in nodes]


In [76]:
# Fix indexing of sources and targets
sources = [source - 1 for source in sources]
targets = [target - 1 for target in targets]

In [119]:
hex_values = [
    "#1F77B4",  # Blue
    "#FF7F0E",  # Orange
    "#2CA02C",  # Green
    "#D62728",  # Red
    "#9467BD",  # Purple
    "#8C564B",  # Brown
    "#E377C2",  # Pink
    "#7F7F7F",  # Gray
    "#BEBD34",  # Olive
    "#8580ba",  # Ward 1
    "#ef9556",  # Ward 2
    "#bdc272",  # Ward 3
    "#ffa9b6",  # Ward 5
    "#2ee5a3",  # Ward 6
    "#f1fb79",  # Ward 7
    "#adaeae",  # Ward 8
    "#88c5cd"   # Ward 4
]

link_colors = [hex_values[source] for source in sources]

In [134]:
first_column[9] = "Ward 1: <br>(e.g./ Adams Morgan, Columbia Heights, Mount Pleasant)"
first_column[10] = "Ward 2: <br>(e.g./ Downtown, Georgetown, Logan Circle)"
first_column[11] = "Ward 3: <br>(e.g./ Adams Morgan, Columbia Heights, Mount Pleasant)"
first_column[12] = "Ward 5: <br>(e.g./ Brookland, Eckington, Fort Totten)"
first_column[13] = "Ward 6: <br>(e.g./ Capitol Hill, Penn Quarter, NoMa)"
first_column[14] = "Ward 7: <br>(e.g./ Benning, Hillcrest, Lincoln Heights)"
first_column[15] = "Ward 8: <br>(e.g./ Anacostia, Congress Heights, Woodland)"
first_column[16] = "Ward 4: <br>(e.g./ Brightwood, Manor Park, Takoma)"


fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=first_column,
        color=hex_values[:len(first_column)],  # Use the hex codes for the nodes
        hovertemplate='%{label} Total Cases'
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors,  # Set the colors of the links
        hovertemplate='%{source.label} → %{target.label}: Total cases'  # Add "cases" after the value label
    ))])

fig.update_layout(title_text="Customized Sankey Diagram", font_size=10)

# Add image to the layout
fig.add_layout_image(
    dict(
        source="./DC_Ward_Map_2020s.svg.png",  # Path to your image
        xref="paper",
        yref="paper",
        x=1,  # Adjust the position as needed
        y=0.5,
        sizex=1,  # Adjust the size of the image
        sizey=1,
        xanchor="left",
        yanchor="middle",
        opacity=0.8,
        layer="above"
    )
)

fig.update_layout(
    title_text="Customized Sankey Diagram",
    font_size=10,
    margin=dict(
        l=100,  # Left margin
        r=300,  # Right margin
        t=50,  # Top margin
        b=50   # Bottom margin
    )
)

fig.show()