# Understanding Linear Regression: Predicting Sales from Advertising Spend

Welcome! In this interactive session, we'll explore the fundamental concept of **linear regression**. We'll use a real-world dataset to understand how advertising spending on different platforms (TV, radio, newspaper) can help us predict sales.

Note: You’ll find the notebook for this screencast below the video, and at the end of this module, there’s a lab where you can try out what you’ve learned.





## A Simple Example

In [1]:
import plotly.express as px
import numpy as np

# Sample data
study_hours = np.array([1, 2, 3, 4, 5, 6, 7, 8])
test_scores = np.array([60, 70, 72, 76, 75, 82, 85, 88])

# Fit a simple linear regression line
m, b = np.polyfit(study_hours, test_scores, 1)

# Create the plot
fig = px.scatter(x=study_hours, y=test_scores, labels={'x': 'Study Hours', 'y': 'Test Scores'}, title='Simple Linear Regression: Study Hours vs. Test Scores')
fig.add_trace(px.line(x=study_hours, y=m * study_hours + b, labels={'x': 'Study Hours', 'y': 'Predicted Test Scores'}).data[0])
fig.update_traces(name='Regression Line', showlegend=True)
fig.show()

In the interactive plot above, the line represents the **linear regression line**, capturing the trend between study hours and test scores.

## Our Dataset: Advertising Spend vs. Sales

We'll be using a dataset showing advertising spend in different media and the corresponding sales. Let's download and load it.

In [2]:
# Imports
import pandas as pd
import plotly.express as px
from ipywidgets import interact, FloatSlider, widgets
from IPython.display import display, clear_output

In [3]:
# Download and unzip the dataset
!curl -L -o ./advertising-spend-vs-sales.zip \
    https://www.kaggle.com/api/v1/datasets/download/brsahan/advertising-spend-vs-sales
!unzip -o ./advertising-spend-vs-sales.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  1909  100  1909    0     0   1563      0  0:00:01  0:00:01 --:--:--  1563
Archive:  ./advertising-spend-vs-sales.zip
  inflating: Advertising.csv         


In [4]:
# Load the dataset
try:
    df = pd.read_csv('Advertising.csv')
except FileNotFoundError:
    print("Error: Advertising.csv not found. Please ensure the download and unzip were successful.")
    exit()

# Display the first few rows
print("First few rows of our dataset:")
print(df.head())

# Basic information about the dataset
print("\nDataset information:")
print(df.info())

# Summary statistics
print("\nSummary statistics:")
print(df.describe())

First few rows of our dataset:
      TV  radio  newspaper  sales
0  230.1   37.8       69.2   22.1
1   44.5   39.3       45.1   10.4
2   17.2   45.9       69.3    9.3
3  151.5   41.3       58.5   18.5
4  180.8   10.8       58.4   12.9

Dataset information:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 4 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   TV         200 non-null    float64
 1   radio      200 non-null    float64
 2   newspaper  200 non-null    float64
 3   sales      200 non-null    float64
dtypes: float64(4)
memory usage: 6.4 KB
None

Summary statistics:
               TV       radio   newspaper       sales
count  200.000000  200.000000  200.000000  200.000000
mean   147.042500   23.264000   30.554000   14.022500
std     85.854236   14.846809   21.778621    5.217457
min      0.700000    0.000000    0.300000    1.600000
25%     74.375000    9.975000   12.750000   10.375000
50%    14

Our dataset includes advertising spend on TV, Radio, and Newspaper, and the resulting Sales.

## Exploring the Relationships Visually

Let's use a scatter plot matrix to see the relationships between all pairs of variables.

In [5]:
# Visualize the relationship between each advertising medium and sales
fig = px.scatter_matrix(df, dimensions=['TV', 'radio', 'newspaper', 'sales'], title='Scatter Plot Matrix of Advertising Spend vs. Sales')
fig.show()

In [6]:
from google.colab import output
output.enable_custom_widget_manager()

In [7]:
from sklearn.linear_model import LinearRegression
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd  # Make sure pandas is imported
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create dropdown widget for selecting advertising channel
channel_dropdown = widgets.Dropdown(
    options=['TV', 'radio', 'newspaper'],
    value='TV',
    description='Channel:',
    style={'description_width': 'initial'}
)

# Create output widget for the plot and correlation
output_area = widgets.Output()

# Function to update the plot and correlation based on the selected channel
def update_regression_plot(channel):
    with output_area:
        clear_output(wait=True)  # wait=True prevents flickering

        if channel not in df.columns:
            print(f"Error: Channel '{channel}' not found in DataFrame columns.")
            return
        if 'sales' not in df.columns:
            print(f"Error: 'sales' column not found in DataFrame.")
            return

        X = df[[channel]]
        y = df['sales']
        model = LinearRegression().fit(X, y)
        y_pred = model.predict(X)

        # Create the figure using Plotly Express
        fig_px = px.scatter(df, x=channel, y='sales',
                            labels={channel: f'{channel} Ad Spend ($1000s)',
                                    'sales': 'Sales (\\$M)'},  # Use \\$ for literal $ in LaTeX
                            title=f'Linear Regression of Sales on {channel} Advertising')
        fig_px.add_trace(go.Scatter(x=df[channel].values, y=y_pred, mode='lines', # Use .values for clarity if X is 1D
                                    name=f'Regression Line (y={model.coef_[0]:.2f}x + {model.intercept_:.2f})',
                                    line=dict(color='red')))

        # Convert the Plotly Express figure to a FigureWidget
        fig_widget = go.FigureWidget(fig_px)

        # Display the FigureWidget within the output area
        display(fig_widget)

        correlation = df[[channel, 'sales']].corr().iloc[0, 1]
        print(f"Correlation between {channel} advertising and Sales: {correlation:.4f}")

# Define the interaction: when the dropdown value changes, call update_regression_plot
def on_channel_change(change):
    if change.new: # Ensure there's a new value
        update_regression_plot(change.new)

channel_dropdown.observe(on_channel_change, names='value')

# Display the dropdown and the output area
display(channel_dropdown, output_area)

# Trigger the display of the initial plot with the default dropdown value
# Ensure df is loaded before this call
if 'df' in globals() and isinstance(df, pd.DataFrame):
    update_regression_plot(channel_dropdown.value)
else:
    with output_area:
        print("DataFrame 'df' is not defined. Please load your data.")

Dropdown(description='Channel:', options=('TV', 'radio', 'newspaper'), style=DescriptionStyle(description_widt…

Output()

Observe how the regression line changes based on the selected advertising channel and note the corresponding correlation coefficient.

## Multiple Linear Regression: Using All Variables

Now, let's predict sales using all three advertising channels.

In [8]:
import statsmodels.formula.api as smf

# Fit the multiple linear regression model
model_multiple = smf.ols('sales ~ TV + radio + newspaper', data=df).fit()

# Print the model summary
print(model_multiple.summary())

                            OLS Regression Results                            
Dep. Variable:                  sales   R-squared:                       0.897
Model:                            OLS   Adj. R-squared:                  0.896
Method:                 Least Squares   F-statistic:                     570.3
Date:                Thu, 29 May 2025   Prob (F-statistic):           1.58e-96
Time:                        18:24:37   Log-Likelihood:                -386.18
No. Observations:                 200   AIC:                             780.4
Df Residuals:                     196   BIC:                             793.6
Df Model:                           3                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
Intercept      2.9389      0.312      9.422      0.0

The summary shows the coefficients for each advertising channel, their statistical significance, and the overall fit of the model (R-squared).

## Making Predictions

Let's make a prediction for a given advertising spend.

In [9]:
# Example prediction
new_data = pd.DataFrame({'TV': [150], 'radio': [50], 'newspaper': [20]})
predicted_sales = model_multiple.predict(new_data)
print(f"\nPredicted sales for TV=150, Radio=50, Newspaper=20: {predicted_sales[0]:.2f}")


Predicted sales for TV=150, Radio=50, Newspaper=20: 19.21
