In [4]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go

# Initialize the Dash app
app = dash.Dash(__name__)

# Define the Sankey diagram functions
def create_sankey_human():
    fig = go.Figure(data=[go.Sankey(
      node = dict(
        pad = 15,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = ["Cooperate", "Defect", "Cooperate", "Defect"],
        color = ["green", "red", "green", "red"]
      ),
      link = dict(
        source = [0, 0, 1, 1], # (push, push, pull, pull)
        target = [2, 3, 2, 3], # (push, pull, push, pull)
        value = [3821, 7686, 3635, 10584] # human
    ))])

    fig.update_layout(title_text='Human', font_size=24)
    return fig

def create_sankey_gpt3():
    fig = go.Figure(data=[go.Sankey(
      node = dict(
        pad = 15,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = ["Cooperate", "Defect", "Cooperate", "Defect"],
        color = ["green", "red", "green", "red"]
      ),
      link = dict(
        source = [0, 0, 1, 1], 
        target = [2, 3, 2, 3], 
        value = [7, 15, 4, 4] # gpt-3
    ))])

    fig.update_layout(title_text='GPT-3', font_size=24)
    return fig

def create_sankey_gpt4():
    fig = go.Figure(data=[go.Sankey(
      node = dict(
        pad = 15,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = ["Cooperate", "Defect", "Cooperate", "Defect"],
        color = ["green", "red", "green", "red"]
      ),
      link = dict(
        source = [0, 0, 1, 1], 
        target = [2, 3, 2, 3], 
        value = [0, 26, 1, 3] # gpt-4
    ))])

    fig.update_layout(title_text='GPT-4', font_size=24)
    return fig

def create_sankey_gpt4_five():
    fig = go.Figure(data=[go.Sankey(
        node = dict(
            pad = 15,
            thickness = 20,
            label = ["Cooperate", "Defect", "", "", "", "", "", "", "Cooperate", "Defect"],
            color = ["green", "red", "green", "red", "green", "red", "green", "red", "green", "red"],
        ),
        link = dict(
            source = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7],
            target = [2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, 8, 9],
            value = [0, 29, 1, 0, 0, 1, 25, 4, 17, 8, 4, 1, 17, 4, 8, 1] 
    ))])

    fig.update_layout(
    title_text='GPT-4',
    font_size=14,
    annotations=[
        dict(x=0, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Round 1"),
        dict(x=0.14, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Defect", font=dict(size=12, color='grey')),
        dict(x=0.25, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="2"),
        dict(x=0.38, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Defect", font=dict(size=12, color='grey')),
        dict(x=0.5, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="3"),
        dict(x=0.62, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Cooperate", font=dict(size=12, color='grey')),
        dict(x=0.755, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="4"),
        dict(x=0.93, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Cooperate", font=dict(size=12, color='grey')),
        dict(x=1, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="5")
    ],
)
    return fig

def create_sankey_gpt3_five():
    fig = go.Figure(data=[go.Sankey(
        node = dict(
            pad = 15,
            thickness = 20,
            line = dict(color = "black", width = 0.5),
            label = ["Cooperate", "Defect", "", "", "", "", "", "", "Cooperate", "Defect"],
            color = ["green", "red", "green", "red", "green", "red", "green", "red", "green", "red"],
        ),
        link = dict(
            source = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7],
            target = [2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, 8, 9],
            value = [6, 15, 1, 8, 5, 2, 9, 14, 10, 4, 5, 11, 12, 3, 2, 13]
    ))])
    
    fig.update_layout(
    title_text='GPT-3',
    font_size=14,
    annotations=[
        dict(x=0, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Round 1"),
        dict(x=0.14, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Defect", font=dict(size=12, color='grey')),
        dict(x=0.25, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="2"),
        dict(x=0.38, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Defect", font=dict(size=12, color='grey')),
        dict(x=0.5, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="3"),
        dict(x=0.62, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Cooperate", font=dict(size=12, color='grey')),
        dict(x=0.755, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="4"),
        dict(x=0.93, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="Cooperate", font=dict(size=12, color='grey')),
        dict(x=1, y=-0.2, xref='paper', yref='paper', showarrow=False, 
             text="5")
    ],
)
    return fig
# Combined App layout
app.layout = html.Div([
    html.H3("Sankey Diagrams"),
    dcc.Dropdown(
        id='graph-selector-human',
        options=[
            {'label': 'Human', 'value': 'Human'},
            {'label': 'GPT-3', 'value': 'GPT-3'},
            {'label': 'GPT-4', 'value': 'GPT-4'}
        ],
        value='Human',  # Default value
        style={'width': '48%', 'display': 'inline-block'}
    ),
    dcc.Dropdown(
        id='graph-selector-gpt',
        options=[
            {'label': 'GPT-4', 'value': 'GPT-4'},
            {'label': 'GPT-3', 'value': 'GPT-3'}
        ],
        value='GPT-4',  # Default value
        style={'width': '48%', 'float': 'right', 'display': 'inline-block'}
    ),
    html.Br(),
    html.A("Read the related article", href="https://www.pnas.org/doi/10.1073/pnas.2313925121"),
    dcc.Graph(id='sankey-graph-human'),
    dcc.Graph(id='sankey-graph-gpt'),
])


# Callback for Human/GPT-3/GPT-4 graph
@app.callback(
    Output('sankey-graph-human', 'figure'),
    [Input('graph-selector-human', 'value')]
)
def update_graph_human(selected_value):
    if selected_value == 'Human':
        return create_sankey_human()
    elif selected_value == 'GPT-3':
        return create_sankey_gpt3()
    else:
        return create_sankey_gpt4()
        

# Callback for GPT-3/GPT-4 graph
@app.callback(
    Output('sankey-graph-gpt', 'figure'),
    [Input('graph-selector-gpt', 'value')]
)
def update_graph_gpt(selected_value):
    if selected_value == 'GPT-4':
        return create_sankey_gpt4_five()
    else:
        return create_sankey_gpt3_five()

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True)
