In [5]:
import dash
from dash import dcc, html
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors

# Dash app initialization
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Constants
data_path = 'https://raw.githubusercontent.com/data-driven-prototypes-2023/s2_Olaf_Lesniak/main/openfoodfacts_yogurt_clean.csv'

# Data Loading
def load_data():
    df = pd.read_csv(data_path, index_col=0)
    df.columns = df.columns.str.strip()  # Remove any leading/trailing spaces
    df['labels'] = df['labels'].str.lower()  # Lowercase the label column
    return df

# Similarity Computation
def compute_similarity(product, recommendations):
    features = ['fat_value_per_100g', 'carbohydrates_value_per_g', 'sugars_value', 'proteins_value_per_100g', 'salt_value_in_grams', 'saturated_fat_value_per_100g', 'energy-kcal_value']
    scaler = StandardScaler()
    df_scaled = scaler.fit_transform(recommendations[features])
    product_scaled = scaler.transform(product[features])

    knn = NearestNeighbors(n_neighbors=min(5, len(recommendations)))
    knn.fit(df_scaled)
    distances, indices = knn.kneighbors(product_scaled)
    return distances, indices

# Plot creations (Line, Box, Cluster)
def create_line_plot(df, nutriment, y_axis_type, plot_type):
    if plot_type == 'line':
        fig_line = go.Figure()

        for grade in df['off:nutriscore_grade'].unique():
            subset = df[df['off:nutriscore_grade'] == grade]
            if not subset.empty:
                counts, bins = np.histogram(subset[nutriment], bins=50, density=(y_axis_type == 'density'))
                bins_center = 0.5 * (bins[1:] + bins[:-1])
                fig_line.add_trace(go.Scatter(x=bins_center, y=counts, mode='lines', name=f'Grade {grade}', line=dict(color=color_map.get(grade, 'blue'))))

        fig_line.update_layout(
            title=f'Distribution of {nutriment.replace("_", " ").capitalize()}',
            xaxis_title=nutriment.replace("_", " ").capitalize(),
            yaxis_title='Density' if y_axis_type == 'density' else 'Frequency',
            legend_title='Nutri-Score Grade'
        )
    else:
        fig_line = px.histogram(df, x=nutriment, color='off:nutriscore_grade', nbins=50,
                                title=f'Distribution of {nutriment.replace("_", " ").capitalize()}',
                                color_discrete_map=color_map, histnorm='probability density' if y_axis_type == 'density' else None)
        fig_line.update_layout(
            xaxis_title=nutriment.replace("_", " ").capitalize(),
            yaxis_title='Density' if y_axis_type == 'density' else 'Frequency',
            legend_title='Nutri-Score Grade'
        )
    return fig_line

def create_box_plot(df, nutriment):
    fig_box = px.box(df, x=nutriment, y='off:nutriscore_grade', color='off:nutriscore_grade',
                     title=f'{nutriment.replace("_", " ").capitalize()} by Nutri-Score Grade',
                     hover_data=['product_name_en'],
                     labels={nutriment: nutriment.replace("_", " ").capitalize(), 'off:nutriscore_grade': 'Nutri-Score Grade'},
                     color_discrete_map=color_map)
    return fig_box

def create_cluster_plot(df, x_cluster, y_cluster, size_cluster, cluster_size):
    if x_cluster and y_cluster:
        fig_cluster = px.scatter(df, x=x_cluster, y=y_cluster, color='off:nutriscore_grade',
                                 size=size_cluster if size_cluster != 'None' else None,
                                 hover_data=['product_name_en'],
                                 title='Product Clusters',
                                 labels={x_cluster: x_cluster.replace("_", " ").capitalize(),
                                         y_cluster: y_cluster.replace("_", " ").capitalize()},
                                 color_discrete_map=color_map)
        if size_cluster != 'None':
            fig_cluster.update_traces(marker=dict(sizemode='diameter', sizeref=2.*max(df[size_cluster])/cluster_size**2, sizemin=2))
    else:
        fig_cluster = {}
    return fig_cluster

# Color map for Nutri-Score grades (a to e useful for color coding)
color_map = {
    'a': 'green',
    'b': 'lightgreen',
    'c': 'yellow',
    'd': 'orange',
    'e': 'red'
}

# Dash Layout (User Interface)
app.layout = dbc.Container([
    dbc.Row([
        dbc.Col(html.H1("🥛 Open Food Facts Yogurt Analysis 🍦"), width=12, className='text-center my-4')
    ]),
    dbc.Row([
        dbc.Col(html.H3("📊 Interactive Plots and Features 📊"), width=12, className='text-center my-4')
    ]),
    dbc.Row([
        dbc.Col(html.P("Welcome to the Yogurt Analysis Dashboard! As an avid yogurt fan, I created this tool to help you find and compare different yogurt products. Here's a guide to what you can do with this dashboard:"), width=12, className='text-center my-4'),
        dbc.Col(html.Ul([
            html.Li("Select a nutrient (e.g., Fat, Carbohydrates) to visualize its distribution across different yogurt products."),
            html.Li("Filter products by brand or label to narrow down your analysis."),
            html.Li("Choose between line plots or histograms and toggle between frequency and density views."),
            html.Li("View box plots to compare nutrient levels across different Nutri-Score grades."),
            html.Li("Create scatter plots to visualize clusters of products based on two selected nutrients, with optional point sizing by a third nutrient."),
            html.Li("Adjust the size of points manually using the slider."),
            html.Li("Find healthier substitutes for your favorite yogurt products, complete with similarity scores and Nutri-Score improvements."),
            html.Li("Enjoy exploring the world of yogurts and finding healthier alternatives! 🥛🍦")
        ]), width=12)
    ]),
    dbc.Row([
        dbc.Col(dcc.Dropdown(
            id='nutriment-dropdown',
            options=[
                {'label': 'Fat 🥑', 'value': 'fat_value_per_100g'},
                {'label': 'Carbohydrates 🍞', 'value': 'carbohydrates_value_per_g'},
                {'label': 'Sugars 🍭', 'value': 'sugars_value'},
                {'label': 'Proteins 🍗', 'value': 'proteins_value_per_100g'},
                {'label': 'Salt 🧂', 'value': 'salt_value_in_grams'},
                {'label': 'Saturated Fat 🥓', 'value': 'saturated_fat_value_per_100g'},
                {'label': 'Energy ⚡', 'value': 'energy-kcal_value'}
            ],
            value='fat_value_per_100g',
            clearable=False,
            className='mb-4'
        ), width=4),
        dbc.Col(dcc.Dropdown(
            id='brand-dropdown',
            placeholder='Select a brand',
            options=[],
            multi=True,
            className='mb-4'
        ), width=4),
        dbc.Col(dbc.Input(
            id='label-input',
            placeholder='Type a label...',
            type='text',
            className='mb-4'
        ), width=4),
    ]),
    dbc.Row([
        dbc.Col(html.Div(id='label-count-output'), width=12, className='text-center my-2')
    ]),
    dbc.Row([
        dbc.Col(dcc.RadioItems(
            id='plot-type',
            options=[
                {'label': 'Line Plot 📈', 'value': 'line'},
                {'label': 'Histogram 📊', 'value': 'histogram'}
            ],
            value='line',
            labelStyle={'display': 'inline-block', 'margin-right': '10px'}
        ), width=6),
        dbc.Col(dcc.RadioItems(
            id='y-axis-type',
            options=[
                {'label': 'Frequency 📏', 'value': 'frequency'},
                {'label': 'Density 🌡️', 'value': 'density'}
            ],
            value='frequency',
            labelStyle={'display': 'inline-block', 'margin-right': '10px'}
        ), width=6)
    ], style={"marginTop": 20}),
    dbc.Row([
        dbc.Col(dbc.Button("Clear Selection 🧹", id='clear-button', color='primary', className='mb-4'), width=12)
    ]),
    dbc.Row([
        dbc.Col(html.H4("Nutrient Distribution Plot 📊"), width=12, className='text-center my-2'),
        dbc.Col(html.P("Select a nutrient and choose between line plots or histograms to visualize its distribution across different yogurt products. Toggle between frequency and density views. Note: Density view works only with line plots."), width=12, className='text-center my-2'),
        dbc.Col(dcc.Graph(id='line-plot'), width=12)
    ]),
    dbc.Row([
        dbc.Col(html.H4("Nutrient Comparison by Nutri-Score Grade 📊"), width=12, className='text-center my-2'),
        dbc.Col(html.P("View box plots to compare nutrient levels across different Nutri-Score grades. Hover over the points to see the product names."), width=12, className='text-center my-2'),
        dbc.Col(dcc.Graph(id='boxplot'), width=12)
    ]),
    dbc.Row([
        dbc.Col(html.H4("Product Clusters Scatter Plot 📈"), width=12, className='text-center my-2'),
        dbc.Col(html.P("Create scatter plots to visualize clusters of products based on two selected nutrients, with optional point sizing by a third nutrient. Adjust the size of points manually using the slider."), width=12, className='text-center my-2'),
        dbc.Col(dcc.Dropdown(
            id='x-cluster-dropdown',
            options=[
                {'label': 'Fat 🥑', 'value': 'fat_value_per_100g'},
                {'label': 'Carbohydrates 🍞', 'value': 'carbohydrates_value_per_g'},
                {'label': 'Sugars 🍭', 'value': 'sugars_value'},
                {'label': 'Proteins 🍗', 'value': 'proteins_value_per_100g'},
                {'label': 'Salt 🧂', 'value': 'salt_value_in_grams'},
                {'label': 'Saturated Fat 🥓', 'value': 'saturated_fat_value_per_100g'},
                {'label': 'Energy ⚡', 'value': 'energy-kcal_value'}
            ],
            value='fat_value_per_100g',
            clearable=False,
            className='mb-4'
        ), width=4),
        dbc.Col(dcc.Dropdown(
            id='y-cluster-dropdown',
            options=[
                {'label': 'Fat 🥑', 'value': 'fat_value_per_100g'},
                {'label': 'Carbohydrates 🍞', 'value': 'carbohydrates_value_per_g'},
                {'label': 'Sugars 🍭', 'value': 'sugars_value'},
                {'label': 'Proteins 🍗', 'value': 'proteins_value_per_100g'},
                {'label': 'Salt 🧂', 'value': 'salt_value_in_grams'},
                {'label': 'Saturated Fat 🥓', 'value': 'saturated_fat_value_per_100g'},
                {'label': 'Energy ⚡', 'value': 'energy-kcal_value'}
            ],
            value='carbohydrates_value_per_g',
            clearable=False,
            className='mb-4'
        ), width=4),
        dbc.Col(dcc.Dropdown(
            id='size-cluster-dropdown',
            options=[
                {'label': 'None', 'value': 'None'},
                {'label': 'Fat 🥑', 'value': 'fat_value_per_100g'},
                {'label': 'Carbohydrates 🍞', 'value': 'carbohydrates_value_per_g'},
                {'label': 'Sugars 🍭', 'value': 'sugars_value'},
                {'label': 'Proteins 🍗', 'value': 'proteins_value_per_100g'},
                {'label': 'Salt 🧂', 'value': 'salt_value_in_grams'},
                {'label': 'Saturated Fat 🥓', 'value': 'saturated_fat_value_per_100g'},
                {'label': 'Energy ⚡', 'value': 'energy-kcal_value'}
            ],
            value='None',
            clearable=False,
            className='mb-4'
        ), width=4)
    ]),
    dbc.Row([
        dbc.Col(dcc.Slider(
            id='cluster-size-slider',
            min=1,
            max=20,
            step=1,
            value=4,
            marks={i: str(i) for i in range(1, 21)},
            className='mb-4'
        ), width=12)
    ]),
    dbc.Row([
        dbc.Col(dcc.Graph(id='cluster-plot'), width=12)
    ]),
    dbc.Row([
        dbc.Col(html.H4("Find Healthier Substitutes 🍦"), width=12, className='text-center my-2'),
        dbc.Col(html.P("Select a product to find healthier substitutes, complete with similarity scores and Nutri-Score improvements."), width=12, className='text-center my-2'),
        dbc.Col(dcc.Dropdown(
            id='product-dropdown',
            placeholder='Select a product',
            options=[],
            clearable=False,
            className='mb-4'
        ), width=12)
    ]),
    dbc.Row([
        dbc.Col(html.Div(id='selected-product-details'), width=12)
    ]),
    dbc.Row([
        dbc.Col(html.Div(id='substitute-output'), width=12)
    ])
], fluid=True)

@app.callback(
    [Output('brand-dropdown', 'options'),
     Output('label-count-output', 'children'),
     Output('product-dropdown', 'options')],
    [Input('label-input', 'value'),
     Input('nutriment-dropdown', 'value')]
)
def set_brand_options(label, nutriment):
    df = load_data()
    
    # Filter by label if provided
    if label:
        df = df[df['labels'].str.contains(label, case=False, na=False)]
        label_count = df['labels'].str.contains(label, case=False, na=False).sum()
        label_output = f"Label '{label}' appears in {label_count} products"
    else:
        label_output = ""

    # Get brand counts and sort them
    brand_counts = df.groupby('brands')['product_name_en'].nunique().reset_index()
    brand_counts = brand_counts.sort_values(by='product_name_en', ascending=False)
    brand_counts['label'] = brand_counts.apply(lambda row: f"{row['brands']} ({row['product_name_en']} products)", axis=1)
    brands = [{'label': row['label'], 'value': row['brands']} for _, row in brand_counts.iterrows()]

    products = [{'label': row['product_name_en'], 'value': row['product_name_en']} for _, row in df.iterrows()]
    
    return brands, label_output, products

@app.callback(
    [Output('line-plot', 'figure'),
     Output('boxplot', 'figure'),
     Output('cluster-plot', 'figure')],
    [Input('nutriment-dropdown', 'value'),
     Input('brand-dropdown', 'value'),
     Input('label-input', 'value'),
     Input('plot-type', 'value'),
     Input('y-axis-type', 'value'),
     Input('x-cluster-dropdown', 'value'),
     Input('y-cluster-dropdown', 'value'),
     Input('size-cluster-dropdown', 'value'),
     Input('cluster-size-slider', 'value')]
)
def update_plots(nutriment, selected_brands, label, plot_type, y_axis_type, x_cluster, y_cluster, size_cluster, cluster_size): #update plots for increased interactivity with plotly
    df = load_data()

    # Handle the case where selected_brands might be None or empty
    if selected_brands:
        df = df[df['brands'].isin(selected_brands)]
    
    # Handle the case where label might be None or empty
    if label:
        df = df[df['labels'].str.contains(label, case=False, na=False)]
    
    # Return empty plots if no data
    if nutriment not in df.columns or df.empty:
        return {}, {}, {}

    fig_line = create_line_plot(df, nutriment, y_axis_type, plot_type)
    fig_box = create_box_plot(df, nutriment)
    fig_cluster = create_cluster_plot(df, x_cluster, y_cluster, size_cluster, cluster_size)

    return fig_line, fig_box, fig_cluster

@app.callback(
    [Output('selected-product-details', 'children'),
     Output('substitute-output', 'children')],
    [Input('product-dropdown', 'value')]
)
def find_substitute(selected_product):
    if not selected_product:
        return "Please select a product to find substitutes.", ""

    df = load_data()

    product = df[df['product_name_en'] == selected_product]

    if product.empty:
        return "Selected product not found in the database.", ""

    # Display selected product details in a styled table (HTML format)
    selected_product_details = html.Table([
        html.Thead([
            html.Tr([html.Th("Product"), html.Th("Brand"), html.Th("Fat 🥑"), html.Th("Carbs 🍞"), html.Th("Sugars 🍭"), html.Th("Proteins 🍗"), html.Th("Salt 🧂"), html.Th("Saturated Fat 🥓"), html.Th("Energy ⚡"), html.Th("Nutri-Score")])
        ]),
        html.Tbody([
            html.Tr([
                html.Td(selected_product),
                html.Td(product['brands'].values[0]),
                html.Td(f"{product['fat_value_per_100g'].values[0]}g"),
                html.Td(f"{product['carbohydrates_value_per_g'].values[0]}g"),
                html.Td(f"{product['sugars_value'].values[0]}g"),
                html.Td(f"{product['proteins_value_per_100g'].values[0]}g"),
                html.Td(f"{product['salt_value_in_grams'].values[0]}g"),
                html.Td(f"{product['saturated_fat_value_per_100g'].values[0]}g"),
                html.Td(f"{product['energy-kcal_value'].values[0]} kcal"),
                html.Td(html.Span(f"{product['off:nutriscore_grade'].values[0].upper()}", style={'color': color_map[product['off:nutriscore_grade'].values[0]], 'font-size': '20px', 'text-align': 'center'}))
            ])
        ])
    ], style={'width': '100%', 'margin-bottom': '20px', 'border-collapse': 'collapse', 'text-align': 'center', 'font-family': 'Arial, sans-serif', 'border': '1px solid grey'})

    ## Intorduction of recommendation logic for substitute products (always recommend a higher grade product and then rank the top 5 by similarity score)
    # Extract the Nutri-Score grade of the selected product
    selected_nutriscore = product['off:nutriscore_grade'].values[0]

    # Filter out products with the same or lower Nutri-Score grade
    if selected_nutriscore != 'a':
        recommendations = df[df['off:nutriscore_grade'] < selected_nutriscore]
    else:
        recommendations = df[df['off:nutriscore_grade'].isin(['a', 'b'])]

    # Ensure recommendations do not include the selected product
    recommendations = recommendations[recommendations['product_name_en'] != selected_product]

    # In case of no recommendations found
    if recommendations.empty:
        return selected_product_details, "No healthier substitutes found."

    # Compute similarity
    distances, indices = compute_similarity(product, recommendations)

    substitute_list = []
    for i, idx in enumerate(indices[0]):
        if idx >= len(recommendations):
            continue
        rec = recommendations.iloc[idx]
        similarity_score = 1 / (1 + distances[0][i]) * 100  # Similarity score as a percentage

        # Generate arrow indicators for Nutri-Score improvement
        nutriscore_diff = ord(selected_nutriscore) - ord(rec['off:nutriscore_grade'])
        arrow = ''
        if nutriscore_diff == 1:
            arrow = '⬆️'
        elif nutriscore_diff > 1:
            arrow = '⬆️⬆️'

        substitute_list.append(html.Tr([
            html.Td(rec['product_name_en']),
            html.Td(rec['brands']),
            html.Td(f"{rec['fat_value_per_100g']}g"),
            html.Td(f"{rec['carbohydrates_value_per_g']}g"),
            html.Td(f"{rec['sugars_value']}g"),
            html.Td(f"{rec['proteins_value_per_100g']}g"),
            html.Td(f"{rec['salt_value_in_grams']}g"),
            html.Td(f"{rec['saturated_fat_value_per_100g']}g"),
            html.Td(f"{rec['energy-kcal_value']} kcal"),
            html.Td(html.Div([
                html.Span(f"{rec['off:nutriscore_grade'].upper()} {arrow}", style={'color': color_map[rec['off:nutriscore_grade']], 'font-size': '20px'}),
                html.Div(style={'height': '20px', 'width': f'{similarity_score:.1f}%', 'backgroundColor': 'lightblue', 'textAlign': 'center', 'lineHeight': '20px'}, children=f"{similarity_score:.1f}%")
            ], style={'display': 'flex', 'flexDirection': 'column', 'alignItems': 'center'}))
        ]))

    return selected_product_details, html.Table([
        html.Thead([
            html.Tr([html.Th("Substitute Product"), html.Th("Brand"), html.Th("Fat 🥑"), html.Th("Carbs 🍞"), html.Th("Sugars 🍭"), html.Th("Proteins 🍗"), html.Th("Salt 🧂"), html.Th("Saturated Fat 🥓"), html.Th("Energy ⚡"), html.Th("Nutri-Score & Similarity")])
        ]),
        html.Tbody(substitute_list)
    ], style={'width': '100%', 'border-collapse': 'collapse', 'text-align': 'center', 'font-family': 'Arial, sans-serif', 'border': '1px solid grey'})

@app.callback(
    [Output('brand-dropdown', 'value'),
     Output('label-input', 'value')],
    Input('clear-button', 'n_clicks'),
    [State('brand-dropdown', 'value'),
     State('label-input', 'value')]
)
def clear_selection(n_clicks, brand_value, label_value):
    if n_clicks:
        return [], ""
    return brand_value, label_value

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True, port=8051)