<a href="https://colab.research.google.com/github/gitmystuff/DTSC5502/blob/main/Module_03-EDA_and_Visualization/Plotly_Advertising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Plotly Advertising Scatter Plot

Run this file

In [11]:
# plotly example of two distributions
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# from scipy.stats import pearsonr, ttest_ind

import plotly.graph_objects as go
from ipywidgets import widgets
from plotly.subplots import make_subplots

df = pd.read_csv('https://raw.githubusercontent.com/gitmystuff/Datasets/main/Advertising.csv', usecols=['TV', 'radio', 'newspaper', 'sales'])
X_train, X_test, y_train, y_test = train_test_split(
    df.drop('sales', axis=1),
    df['sales'],
    test_size=0.25,
    random_state=42)

scaler = StandardScaler()
scaler.fit(X_train)

X_train = pd.DataFrame(data=scaler.transform(X_train), columns=X_train.columns)
X_test = pd.DataFrame(data=scaler.transform(X_test), columns=X_test.columns)

tv_model = LinearRegression().fit(X_train[['TV']],y_train)
tv_preds = tv_model.predict(X_train[['TV']])

radio_model = LinearRegression().fit(X_train[['radio']],y_train)
radio_preds = radio_model.predict(X_train[['radio']])

paper_model = LinearRegression().fit(X_train[['newspaper']],y_train)
paper_preds = paper_model.predict(X_train[['newspaper']])

TV = go.Scatter(x=X_train['TV'], y=y_train, mode='markers', marker=dict(color='lightgray'), name='TV')
radio = go.Scatter(x=X_train['radio'], y=y_train, mode='markers', marker=dict(color='lightgray'), name='radio')
newspaper = go.Scatter(x=X_train['newspaper'], y=y_train, mode='markers', marker=dict(color='lightgray'), name='newspaper')

tv_trend = go.Scatter(x=X_train['TV'], y=tv_preds, mode='lines', marker=dict(color='hotpink'), name='TV trend line', visible=False)
radio_trend = go.Scatter(x=X_train['radio'], y=radio_preds, mode='lines', marker=dict(color='lime'), name='radio trend line', visible=False)
paper_trend = go.Scatter(x=X_train['newspaper'], y=paper_preds, mode='lines', marker=dict(color='cyan'), name='newspaper trend line', visible=False)

show_colors = widgets.Checkbox(
    description='Show Colors',
    value=False,
)

show_trends = widgets.Checkbox(
    description='Show Trends',
    value=False,
)

show_paper_color = widgets.Checkbox(
    description='Paper Color',
    value=False,
)

show_radio_color = widgets.Checkbox(
    description='Radio Color',
    value=False,
)

show_tv_color = widgets.Checkbox(
    description='TV Color',
    value=False,
)

container = widgets.HBox(children=[show_colors, show_trends])
container2 = widgets.HBox(children=[show_tv_color, show_radio_color, show_paper_color])

g1 = go.FigureWidget(data=[TV, radio, newspaper, tv_trend, radio_trend, paper_trend],
                    layout=go.Layout(
                        title=dict(
                            text='Advertising Scatterplot'
                        ),
                        width=800,
                        height=600,
                        xaxis_title='X',
                        yaxis_title='y'
                    ))

# def response(change):
#   with g1.batch_update():
#     if show_colors.value==True:
#       g1.data[0].marker=dict(color='red')
#       g1.data[1].marker=dict(color='springgreen')
#       g1.data[2].marker=dict(color='blue')
#     else:
#       g1.data[0].marker=dict(color='lightgray')
#       g1.data[1].marker=dict(color='lightgray')
#       g1.data[2].marker=dict(color='lightgray')

#     if show_trends.value==True:
#       g1.data[3].visible = True
#       g1.data[4].visible = True
#       g1.data[5].visible = True
#       g1.data[0].opacity = .8
#       g1.data[1].opacity = .6
#       g1.data[2].opacity = .2
#     else:
#       g1.data[3].visible = False
#       g1.data[4].visible = False
#       g1.data[5].visible = False
#       g1.data[0].opacity = 1
#       g1.data[1].opacity = 1
#       g1.data[2].opacity = 1

#     if show_tv_color.value==True:
#       g1.data[0].marker=dict(color='red')
#       g1.data[3].visible = True
#     else:
#       g1.data[0].marker=dict(color='lightgray')
#       g1.data[3].visible = False

#     if show_radio_color.value==True:
#       g1.data[1].marker=dict(color='springgreen')
#       g1.data[4].visible = True
#     else:
#       g1.data[1].marker=dict(color='lightgray')
#       g1.data[4].visible = False

#     if show_paper_color.value==True:
#       g1.data[2].marker=dict(color='blue')
#       g1.data[5].visible = True
#     else:
#       g1.data[2].marker=dict(color='lightgray')
#       g1.data[5].visible = False

def response(change):
    """
    Handles widget events to update the plot's appearance.
    This refactored version uses a mapping to avoid hardcoded indices
    and clarifies the logic for applying global and individual settings.
    """
    # --- Setup: Map indices to names and widgets for clarity ---
    # This makes the code readable by replacing g1.data[0] with a meaningful name.
    feature_map = {
        'TV': {
            'scatter_idx': 0,
            'trend_idx': 3,
            'widget': show_tv_color,
            'color': 'red',
            'opacity': 0.8
        },
        'radio': {
            'scatter_idx': 1,
            'trend_idx': 4,
            'widget': show_radio_color,
            'color': 'springgreen',
            'opacity': 0.6
        },
        'newspaper': {
            'scatter_idx': 2,
            'trend_idx': 5,
            'widget': show_paper_color,
            'color': 'blue',
            'opacity': 0.2
        }
    }

    with g1.batch_update():
        # --- 1. Apply Global Settings First ---
        # Determine the base settings from the global 'Show Colors' and 'Show Trends' checkboxes.
        global_colors_on = show_colors.value
        global_trends_on = show_trends.value

        # --- 2. Iterate Through Each Feature to Apply Logic ---
        # This loop handles all three features, avoiding repeated code.
        for config in feature_map.values():
            scatter_plot = g1.data[config['scatter_idx']]
            trend_line = g1.data[config['trend_idx']]

            # Check the state of the feature's individual checkbox
            individual_widget_on = config['widget'].value

            # --- Determine Final State (Individual checkboxes override global ones) ---
            # Color is specific if the individual box is checked OR the global color box is checked.
            is_colored = individual_widget_on or global_colors_on
            scatter_plot.marker.color = config['color'] if is_colored else 'lightgray'

            # Trend is visible if the individual box is checked OR the global trend box is checked.
            is_visible = individual_widget_on or global_trends_on
            trend_line.visible = is_visible

            # Opacity is reduced only when the global trends box is on.
            # Individual boxes don't affect the opacity of other plots in this logic.
            scatter_plot.opacity = config['opacity'] if global_trends_on else 1.0

show_colors.observe(response, names="value")
show_trends.observe(response, names="value")
show_paper_color.observe(response, names="value")
show_radio_color.observe(response, names="value")
show_tv_color.observe(response, names="value")

In [12]:
from google.colab import output
output.enable_custom_widget_manager()

Support for third party widgets will remain active for the duration of the session. To disable support:

In [13]:
from google.colab import output
output.disable_custom_widget_manager()

In [14]:
from google.colab import output
output.enable_custom_widget_manager()

In [15]:
widgets.VBox([g1, container2, container])

VBox(children=(FigureWidget({
    'data': [{'marker': {'color': 'lightgray'},
              'mode': 'markers',…