In [19]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go

# Identical to Adam's answer
import plotly.colors
from PIL import ImageColor

pd.set_option('display.max_columns', None)


In [20]:
def get_unit_price_comparison_data(df, sorting_value='ratio_mini_lt_full'):
    '''
    Preprocessing required to compare mini and standard size products with one another 
    Args:
    Returns:
    '''
    # for each product, compare all mini size to standard using cross join
    df_compare = df[df['swatch_group']=='mini size'].merge(
        df[df['swatch_group']=='standard size'],
        on=['product_id','product_name','brand_name'],
        suffixes=('_mini','_standard')
    )
    # only calculate ratio in one direction 
    df_compare = df_compare[df_compare['amount_adj_mini']<df_compare['amount_adj_standard']]
    # if ratio < 1, mini is better value per oz, if ratio > 1, standard is better value
    df_compare['mini_to_standard_ratio'] = df_compare['unit_price_mini'] / df_compare['unit_price_standard']
    df_compare = df_compare.reset_index().rename(columns={'index':'prod_rank'})

    # df_compare = sort_product_comparison_data(df_compare, sorting_value)

    df_compare = df_compare.melt(['product_id','brand_name','product_name',
                                'prod_rank','amount_adj_mini', 'amount_adj_standard',
                                'mini_to_standard_ratio'])
    df_compare = df_compare[df_compare['variable'].isin(['unit_price_mini','unit_price_standard'])]
    df_compare = df_compare.merge(df, 
                    on=['product_id','brand_name','product_name'],
                    how='left')
    df_compare['pretty_ratio'] = df_compare['mini_to_standard_ratio'].round(2).astype(str)
    df_compare['display_name'] = df_compare['brand_name']+",<br>"+df_compare['lvl_2_cat']+" ("+df_compare['pretty_ratio']+")"
    return df_compare

In [21]:
df = pd.read_csv('../data/agg_prod_data.csv')

df = get_unit_price_comparison_data(df)

In [25]:
def get_color(colorscale_name, loc):
    from _plotly_utils.basevalidators import ColorscaleValidator
    # first parameter: Name of the property being validated
    # second parameter: a string, doesn't really matter in our use case
    cv = ColorscaleValidator("colorscale", "")
    # colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...] 
    colorscale = cv.validate_coerce(colorscale_name)
    
    if hasattr(loc, "__iter__"):
        return [get_continuous_color(colorscale, x) for x in loc]
    return get_continuous_color(colorscale, loc)
        
# This function allows you to retrieve colors from a continuous color scale
# by providing the name of the color scale, and the normalized location between 0 and 1
# Reference: https://stackoverflow.com/questions/62710057/access-color-from-plotly-color-scale

def get_continuous_color(colorscale, intermed):
    """
    Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
    color for any value in that range.

    Plotly doesn't make the colorscales directly accessible in a common format.
    Some are ready to use:
    
        colorscale = plotly.colors.PLOTLY_SCALES["Greens"]

    Others are just swatches that need to be constructed into a colorscale:

        viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
        colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)

    :param colorscale: A plotly continuous colorscale defined with RGB string colors.
    :param intermed: value in the range [0, 1]
    :return: color in rgb string format
    :rtype: str
    """
    if len(colorscale) < 1:
        raise ValueError("colorscale must have at least one color")

    hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))

    if intermed <= 0 or len(colorscale) == 1:
        c = colorscale[0][1]
        return c if c[0] != "#" else hex_to_rgb(c)
    if intermed >= 1:
        c = colorscale[-1][1]
        return c if c[0] != "#" else hex_to_rgb(c)

    for cutoff, color in colorscale:
        if intermed > cutoff:
            low_cutoff, low_color = cutoff, color
        else:
            high_cutoff, high_color = cutoff, color
            break

    if (low_color[0] == "#") or (high_color[0] == "#"):
        # some color scale names (such as cividis) returns:
        # [[loc1, "hex1"], [loc2, "hex2"], ...]
        low_color = hex_to_rgb(low_color)
        high_color = hex_to_rgb(high_color)

    return plotly.colors.find_intermediate_color(
        lowcolor=low_color,
        highcolor=high_color,
        intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
        colortype="rgb",
    )

In [251]:
# colorbar theme - this will be in global styling 
COLOUR_SCALE='plotly3_r'


# Create subplots, left pair plot, right scatter - both sharing colour bar
import plotly.subplots as sp


def normalize_colour_value(data_point, all_values):
    """
    Normalizes data point values for use with colour bar. 
    """
    return (data_point - min(all_values)) / (max(all_values) - min(all_values))


def create_pairplot_title(default_plot_title, brand_title_filter, category_title_filter):
    # in progress
    title = f'Top 10{" "+brand_val.title() if brand_val else ""}{" "+category_val.title() if category_val else ""} Products '
    return default_plot_title

def create_scatterplot_title(default_plot_title, brand_title_filter, category_title_filter):
     # inprogress
     title = f'Explore{" "+brand_val.title() if brand_val else ""}{" "+category_val.title() if category_val else ""} Products By Size And Price'
    return default_plot_title


def create_join_pair_scatter(df_product_pairs, df_base, brand_title_filter=None, category_title_filter=None):
    """
    Left subplot is a slope/pair plot of mini-standard product pairs
    Y = unit price
    colour = mini-to-standard unit price ratio
    initially, product and brand names were shown with pairs on subplot but
    this takes up too much space - all additional info has been moved to hover

    Right subsplot is all products in scatter plot
    X = size
    Y = price
    colour = sharing same ratio colour bar scale from left subplot
    this means only mini and standard size products will have non-grey colours
     4 product size categories will be shown with scatter shape - Refill, Value, Mini, Standard
    """

    pair_plot_title = create_pairplot_title("Unit Price Comparison Of Products", brand_title_filter, category_title_filter)
    scatter_plot_title = create_scatterplot_title("Explore Products By Size And Price", brand_title_filter, category_title_filter)

    fig = sp.make_subplots(rows=1, cols=2, column_widths=[0.4, 0.6],
        subplot_titles=(pair_plot_title, scatter_plot_title))
    
    # pair plot - need to draw twice to use continuous colour scale with line plot...
    # need to plot pairs traces one-by-one
    # this plot should be limited to 10 lines max - otherwise it is too difficult to read - influenced by sorting callback 
    for i, row in df_product_pairs.iterrows():

        colour_val_normalized = normalize_colour_value(row['mini_to_standard_ratio'], df_product_pairs['mini_to_standard_ratio'])
        tooltip_hover_template = '{}<br>{}<br>Size: {} oz.<br>Price: ${}<br>Category: {}<br>Mini-to-Standard Ratio: {:.2f}'

        pair_line_trace = go.Scatter(
            x=['Mini', 'Standard'],
            y=[row['unit_price_mini'], row['unit_price_standard']],
            mode='markers+lines',
            marker=dict(
                color=get_color(COLOUR_SCALE, colour_val_normalized),
                colorscale=COLOUR_SCALE,
                showscale=False,
                cmin=min(df_product_pairs['mini_to_standard_ratio']),
                cmax=max(df_product_pairs['mini_to_standard_ratio']),
            ),  
            line=dict(
                width=5
            ),
            showlegend = False,
            hovertemplate = 'Unit Price: %{y:.2f}$/oz. <br>%{text}',  
            # each line is made of two markers, text = [marker_mini, marker_standard]
            text=[tooltip_hover_template.format(row['product_name'], row['brand_name'], row['amount_a_mini'], row['price_mini'], row['lvl_2_cat_mini'], row['mini_to_standard_ratio']),
                  tooltip_hover_template.format(row['product_name'], row['brand_name'], row['amount_a_standard'], row['price_standard'], row['lvl_2_cat_standard'], row['mini_to_standard_ratio'])],
        )

        fig.add_trace(pair_line_trace, row=1, col=1)

        fig.add_vline(x=0, line_width=1, line_dash="dash", line_color="grey", row=1, col=1)
        fig.add_vline(x=1, line_width=1, line_dash="dash", line_color="grey", row=1, col=1)

        fig.update_layout(
            xaxis=dict(
                title='Product Size',
                type='category',
                tickmode='array',
                # really difficult to get categorical axis spacing right
                range=[-0.2, 2 - 0.7],
                linecolor='rgb(204, 204, 204)',
            ),
            yaxis=dict(
                title='Unit Price ($/oz.)',
                showgrid=False,
                zeroline=True,
                showline=False,
                showticklabels=True,
            ),
            showlegend=False,
            plot_bgcolor='white',
        )


    # right side - scatter plot
    marker_shapes = {'mini size':'circle', 'standard size':'square', 'refill size':'diamond', 'value size':'cross'}
    
    # grey markers, no mini-to-standard ratio
    df_no_ratio = df_base[df_base['mini_to_standard_ratio'].isna()]

    background_scatter = go.Scatter(
        x=df_no_ratio['amount_a'],
        y=df_no_ratio['price'],
        mode="markers",
        marker=dict(
            color=['grey' for x in range(df_no_ratio.shape[0])],
            symbol=[marker_shapes[row['swatch_group']] for i, row in df_no_ratio.iterrows()],
        ),
        opacity=0.6,
        hovertemplate='Size: %{x}oz.<br>Price: $%{y}, %{text}',
        text=[tooltip_hover_template.format(row['product_name'], row['brand_name'], row['amount_a'], row['price'], row['lvl_2_cat'], row['mini_to_standard_ratio']) for i, row in df_no_ratio.iterrows()]
    )

    fig.add_trace(background_scatter, row=1, col=2)

    df_w_ratio = df_base[df_base['mini_to_standard_ratio'].notnull()]

    scatter_highlight = go.Scatter(
        x=df_w_ratio['amount_a'],
        y=df_w_ratio['price'],
        mode="markers",
        marker=dict(
            color=df_w_ratio['mini_to_standard_ratio'],
            colorbar=dict(
                title="Mini-to-Standard <br>Unit Price Ratio",
                thickness=25,
                bordercolor='white',
                outlinecolor='white',
                x=1,
                xref="container",
            ),
            colorscale=COLOUR_SCALE,
            cmin=min(df_base['mini_to_standard_ratio']),
            cmax=max(df_base['mini_to_standard_ratio']),
            symbol=[marker_shapes[x['swatch_group']] for i, x in df_w_ratio.iterrows()]
        ),
        hovertemplate='Size: %{x}oz.<br>Price: $%{y}, %{text}',
        text=[tooltip_hover_template.format(row['product_name'], row['brand_name'], row['amount_a'], row['price'], row['lvl_2_cat'], row['mini_to_standard_ratio']) for i, row in df_w_ratio.iterrows()],
    )

    fig.add_trace(scatter_highlight, row=1, col=2)
    
    fig.update_xaxes(title_text="Size (oz.)", row=1, col=2)
    fig.update_yaxes(title_text="Price ($)", row=1, col=2)

    return fig




df = pd.read_csv('../data/agg_prod_data.csv')
df_compare = df[df['swatch_group']=='mini size'].merge(
    df[df['swatch_group']=='standard size'],
    on=['product_id','product_name','brand_name'],
    suffixes=('_mini','_standard'))

df_compare = df_compare[df_compare['amount_adj_mini']<df_compare['amount_adj_standard']]
# if ratio < 1, mini is better value per oz, if ratio > 1, standard is better value
df_compare['mini_to_standard_ratio'] = df_compare['unit_price_mini'] / df_compare['unit_price_standard']
df_compare = df_compare.reset_index().rename(columns={'index':'prod_rank'})


df = df.merge(
    df_compare[['product_id','product_name','brand_name','mini_to_standard_ratio']],
    on=['product_id','product_name','brand_name'],
    how='left')

df.loc[df['swatch_group'].isin(['value size','refill size']), 'mini_to_standard_ratio'] = np.nan
fig = create_join_pair_scatter(df_compare[50:70], df)


In [252]:
fig.show()

In [255]:
# # put data on 0-1
# def min_max(X, min_value, max_value):

#     X_std = (X - X.min()) / (X.max() - X.min())
#     X_scaled = X_std * (max_value - min_value) + min_value
#     return X_scaled

# df_compare['minmax_mini_to_standard_ratio'] = df_compare['mini_to_standard_ratio'].apply(normalize_colour_value, df_compare['mini_to_standard_ratio'])

# df_compare['ratio_color'] = plotly.colors.sample_colorscale(plotly.colors.get_colorscale(COLOUR_SCALE), df_compare['minmax_mini_to_standard_ratio'])# low=0, high=13)



In [188]:
get_color(COLOUR_SCALE, 1)

'rgb(5, 8, 184)'

In [210]:

# X_scaled = scale * X + min - X.min(axis=0) * scale


In [211]:
df_compare['ratio_color'] 

0      rgb(254, 192, 254)
1      rgb(254, 190, 254)
2      rgb(254, 182, 254)
3      rgb(254, 180, 254)
4      rgb(254, 148, 252)
              ...        
994    rgb(254, 176, 253)
995        rgb(5, 8, 184)
996    rgb(254, 193, 254)
997    rgb(254, 191, 254)
998    rgb(254, 193, 254)
Name: ratio_color, Length: 999, dtype: object