# Interpreting the coefficients of a linear regression

A regression coefficient describes how much the response variable changes for a unit change of a covariate while all other covariates remain constant.

In this notebook, we will deepen this intuition with a hands-on example.

In [None]:
import numpy as np
import pandas as pd

import statsmodels.formula.api as smf

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
sns.set_context('poster')

## Generate data

We generate synthetic data using a form of structural equation modeling.
This way, we can check whether we are able to recover the coefficients.

In [None]:
N = 1000

beta_g1 = 1.4
beta_g2 = -0.8

mean_g1 = -2
mean_g2 = 10

In [None]:
np.random.seed(42)

X = np.random.normal(size=N * 2)
Y = np.r_[
    beta_g1 * X[: int(len(X) / 2)] + np.random.normal(mean_g1, size=N),
    beta_g2 * X[int(len(X) / 2) :] + np.random.normal(mean_g2, size=N),
]
group = ['$G_1$'] * N + ['$G_2$'] * N

In [None]:
df = pd.DataFrame({'X': X, 'Y': Y, 'group': group})

df['group'] = df['group'].astype('category')

df.head()

## Fit model

The model:
$$
Y \sim \beta_0 + \beta_1 \cdot group + \beta_2 \cdot X + \beta_3 \cdot X \cdot group
$$

In [None]:
mod = smf.ols(formula='Y ~ X * group', data=df)
fit = mod.fit()

## Investigate result

### Retrieve coefficients

In [None]:
res = fit.summary()
res.tables[1]

In [None]:
coefs = fit.params
coefs

### Understand their meaning

In [None]:
fitted_beta_g1 = coefs['X']
fitted_beta_g2 = coefs['X'] + coefs['X:group[T.$G_2$]']

fitted_mean_g1 = coefs['Intercept']
fitted_mean_g2 = coefs['Intercept'] + coefs['group[T.$G_2$]']

In [None]:
pd.DataFrame(
    {
        'label': ['beta_g1', 'beta_g2', 'mean_g1', 'mean_g2'],
        'true_value': [beta_g1, beta_g2, mean_g1, mean_g2],
        'fitted_value': [
            fitted_beta_g1,
            fitted_beta_g2,
            fitted_mean_g1,
            fitted_mean_g2,
        ],
    }
)

In [None]:
def annotate_plot(space, mean, beta, color):
    values = mean + beta * space

    ax.plot(space, values, color=color)
    ax.axhline(mean, ls='dashed', color=color)

    mid = len(space) // 2
    ax.text(
        space[mid],
        values[mid],
        f'${mean:.2f} + {beta:.2f} \cdot x$',
        color=color,
        size=12,
        bbox=dict(boxstyle='round4,pad=.5', fc='0.85'),
        ha='center',
    )

In [None]:
plt.figure(figsize=(16, 12))
ax = sns.scatterplot(x='X', y='Y', hue='group', data=df, s=10)

sub = df.loc[df['group'] == '$G_1$', 'X']
annotate_plot(
    np.linspace(sub.min(), sub.max()),
    fitted_mean_g1,
    fitted_beta_g1,
    sns.color_palette()[0],
)

sub = df.loc[df['group'] == '$G_2$', 'X']
annotate_plot(
    np.linspace(sub.min(), sub.max()),
    fitted_mean_g2,
    fitted_beta_g2,
    sns.color_palette()[1],
)

plt.legend(bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)