In [3]:
import plotly.express as px
import pandas as pd
import numpy as np
import plotly.graph_objects as go

series_dict = pd.read_pickle('data/episodes_dict.pkl')
const_title_dict = pd.read_pickle('data/const_title_dict.pkl')

In [4]:
from dash import Dash, html, dcc, Output, Input
from plotly.subplots import make_subplots
import dash_bootstrap_components as dbc 
app = Dash(__name__,external_stylesheets=[dbc.themes.BOOTSTRAP])

dropdown_options = [{'label': name, 'value': tconst} for tconst, name in const_title_dict.items()]

# Define the layout of the app
app.layout = html.Div([
   html.Div( # header div
        style={'display': 'flex', 'padding': '20px', 'box-sizing': 'border-box'},
        children=[
            # Left Div 
            html.Div(
                style={'width': '60%', 'display': 'flex', 'alignItems': 'center', 'justifyContent': 'space-between'},
                children=[
                    # Sub Div for Title
                    html.Div(
                        children=[html.H1(id='tv-show-title', children="Select a TV show", style={'textAlign': 'left'})]
                    ), 
                    # Sub Div for Dropdown
                    html.Div(
                        children=[
                            dcc.Dropdown(
                                id='tv-show',
                                options=dropdown_options,
                                value=list(const_title_dict.keys())[0]
                            )
                        ],
                        style={'width': '50%'}  
                    )
                ]
            ),
        # Right Div 
            html.Div(
                style={'width': '40%'},
            )
        ]
    ),
    # Main content Div
    html.Div(
        id="content-div",
        children=[
            # Plot Div
            html.Div(
                id="plot-div",
                children=[
                    html.Div(
                        id="wh",
                    style={
                            'display': 'flex',
                        'flexDirection': 'row',  # Align children in a row
                        'justifyContent': 'flex-start',
                        'alignItems': 'center',# Distribute extra space evenly
                        'width': '100%',
                        # Take up the full width of the parent
                        },
                    children=[
                        html.Div(
            [
                html.Span("Heatmap", style={'marginRight': '10px','marginLeft': '10px'}),
                dbc.Checklist(
                    options=[{"label": "", "value": "toggle"}],
                    value=["toggle"],  # The switch starts in the on position
                    id="plot-toggle",
                    switch=True,
                    style={'marginTop':'5px'}
                ),
                html.Span("Line Plot", style={'marginLeft': '3px','marginRight': '15px'}),
            ],
            style={'flex': '0 0 auto', 'display': 'flex', 'alignItems': 'center'}
        ),
                html.Div(
                    id='season-break-container',
                    children=[
                        dcc.Dropdown(
                id='yaxis-column',
                options=[
                    {'label': 'Number of Votes', 'value': 'numVotes'},
                    {'label': 'Average Rating', 'value': 'averageRating'}
                ],
                value='numVotes',
                style={'width':'300px', 'alignItems': 'center','justifyContent': 'flex-start'},
            ),
            
                dbc.Checklist(
                    id='break-by-season',
                    options=[
                        {'label': 'Break by Seasons', 'value': 'by_season'}
                    ],
                    value=[],
                    style={'marginRight':'15px','marginLeft':'15px','marginTop':'5px'}),
                    ],
                    style={'flex': '1', 'display': 'none', 'justifyContent': 'flex-end', 'alignItems': 'center'})  # Allow the item to grow
                
                    ]
                ),
                    dcc.Graph(id='combined-graph')],
                style={'display': 'block', 'width': '60%', 'verticalAlign': 'top', 'padding-left': '100px'}
            ),
            # Controls Div
            html.Div(
                id="side-panel",
                children=[
                    # Upper Div
                    html.Div(
                        id="upper-side-div",
                        style={ 'height': '50%', 'textAlign': 'center','boxSizing':'border-box'}
                    ),
                    # Lower Div
                    html.Div(
                        id="lower-side-div",
                        style={ 'height': '50%', 'textAlign': 'center','boxSizing':'border-box'}
                    )
                ],
                style={'display': 'inline-block', 'width': '40%', 'verticalAlign': 'top'}
            ),
        ],
        style={'display':'flex','flex-direction': 'row', 'box-sizing': 'border-box','height': 'calc(100vh - 110px)'}
    ),
], style={'height': '100%', 'margin':'0', 'padding':'0','box-sizing': 'border-box'})



# Define callback to update graph
@app.callback(
    [Output('tv-show-title', 'children'),
     Output('combined-graph', 'figure'),
     Output('season-break-container', 'style')],
    [Input('tv-show', 'value'),
     Input('plot-toggle', 'value'),
     Input('yaxis-column', 'value'),
     Input('break-by-season', 'value')]
)
def update_graph(tv_show, toggle_value,yaxis_column_name, break_by_season):
    #take selected show dataframe
    show_df = series_dict[tv_show]
    if toggle_value: #plots for line plot
        checklist_style = {'display': 'flex'}
        labels={"averageRating": "Average Rating", "episode": "Episode", "numVotes": "Number of Votes"}
        fig = make_subplots(rows=2, cols=1, shared_xaxes=False, vertical_spacing=0.15, subplot_titles=(f'Episodes by {labels[yaxis_column_name]}', 'Top 5 and Bottom 5 Episodes'))
        if "by_season" in break_by_season:
            grouped = show_df.groupby('seasonNumber')
            for season, group in grouped:
                fig.add_trace(
                    go.Scatter(
                        x=group['normalizedEpisodeNumber'],
                        y=group[yaxis_column_name],
                        mode='lines+markers',
                        name=f'Season {season}',
                        hoverinfo='text',
                        hovertext=(
                            'Season: ' + group['seasonNumber'].astype(str) +
                            ', Episode: ' + group['episodeNumber'].astype(str)+
                            f'<br>{labels[yaxis_column_name]}: ' + group[yaxis_column_name].astype(str)
                            
                            ),
                    ),
                    row=1, col=1  
                    )
            fig.update_layout(
                xaxis=dict(
                    showgrid=False,  
                    zeroline=False  
                ))
            fig.update_xaxes(
                tickvals=[1, show_df['normalizedEpisodeNumber'].max()],
                ticktext=["Start Season", "End Season"],
                row=1,col=1
            )
        else:
            fig.add_trace(px.line(show_df,
                x = "episode",
                y=yaxis_column_name,
                custom_data = ["primaryTitle"],
                labels=labels,
                title=f'Episodes by {labels.get(yaxis_column_name, yaxis_column_name)}',
            ).update_traces(hovertemplate=None,line_color='navy').data[0],
            row=1, col=1
            )
            fig.update_traces(
                hovertemplate="<br>".join([f"<b>%{{x}} : %{{customdata[0]}}</b>", f"{labels.get(yaxis_column_name, yaxis_column_name)}:%{{y}}"]),
                textposition='top center',
            )
            fig.update_layout(
                xaxis=dict(
                    showgrid=False
                ),
            )
        # dataframe for bar chart    
        combined = pd.concat([show_df.nlargest(5, yaxis_column_name), show_df.nsmallest(5, yaxis_column_name).iloc[::-1]])
        combined['rank'] = list(range(1, 6)) + list(range(len(show_df) - 4, len(show_df) + 1))
        colorscale = px.colors.sequential.Viridis
        
        # Normalize y-axis values to a 0-1 scale to match the colorscale
        y_values = combined[yaxis_column_name]
        normalized_values = (y_values - y_values.min()) / (y_values.max() - y_values.min())
        colors = [colorscale[int(value * (len(colorscale) - 1))] for value in normalized_values]
        
        #add scatter plot points to the first graph
        fig.add_trace(
            go.Scatter(
            x=combined['episode'],  
            y=combined[yaxis_column_name],
            mode='markers',
            marker=dict(
            color=colors,
            size=10,
            showscale=False),
            hoverinfo='text',
            hovertext=(
                'Season: ' + combined['seasonNumber'].astype(str) +
                ', Episode: ' + combined['episodeNumber'].astype(str) +
                '<br>Average Rating: ' + combined['averageRating'].astype(str) +
                '<br>Num Votes: ' + combined['numVotes'].astype(str) 
            )
                
                ),
                row=1, col=1  
                    )
        
        fig.add_trace(
            go.Bar(
                x=combined['rank'],
                y=combined[yaxis_column_name],
                text=combined.apply(lambda row: f'S{row.seasonNumber} E{row.episodeNumber}', axis=1),
                name='Top 5 and Bottom 5 Episodes',
                marker_color=combined[yaxis_column_name],
                hoverinfo='text',
                hovertext=combined['primaryTitle'],
                marker=dict(colorscale='Viridis')
        ),
            row=2, col=1
                )
        fig.update_xaxes(title_text='Rank', row=2, col=1, type='category')

    # Update y-axis title based on selected metric
        fig.update_yaxes(title_text=yaxis_column_name, row=1, col=1)
        fig.update_yaxes(title_text=yaxis_column_name, row=2, col=1)
        fig.update_layout(height=700, showlegend=False,margin=dict(l=80, r=40, t=60, b=40))

    #heatmap
    else:
        checklist_style = {'display': 'none'}
        size_ref = 0 
        max_ep = max(show_df["episodeNumber"])
        max_s = max(show_df["seasonNumber"])
        #define size_ref for markers based on no. of episodes and seasons
        if max_s < 9 and max_ep < 15:
            size_ref = 40
        else:
            size_ref = 30
            
        show_df['marker_size'] = np.sqrt(show_df['numVotes'] / max(show_df['numVotes'])) * size_ref
        num_episodes = show_df['episodeNumber'].nunique()
        num_seasons = show_df['seasonNumber'].nunique()
        max_marker_size = np.sqrt(show_df['numVotes'].max()) * size_ref

        # Calculate the total height and widht of the figure based on max no. of seasons and episodes
        if max_s == 3:
            total_figure_height = num_seasons * 110
        elif max_s > 3 and max_s < 7:
            total_figure_height = num_seasons * 97
        else:
            total_figure_height =  num_seasons * 70 if size_ref == 30 else num_seasons * 85
        if max_ep < 20:
            total_figure_width =  num_episodes * 70 if size_ref == 30 else num_episodes * 90
        else:
            total_figure_width =  num_episodes * 68

        fig = go.Figure(layout=go.Layout(height=total_figure_height, width = total_figure_width))

        # Add the scatter plot for each data point
        fig.add_trace(go.Scatter(
            x=show_df['episodeNumber'],
            y=show_df['seasonNumber'],
            mode='markers',
            marker=dict(
                symbol='square',
                size=show_df['marker_size'],
                color=show_df['averageRating'],
                colorscale='rdylgn',
                colorbar=dict(title='Average <br> Rating'),
                showscale=True,
                cmin=3,
                cmax=10,
                opacity=0.7  # Adjust opacity for better visibility
            ),
            hoverinfo='text',
            hovertext=(
                'Season: ' + show_df['seasonNumber'].astype(str) +
                ', Episode: ' + show_df['episodeNumber'].astype(str) +
                '<br>Average Rating: ' + show_df['averageRating'].astype(str) +
                '<br>Num Votes: ' + show_df['numVotes'].astype(str) 
            )
        ))
            
        size_samples = [min(show_df['marker_size']), max(show_df['marker_size'])]
        
            
        # Update axes
        fig.update_yaxes(
            title='Season Number',
            tickmode='array',
            tickvals=sorted(show_df['seasonNumber'].unique()),
            type='category',
            autorange="reversed",
            showgrid = False
        )

        fig.update_xaxes(
            title='Episode Number',
            tickmode='array',
            tickvals=sorted(show_df['episodeNumber'].unique()),
            type='category',
            showgrid = False
        )
        fig.update_layout(title='Episodes by Rating and Number of Votes',title_x=0.5)
  
    return const_title_dict.get(tv_show), fig, checklist_style

# Run the app
if __name__ == '__main__':
    app.run_server(server=8008,jupyter_mode="external")

Dash app running on http://127.0.0.1:8050/
