In [234]:
import data_loader as dl
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np

geo_edf = dl.get_raw_data()

In [235]:
df_for_sankey = geo_edf[
    ['veh_type_1',  'veh_pt_1', 'next_veh_type_1', 'next_veh_pt_1', 'comb_weight']
].copy()

# filter out the rows that is not -1
df_for_sankey.fillna(-1, inplace=True)
df_for_sankey = df_for_sankey[
    (df_for_sankey['veh_type_1'] != -1) & (df_for_sankey['next_veh_type_1'] != -1) & (df_for_sankey['next_veh_pt_1'] != -1) & (df_for_sankey['veh_pt_1'] != -1)
    ].copy()

df_for_sankey[['veh_type_1',  'veh_pt_1', 'next_veh_type_1', 'next_veh_pt_1',]] = df_for_sankey[['veh_type_1',  'veh_pt_1', 'next_veh_type_1', 'next_veh_pt_1',]].copy().astype(int)


In [236]:
df_for_sankey['veh_type_1'] = "C_"+df_for_sankey['veh_type_1'].copy().astype(str) 
df_for_sankey['veh_pt_1'] = "C_"+df_for_sankey['veh_pt_1'].copy().astype(str) 
df_for_sankey['next_veh_type_1'] = "N_"+df_for_sankey['next_veh_type_1'].copy().astype(str)
df_for_sankey['next_veh_pt_1'] = "N_"+df_for_sankey['next_veh_pt_1'].copy().astype(str)



In [257]:
df_for_sankey['source'] = df_for_sankey['veh_pt_1'] + "_" + df_for_sankey['veh_type_1']
df_for_sankey['target'] = df_for_sankey['next_veh_pt_1'] + "_" + df_for_sankey['next_veh_type_1']
df_x = df_for_sankey.groupby(['source', 'target']).sum('comb_weight').reset_index()


# df_x = df_x[df_x['comb_weight'] > 20].copy()


In [258]:

def prepare_df_sankey(df, source, target):
    '''
        be sure to `sort` the df by source and then by target
    '''
    label_list = df[source].unique().tolist() + df[target].unique().tolist()

    df['source_index'] = df[source].copy().apply(lambda x: label_list.index(x))
    df['target_index'] = df[target].copy().apply(lambda x: label_list.index(x))

    return df, label_list

df, label_list = prepare_df_sankey(df_x, 'source', 'target')


In [259]:
curr_veh_pt_dict = {1: "ICE", 2: "BEV", 3: "HEV", 4: "PHEV"}
curr_veh_type_dict = {1: "Small Car", 2: "Large Car ", 3: "Pickup Truck", 4: "Other"}

next_veh_pt_dict = {1: "Other", 2: "BEV", 3: "HEV", 4: "PHEV",  5: "ICE",}
next_veh_type_dict = {1: "Small Car", 2: "Other", 3: "Pickup Truck", 4: "Large Car"}

curr_dict = {}
for i in curr_veh_pt_dict.keys():
    for j in curr_veh_type_dict.keys():
        curr_dict[f"C_{i}_C_{j}"] = f"{curr_veh_pt_dict[i]}__{curr_veh_type_dict[j]}"

next_dict = {}
for i in next_veh_pt_dict.keys():
    for j in next_veh_type_dict.keys():
        # print(f"C_{i}_C_{j} = {next_veh_pt_dict[i]}__{next_veh_type_dict[j]}")
        next_dict[f"N_{i}_N_{j}"] = f"{next_veh_pt_dict[i]}__{next_veh_type_dict[j]}"

df.loc[:, 'source_label'] = df['source'].map(curr_dict)
df.loc[:, 'target_label'] = df['target'].map(next_dict)

df

Unnamed: 0,source,target,comb_weight,source_index,target_index,source_label,target_label
0,C_1_C_1,N_1_N_1,207.266973,0,4,ICE__Small Car,Other__Small Car
1,C_1_C_1,N_1_N_2,87.483859,0,5,ICE__Small Car,Other__Other
2,C_1_C_1,N_1_N_3,22.267873,0,6,ICE__Small Car,Other__Pickup Truck
4,C_1_C_1,N_2_N_1,48.615027,0,7,ICE__Small Car,BEV__Small Car
8,C_1_C_1,N_3_N_1,51.074275,0,8,ICE__Small Car,HEV__Small Car
9,C_1_C_1,N_3_N_2,24.679626,0,9,ICE__Small Car,HEV__Other
12,C_1_C_1,N_4_N_1,47.850957,0,10,ICE__Small Car,PHEV__Small Car
13,C_1_C_1,N_4_N_2,20.785618,0,11,ICE__Small Car,PHEV__Other
20,C_1_C_2,N_1_N_1,47.745733,1,4,ICE__Large Car,Other__Small Car
21,C_1_C_2,N_1_N_2,167.944068,1,5,ICE__Large Car,Other__Other


In [260]:

values = df['comb_weight'].to_list()

label_list = df['source_label'].unique().tolist() + df['target_label'].unique().tolist()


fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=label_list,
        # color="blue",
        x=[0, 0, 0, 1, 1, 1, 1]  # Adjust x-axis position for each node
    ),
    link=dict(
        source=df['source_index'].to_list(),
        target=df['target_index'].to_list(),
        value=df['comb_weight'].to_list()
    )
)])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()


In [261]:
source = df['source'].to_list()
destinations = df['target'].to_list()
values = df['comb_weight'].to_list()

label_list = df['source_label'].unique().tolist() + df['target_label'].unique().tolist()


In [255]:
label_list

['ICE__Small Car',
 'ICE__Large Car ',
 'ICE__Pickup Truck',
 'HEV__Small Car',
 'Other__Small Car',
 'Other__Other',
 'Other__Pickup Truck',
 'BEV__Small Car',
 'HEV__Small Car',
 'HEV__Other',
 'PHEV__Small Car',
 'PHEV__Other',
 'Other__Large Car',
 'HEV__Large Car']

In [256]:
df['source_index'].to_list()

[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3]

In [246]:
source

['C_1_C_1',
 'C_1_C_1',
 'C_1_C_1',
 'C_1_C_1',
 'C_1_C_1',
 'C_1_C_1',
 'C_1_C_1',
 'C_1_C_1',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_2',
 'C_1_C_3',
 'C_1_C_3',
 'C_3_C_1']

In [263]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define your Sankey diagram data
# ...
# Replace the following placeholder data with your actual Sankey diagram data
source_nodes = [0, 1, 2, 3]
target_nodes = [4, 4, 5, 5]
values = [10, 15, 5, 20]
labels = ['Node 0', 'Node 1', 'Node 2', 'Node 3', 'Node 4', 'Node 5']

# Create the Sankey diagram trace
sankey_trace = go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=labels
    ),
    link=dict(
        source=source_nodes,
        target=target_nodes,
        value=values
    )
)

# Create the figure and add the Sankey trace
fig = go.Figure(data=[sankey_trace])

# Create the slider steps
steps = []
for i in range(len(values)):
    step = dict(
        method="restyle",
        args=["link.value", [0] * len(values)],
        label=f"Step {i + 1}"
    )
    step["args"][1][i] = values[i]
    steps.append(step)

# Add the slider to the layout
sliders = [dict(
    active=0,
    currentvalue={"prefix": "Step: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(sliders=sliders)

# Show the plot
fig.show()


In [265]:
import pandas as pd
import plotly.graph_objects as go
import plotly.offline as pyo
import plotly.tools as tls
from plotly.subplots import make_subplots
import numpy as np

# Create sample data
np.random.seed(0)
df = pd.DataFrame({'Category': ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'],
                   'Value': np.random.randint(0, 100, 9)})

# Create initial Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=df['Category'].unique()
    ),
    link=dict(
        source=[0, 1, 2, 0, 1, 2, 0, 1, 2],
        target=[3, 3, 3, 4, 4, 4, 5, 5, 5],
        value=df['Value']
    )
)])

# Create the slider
slider_steps = []
for i, category in enumerate(df['Category'].unique()):
    mask = df['Category'] == category
    filtered_values = df.loc[mask, 'Value']
    slider_step = dict(
        method='restyle',
        label=category,
        args=[{'link.value': [filtered_values.tolist()]}],
    )
    slider_steps.append(slider_step)

sliders = [dict(
    active=0,
    # pad={"t": 50},
    steps=slider_steps
)]

# Add slider to the layout
fig.update_layout(sliders=sliders)

# Show the plot
fig.show()
