In [None]:
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State

In [None]:
print('Loading data...')
input_dir = '../data/'
sell_price = pd.read_csv('%s/sell_prices.csv' % input_dir)
calendar = pd.read_csv('%s/calendar.csv' % input_dir)
train = pd.read_csv('%s/sales_train_evaluation.csv' % input_dir)
train['id'] = train['id'].str.replace('_evaluation', '', regex=False)

cat_cols = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
ts_cols = [col for col in train.columns if col not in cat_cols]
ts_dict = {t: int(t[2:]) for t in ts_cols}

for col in cat_cols:
    print('   N_unique %s: %i' % (col, train[col].nunique()))

In [None]:
def plot_time_series(data, ax=None, show=False):
    if ax is None:
        fig, ax = plt.subplots()
    for ind in data.index:
        ax.plot([int(col.split('_')[-1]) for col in data.columns], 
                data.loc[ind].values, '-', label=ind)
    ax.legend(loc='best')
    ax.set_xlabel('day number')
    ax.set_ylabel('items sold')
    if show:
        plt.show(block=False)

all_sales = pd.DataFrame(train[ts_cols].sum(axis=0)).transpose()
all_sales.index = ['all']
state_sales = train.groupby('state_id')[ts_cols].sum()
store_sales = train.groupby('store_id')[ts_cols].sum()
cat_sales = train.groupby('cat_id')[ts_cols].sum()
dept_sales = train.groupby('dept_id')[ts_cols].sum()
item_sales = train.groupby('item_id')[ts_cols].sum()
id_sales = train[ts_cols]
id_sales.index = train['id']

state_cat_sales = train.groupby(['state_id', 'cat_id'])[ts_cols].sum()
state_dept_sales = train.groupby(['state_id', 'cat_id'])[ts_cols].sum()
store_cat_sales = train.groupby(['store_id', 'cat_id'])[ts_cols].sum()
store_dept_sales = train.groupby(['store_id', 'dept_id'])[ts_cols].sum()

plot_time_series(all_sales)

In [7]:
def run_dash_app(dataframe):

    app = Dash(__name__)

    app.layout = html.Div([
        dcc.Dropdown(
            id='column-selector',
            options=[
                {'label': col, 'value': col}
                for col in ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
            ],
            multi=True, 
            placeholder='Select a column'
        ),
        html.Div([
            dcc.Dropdown(
                id='value-selector',
                multi=True,  
                placeholder='Select grouped values'
            ),
            dcc.Checklist(
                id='select-all',
                options=[{'label': 'Select All', 'value': 'all'}],
                inline=True
            ),
        ]),
        
        dcc.Graph(id='line-chart', config={'displayModeBar': True}),
    ])

    @app.callback(
        [Output('value-selector', 'options'),
         Output('value-selector', 'value')],
        [Input('column-selector', 'value'),
         Input('select-all', 'value')],
        State('value-selector', 'options')
    )
    def update_value_dropdown(selected_column, select_all, current_options):
        if not selected_column:
            return [], []
        
        grouped = dataframe.groupby(selected_column).size().reset_index()
        grouped['group'] = grouped[selected_column].astype(str).agg('_'.join, axis=1)
    
        options = [{'label': group, 'value': group} for group in grouped['group'].unique()]

        if select_all and 'all' in select_all:
            return options, [option['value'] for option in options]

        return options, []

    @app.callback(
        Output('line-chart', 'figure'),
        [Input('column-selector', 'value'),
         Input('value-selector', 'value')]
    )
    def update_chart(selected_column, selected_values):
        if not selected_column or not selected_values:
            return go.Figure()

        grouped = dataframe.groupby(selected_column).sum().reset_index()
        grouped['group'] = grouped[selected_column].astype(str).agg('_'.join, axis=1)
        filtered_dataframe = grouped[grouped['group'].isin(selected_values)]

        fig = go.Figure()
        for _, row in filtered_dataframe.iterrows():
            ts_columns = [col for col in dataframe.columns if col.startswith('d_')]
            fig.add_trace(
                go.Scatter(
                    x=[int(col.split('_')[-1]) for col in ts_columns],
                    y=row[ts_columns],
                    mode='lines',
                    name=row['group']
                )
            )

        fig.update_layout(xaxis_title="Days", yaxis_title="Items Sold")
        return fig

    app.run_server(debug=True)

run_dash_app(train)