# Chapter 15. Multiple regression

## 15.1 The model

yi = α + β1xi1+. . .+βkxik + εi

## 15.2 Further assumption of the least squares model

1. The columns of x are linearly independent, that there is no way to write any one as a weighted sum of some of the others 
2. The columns of x are all uncorrelated with the errors

## 15.3 Fitting the model

In [3]:
from typing import List
Vector = List[float]

def dot(v: Vector, w: Vector) -> float:
    """Computes v_1 * w_1 + ... + v_n * w_n"""
    assert len(v) == len(w), "vectors must be same length"

    return sum(v_i * w_i for v_i, w_i in zip(v, w))

def predict(x: Vector, beta: Vector) -> float:
    return dot(x, beta) 

def error(x: Vector, y:float, beta: Vector) -> float:
    return predict(x, beta) - y

def squared_error(x: Vector, y: float, beta: Vector) -> float:
    return error(x, y, beta) ** 2

x = [1, 2, 3]
y = 30
beta = [4, 4, 4]

assert error(x, y, beta) == -6
assert squared_error(x, y, beta) == 36

In [4]:
def sqerror_gradient(x: Vector, y: float, beta: Vector) -> Vector:
    err = error(x, y, beta)
    return [2 * err * x_i for x_i in x]

assert sqerror_gradient(x, y, beta) == [-12, -24, -36]

In [14]:
import random 
import tqdm 


def vector_sum(vectors: List[Vector]) -> Vector:
    """Sums all corresponding elements"""
    # Check that vectors is not empty
    assert vectors, "no vectors provided!"

    # Check the vectors are all the same size
    num_elements = len(vectors[0])
    assert all(len(v) == num_elements for v in vectors), "different sizes!"

    # the i-th element of the result is the sum of every vector[i]
    return [sum(vector[i] for vector in vectors)
            for i in range(num_elements)]

def scalar_multiply(c: float, v: Vector) -> Vector:
    """Multiplies every element by c"""
    return [c * v_i for v_i in v]

assert scalar_multiply(2, [1, 2, 3]) == [2, 4, 6]

def vector_mean(vectors: List[Vector]) -> Vector:
    """Computes the element-wise average"""
    n = len(vectors)
    return scalar_multiply(1/n, vector_sum(vectors))

def add(v: Vector, w: Vector) -> Vector:
    """Adds corresponding elements"""
    assert len(v) == len(w), "vectors must be the same length"

    return [v_i + w_i for v_i, w_i in zip(v, w)]

def gradient_step(v: Vector, gradient: Vector, step_size: float) -> Vector:
    """Moves `step_size` in the `gradient` direction from `v`"""
    assert len(v) == len(gradient)
    step = scalar_multiply(step_size, gradient)
    return add(v, step)


In [15]:
def least_squares_fit(xs: List[Vector],
                      ys: List[Vector],
                      learning_rate: float = 0.001,
                      num_steps: int = 1000,
                      batch_size: int = 1) -> Vector:
    '''Find the beta that minimizes the sum of squared errors assuming the model y = dot(x, beta)'''
    
    # Start with a random guess 
    guess = [random.random() for _ in xs[0]]
    
    for _ in tqdm.trange(num_steps, desc = 'least squares fit'):
        for start in range(0, len(xs), batch_size):
            batch_xs = xs[start:start+batch_size]
            batch_ys = ys[start:start+batch_size]
            
            gradient = vector_mean([sqerror_gradient(x, y, guess)
                                    for x, y in zip(batch_xs, batch_ys)])
            guess = gradient_step(guess, gradient, -learning_rate)
    return guess

In [16]:
daily_minutes = [1,68.77,51.25,52.08,38.36,44.54,57.13,51.4,41.42,31.22,34.76,54.01,38.79,47.59,49.1,27.66,41.03,36.73,48.65,28.12,46.62,35.57,32.98,35,26.07,23.77,39.73,40.57,31.65,31.21,36.32,20.45,21.93,26.02,27.34,23.49,46.94,30.5,33.8,24.23,21.4,27.94,32.24,40.57,25.07,19.42,22.39,18.42,46.96,23.72,26.41,26.97,36.76,40.32,35.02,29.47,30.2,31,38.11,38.18,36.31,21.03,30.86,36.07,28.66,29.08,37.28,15.28,24.17,22.31,30.17,25.53,19.85,35.37,44.6,17.23,13.47,26.33,35.02,32.09,24.81,19.33,28.77,24.26,31.98,25.73,24.86,16.28,34.51,15.23,39.72,40.8,26.06,35.76,34.76,16.13,44.04,18.03,19.65,32.62,35.59,39.43,14.18,35.24,40.13,41.82,35.45,36.07,43.67,24.61,20.9,21.9,18.79,27.61,27.21,26.61,29.77,20.59,27.53,13.82,33.2,25,33.1,36.65,18.63,14.87,22.2,36.81,25.53,24.62,26.25,18.21,28.08,19.42,29.79,32.8,35.99,28.32,27.79,35.88,29.06,36.28,14.1,36.63,37.49,26.9,18.58,38.48,24.48,18.95,33.55,14.24,29.04,32.51,25.63,22.22,19,32.73,15.16,13.9,27.2,32.01,29.27,33,13.74,20.42,27.32,18.23,35.35,28.48,9.08,24.62,20.12,35.26,19.92,31.02,16.49,12.16,30.7,31.22,34.65,13.13,27.51,33.2,31.57,14.1,33.42,17.44,10.12,24.42,9.82,23.39,30.93,15.03,21.67,31.09,33.29,22.61,26.89,23.48,8.38,27.81,32.35,23.84]
num_friends = [100.0,49,41,40,25,21,21,19,19,18,18,16,15,15,15,15,14,14,13,13,13,13,12,12,11,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,8,8,8,8,8,8,8,8,8,8,8,8,8,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]

outlier = num_friends.index(100)    # index of outlier

num_friends_good = [x
                    for i, x in enumerate(num_friends)
                    if i != outlier]

daily_minutes_good = [x
                      for i, x in enumerate(daily_minutes)
                      if i != outlier]

In [17]:
random.seed(0)

inputs: List[List[float]] = [[1.,49,4,0],[1,41,9,0],[1,40,8,0],[1,25,6,0],[1,21,1,0],[1,21,0,0],[1,19,3,0],[1,19,0,0],[1,18,9,0],[1,18,8,0],[1,16,4,0],[1,15,3,0],[1,15,0,0],[1,15,2,0],[1,15,7,0],[1,14,0,0],[1,14,1,0],[1,13,1,0],[1,13,7,0],[1,13,4,0],[1,13,2,0],[1,12,5,0],[1,12,0,0],[1,11,9,0],[1,10,9,0],[1,10,1,0],[1,10,1,0],[1,10,7,0],[1,10,9,0],[1,10,1,0],[1,10,6,0],[1,10,6,0],[1,10,8,0],[1,10,10,0],[1,10,6,0],[1,10,0,0],[1,10,5,0],[1,10,3,0],[1,10,4,0],[1,9,9,0],[1,9,9,0],[1,9,0,0],[1,9,0,0],[1,9,6,0],[1,9,10,0],[1,9,8,0],[1,9,5,0],[1,9,2,0],[1,9,9,0],[1,9,10,0],[1,9,7,0],[1,9,2,0],[1,9,0,0],[1,9,4,0],[1,9,6,0],[1,9,4,0],[1,9,7,0],[1,8,3,0],[1,8,2,0],[1,8,4,0],[1,8,9,0],[1,8,2,0],[1,8,3,0],[1,8,5,0],[1,8,8,0],[1,8,0,0],[1,8,9,0],[1,8,10,0],[1,8,5,0],[1,8,5,0],[1,7,5,0],[1,7,5,0],[1,7,0,0],[1,7,2,0],[1,7,8,0],[1,7,10,0],[1,7,5,0],[1,7,3,0],[1,7,3,0],[1,7,6,0],[1,7,7,0],[1,7,7,0],[1,7,9,0],[1,7,3,0],[1,7,8,0],[1,6,4,0],[1,6,6,0],[1,6,4,0],[1,6,9,0],[1,6,0,0],[1,6,1,0],[1,6,4,0],[1,6,1,0],[1,6,0,0],[1,6,7,0],[1,6,0,0],[1,6,8,0],[1,6,4,0],[1,6,2,1],[1,6,1,1],[1,6,3,1],[1,6,6,1],[1,6,4,1],[1,6,4,1],[1,6,1,1],[1,6,3,1],[1,6,4,1],[1,5,1,1],[1,5,9,1],[1,5,4,1],[1,5,6,1],[1,5,4,1],[1,5,4,1],[1,5,10,1],[1,5,5,1],[1,5,2,1],[1,5,4,1],[1,5,4,1],[1,5,9,1],[1,5,3,1],[1,5,10,1],[1,5,2,1],[1,5,2,1],[1,5,9,1],[1,4,8,1],[1,4,6,1],[1,4,0,1],[1,4,10,1],[1,4,5,1],[1,4,10,1],[1,4,9,1],[1,4,1,1],[1,4,4,1],[1,4,4,1],[1,4,0,1],[1,4,3,1],[1,4,1,1],[1,4,3,1],[1,4,2,1],[1,4,4,1],[1,4,4,1],[1,4,8,1],[1,4,2,1],[1,4,4,1],[1,3,2,1],[1,3,6,1],[1,3,4,1],[1,3,7,1],[1,3,4,1],[1,3,1,1],[1,3,10,1],[1,3,3,1],[1,3,4,1],[1,3,7,1],[1,3,5,1],[1,3,6,1],[1,3,1,1],[1,3,6,1],[1,3,10,1],[1,3,2,1],[1,3,4,1],[1,3,2,1],[1,3,1,1],[1,3,5,1],[1,2,4,1],[1,2,2,1],[1,2,8,1],[1,2,3,1],[1,2,1,1],[1,2,9,1],[1,2,10,1],[1,2,9,1],[1,2,4,1],[1,2,5,1],[1,2,0,1],[1,2,9,1],[1,2,9,1],[1,2,0,1],[1,2,1,1],[1,2,1,1],[1,2,4,1],[1,1,0,1],[1,1,2,1],[1,1,2,1],[1,1,5,1],[1,1,3,1],[1,1,10,1],[1,1,6,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,4,1],[1,1,9,1],[1,1,9,1],[1,1,4,1],[1,1,2,1],[1,1,9,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,1,1],[1,1,1,1],[1,1,5,1]]

learning_rate = 0.001
beta = least_squares_fit(inputs, daily_minutes_good, learning_rate, 5000, 25)
assert 30.5 < beta[0] < 30.7 
assert 0.96 < beta[1] < 1.00

least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2464.47it/s]


## 15.4 Interpreting the model

Think of the coefficients of the model as representing all-else-being-equal estimates of the impacts of each factor

Add the interaction effects

## 15.5 Goodness of fit

In [21]:
def mean(xs: List[float]) -> float:
    return sum(xs) / len(xs)

def de_mean(xs: List[float]) -> List[float]:
    """Translate xs by subtracting its mean (so the result has mean 0)"""
    x_bar = mean(xs)
    return [x - x_bar for x in xs]

def total_sum_of_squares(y: Vector) -> float:
    """the total squared variation of y_i's from their mean"""
    return sum(v ** 2 for v in de_mean(y))

In [22]:
def multiple_r_squared (xs: List[Vector], ys: Vector, beta: Vector) -> float:
    sum_of_squared_errors = sum(error(x, y, beta) ** 2
                                for x, y in zip(xs, ys))
    return 1.0 - sum_of_squared_errors / total_sum_of_squares(ys)

assert 0.67 < multiple_r_squared(inputs, daily_minutes_good, beta) < 0.68

In a multiple regression, we also need to look at the standard errors of the coefficients, which measure how certain we are about our estimates of each βi. The regression as a whole may fit our data very well, but if some of the independent variables are correlated (or irrelevant), their coefficients might not mean much.

## 15.6 Digression: the bootstrap

In [29]:
from typing import TypeVar, Callable 

X = TypeVar('X') # Generic type for data
Stat = TypeVar('Stat') # Generic type for 'statistic'

def bootstrap_sample(data: List[X]) -> List[X]:
    '''randomly samples len(data) elements with replacement'''
    return [random.choice(data) for _ in data]

def bootstrap_statistic(data: List[X],
                        stats_fn: Callable[[List[X]], Stat],
                        num_samples: int) -> List[Stat]:
    '''evaluates stats_fn on num_samples bootstrap samples from data'''
    return [stats_fn(bootstrap_sample(data)) for _ in range(num_samples)]

In [34]:
# 101 points all very close to 100
close_to_100 = [99.5 + random.random() for _ in range(101)]

# 101 points, 50 of them near 0, 50 of them near 200
far_from_100 = ([99.5 + random.random()] + 
                [random.random() for _ in range(50)] + 
                [200 + random.random() for _ in range(50)])

In [42]:
def _median_odd(xs: List[float]) -> float:
    """If len(xs) is odd, the median is the middle element"""
    return sorted(xs)[len(xs) // 2]

def _median_even(xs: List[float]) -> float:
    """If len(xs) is even, it's the average of the middle two elements"""
    sorted_xs = sorted(xs)
    hi_midpoint = len(xs) // 2  # e.g. length 4 => hi_midpoint 2
    return (sorted_xs[hi_midpoint - 1] + sorted_xs[hi_midpoint]) / 2

def median(v: List[float]) -> float:
    """Finds the 'middle-most' value of v"""
    return _median_even(v) if len(v) % 2 == 0 else _median_odd(v)

def de_mean(xs: List[float]) -> List[float]:
    """Translate xs by subtracting its mean (so the result has mean 0)"""
    x_bar = mean(xs)
    return [x - x_bar for x in xs]

def variance(xs: List[float]) -> float:
    """Almost the average squared deviation from the mean"""
    assert len(xs) >= 2, "variance requires at least two elements"

    n = len(xs)
    deviations = de_mean(xs)
    return sum_of_squares(deviations) / (n - 1)

import math

def standard_deviation(xs: List[float]) -> float:
    """The standard deviation is the square root of the variance"""
    return math.sqrt(variance(xs))

def sum_of_squares(v: Vector) -> float:
    """Computes the sum of squared elements in v"""
    return dot(v, v)

In [38]:
medians_close = bootstrap_statistic(close_to_100, median, 100)
medians_far = bootstrap_statistic(far_from_100, median, 100)

In [43]:
assert standard_deviation(medians_close) < 1
assert standard_deviation(medians_far) > 90

## 15.7 Standard errors of regression coefficients

Repeatedly take a bootstrap_sample of the data and estimate beta based on that sample. If the coefficient varies greatly across samples, then we can't be at all confident in our estimate

In [46]:
from typing import Tuple 

import datetime 

def estimate_sample_beta(pairs: List[Tuple[Vector, float]]):
    x_sample = [x for x , _ in pairs]
    y_sample = [y for _, y in pairs]
    beta = least_squares_fit(x_sample, y_sample, learning_rate, 5000, 25)
    print('bootstrap sample', beta)
    return beta

random.seed(0)
bootstrap_betas = bootstrap_statistic(list(zip(inputs, daily_minutes_good)),
                                      estimate_sample_beta, 
                                      100)

least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2526.52it/s]
least squares fit:   5%|▍         | 240/5000 [00:00<00:01, 2396.28it/s]

bootstrap sample [30.49402029547432, 1.0393791030498776, -1.9516851948558502, 0.7483721251697333]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2490.36it/s]
least squares fit:   5%|▍         | 245/5000 [00:00<00:01, 2444.33it/s]

bootstrap sample [30.149963287526045, 1.0005300432763113, -2.0650380122822543, 3.177179854834797]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2459.08it/s]
least squares fit:   5%|▍         | 235/5000 [00:00<00:02, 2344.33it/s]

bootstrap sample [29.202826897693722, 1.0017089956376213, -1.5294248424787367, 0.9528580285760854]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2578.29it/s]
least squares fit:   5%|▍         | 242/5000 [00:00<00:01, 2414.37it/s]

bootstrap sample [31.29481217471851, 0.959264729494101, -1.9120875473727545, 0.039471107599519425]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2477.19it/s]
least squares fit:   5%|▍         | 234/5000 [00:00<00:02, 2336.73it/s]

bootstrap sample [32.124144227949955, 0.8569794405277468, -1.9936770520754086, 1.0416943131373024]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2547.90it/s]
least squares fit:   5%|▌         | 256/5000 [00:00<00:01, 2553.84it/s]

bootstrap sample [31.8691994453096, 0.7748022870492418, -2.0087625702876446, -1.2407036547656678]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2580.35it/s]
least squares fit:  10%|█         | 513/5000 [00:00<00:01, 2546.72it/s]

bootstrap sample [31.08119759650208, 0.998386254386918, -1.9833984114987815, 0.9567646217580389]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2460.05it/s]
least squares fit:  10%|█         | 502/5000 [00:00<00:01, 2513.66it/s]

bootstrap sample [29.254530450577782, 0.9763387220017684, -1.7430339427043595, 1.9944240584590935]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2516.21it/s]
least squares fit:   5%|▍         | 241/5000 [00:00<00:01, 2399.75it/s]

bootstrap sample [31.649174199331632, 0.9389340937491032, -1.9733848473304205, -0.15249287969349437]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2363.47it/s]
least squares fit:   5%|▍         | 227/5000 [00:00<00:02, 2260.51it/s]

bootstrap sample [30.040109260720964, 1.0531247386421572, -1.7694878560354388, 1.302971911084249]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2471.51it/s]
least squares fit:   5%|▍         | 245/5000 [00:00<00:01, 2449.59it/s]

bootstrap sample [29.066927054721297, 1.2792640005590372, -1.937339904947856, 0.9183668519320846]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2497.49it/s]
least squares fit:   5%|▍         | 232/5000 [00:00<00:02, 2317.35it/s]

bootstrap sample [31.740476303331718, 0.9538879291586574, -2.0689725879612477, 1.4785830120835612]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2539.95it/s]
least squares fit:   5%|▌         | 253/5000 [00:00<00:01, 2527.04it/s]

bootstrap sample [29.46654084062671, 0.9837739845117637, -1.9915052407093472, 3.150029950640157]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2385.16it/s]
least squares fit:   4%|▍         | 224/5000 [00:00<00:02, 2238.99it/s]

bootstrap sample [30.97515705531374, 0.9420086669374396, -2.0367671746636065, 0.6323599067111714]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2408.51it/s]
least squares fit:  10%|▉         | 485/5000 [00:00<00:01, 2402.21it/s]

bootstrap sample [31.478778128163995, 0.8623617407485805, -1.8798782324632368, -0.11949170941208796]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2498.16it/s]
least squares fit:  10%|▉         | 491/5000 [00:00<00:01, 2457.94it/s]

bootstrap sample [33.87286992682308, 0.8824018752321863, -1.8978803929581156, -1.0333647107478692]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2519.12it/s]
least squares fit:   5%|▌         | 250/5000 [00:00<00:01, 2494.12it/s]

bootstrap sample [29.272206898314987, 1.0899411603739348, -1.8911943299601002, 3.162677841885805]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2521.91it/s]
least squares fit:   5%|▌         | 257/5000 [00:00<00:01, 2567.87it/s]

bootstrap sample [30.83577809561691, 1.0242186355671827, -1.9209251081222494, 1.3383795133620962]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2518.44it/s]
least squares fit:   5%|▌         | 259/5000 [00:00<00:01, 2584.34it/s]

bootstrap sample [28.211162672015906, 1.458352440392638, -1.70241171517105, 0.94520401518726]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2526.08it/s]
least squares fit:   5%|▌         | 260/5000 [00:00<00:01, 2596.42it/s]

bootstrap sample [29.93552336056075, 0.9470529669956465, -1.8491245571618218, 0.8573641103651921]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2553.64it/s]
least squares fit:   5%|▌         | 258/5000 [00:00<00:01, 2572.39it/s]

bootstrap sample [30.636052325886002, 0.9966176913889679, -1.8308401560119625, 0.13862673979220685]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2397.74it/s]
least squares fit:   9%|▉         | 464/5000 [00:00<00:01, 2319.93it/s]

bootstrap sample [30.855945311129382, 0.9925731301194982, -1.834813509355548, 1.9711641797749935]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2333.59it/s]
least squares fit:   9%|▉         | 469/5000 [00:00<00:01, 2353.22it/s]

bootstrap sample [29.77226728370608, 1.0493381798575807, -1.6999309651266667, 0.9221651877575128]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2385.97it/s]
least squares fit:   5%|▌         | 252/5000 [00:00<00:01, 2513.04it/s]

bootstrap sample [28.78470035821775, 0.9629668755117143, -1.7818333154132011, 1.905170320676074]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2417.06it/s]
least squares fit:   5%|▍         | 239/5000 [00:00<00:01, 2384.35it/s]

bootstrap sample [31.769457992268443, 0.9040180814550004, -1.867677593282121, -0.7957987643064021]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2447.25it/s]
least squares fit:  10%|▉         | 476/5000 [00:00<00:01, 2342.53it/s]

bootstrap sample [30.06836087625883, 0.9237365767889174, -1.7326788050658604, 1.9044381512517516]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2412.29it/s]
least squares fit:   5%|▍         | 241/5000 [00:00<00:01, 2407.33it/s]

bootstrap sample [29.248924522774857, 1.0251706036709467, -1.6396068581125103, 1.7875039505127974]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2462.49it/s]
least squares fit:   9%|▉         | 468/5000 [00:00<00:01, 2311.89it/s]

bootstrap sample [26.160390418551746, 1.3566609275406472, -1.880731098382104, 3.884946816272215]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2450.60it/s]
least squares fit:  10%|█         | 500/5000 [00:00<00:01, 2484.57it/s]

bootstrap sample [31.9708823034869, 0.8717159490168253, -1.8037586194211703, -0.23788897755135452]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2478.18it/s]
least squares fit:   5%|▍         | 246/5000 [00:00<00:01, 2458.75it/s]

bootstrap sample [30.580903591168788, 0.9610711598856186, -1.8984859248085817, -0.0023187395782722853]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2472.64it/s]
least squares fit:  10%|▉         | 497/5000 [00:00<00:01, 2461.35it/s]

bootstrap sample [31.433330253362577, 0.8768141821390377, -1.7328584033279486, -0.10210988051437449]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2511.17it/s]
least squares fit:   5%|▍         | 244/5000 [00:00<00:01, 2436.32it/s]

bootstrap sample [30.984236860566945, 1.036149466142919, -2.2200016449095226, 1.0886749895563579]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2506.46it/s]
least squares fit:   5%|▍         | 243/5000 [00:00<00:01, 2426.58it/s]

bootstrap sample [29.28237674477942, 1.0858388836439414, -1.74280602847473, 1.4397328297413239]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2508.38it/s]
least squares fit:   5%|▌         | 251/5000 [00:00<00:01, 2504.04it/s]

bootstrap sample [30.65460047430859, 0.9454408039075628, -1.7320071301269266, -0.1485862182089119]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2461.41it/s]
least squares fit:  10%|█         | 500/5000 [00:00<00:01, 2482.09it/s]

bootstrap sample [29.118139496835955, 0.893808801696634, -1.9153563192896768, 2.0598345811489462]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2489.65it/s]
least squares fit:   5%|▍         | 246/5000 [00:00<00:01, 2455.66it/s]

bootstrap sample [29.95488463195938, 0.9940567914003665, -1.7605085370056377, 1.6096257131696945]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2481.94it/s]
least squares fit:   5%|▍         | 230/5000 [00:00<00:02, 2291.25it/s]

bootstrap sample [31.000164842544855, 0.962368315956188, -1.9115208623969564, 0.7473190835230143]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2506.14it/s]
least squares fit:   5%|▌         | 255/5000 [00:00<00:01, 2548.65it/s]

bootstrap sample [30.825204744897878, 0.8912590208026672, -1.770469093607492, 0.7459655949536621]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2509.35it/s]
least squares fit:   5%|▌         | 259/5000 [00:00<00:01, 2587.36it/s]

bootstrap sample [29.366136812382027, 1.012558241060829, -1.6182773155952548, 1.017025703754098]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2499.15it/s]
least squares fit:   5%|▍         | 247/5000 [00:00<00:01, 2463.32it/s]

bootstrap sample [29.94295970143512, 1.0167217566773747, -1.5621167917565122, -0.1030904763985423]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2477.41it/s]
least squares fit:   5%|▍         | 240/5000 [00:00<00:01, 2397.63it/s]

bootstrap sample [29.962898858207343, 1.0652251821283687, -1.926924147663584, 2.0385736378519947]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2530.85it/s]
least squares fit:  10%|█         | 523/5000 [00:00<00:01, 2590.40it/s]

bootstrap sample [30.525302041791818, 0.9658944102293198, -1.8870631894489638, 0.5367690208128764]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2539.30it/s]
least squares fit:  10%|█         | 508/5000 [00:00<00:01, 2534.07it/s]

bootstrap sample [30.6785320563259, 1.0139828545599132, -1.7817299670979692, 1.6026393229652947]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2511.76it/s]
least squares fit:   5%|▌         | 253/5000 [00:00<00:01, 2529.77it/s]

bootstrap sample [30.09073747870904, 1.0047123547747132, -1.9560265455918162, 2.75259429615735]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2522.57it/s]
least squares fit:   5%|▌         | 250/5000 [00:00<00:01, 2499.47it/s]

bootstrap sample [30.937038893678913, 0.9670590611928079, -2.1124811600264293, 0.3258045605146989]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2528.70it/s]
least squares fit:   5%|▌         | 252/5000 [00:00<00:01, 2517.00it/s]

bootstrap sample [28.78930855900356, 1.1730115746597942, -1.7835138640623003, 3.2623158308236095]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2513.07it/s]
least squares fit:   5%|▌         | 258/5000 [00:00<00:01, 2577.99it/s]

bootstrap sample [31.14754912309552, 0.9326436111603991, -1.7707952504307622, -1.099359064604309]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2506.74it/s]
least squares fit:   5%|▌         | 257/5000 [00:00<00:01, 2562.37it/s]

bootstrap sample [31.813727613195653, 0.9683784085384612, -2.019078886892217, 0.7501206686148623]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2548.53it/s]
least squares fit:  10%|█         | 512/5000 [00:00<00:01, 2569.25it/s]

bootstrap sample [30.223398353230184, 0.9373764744862103, -1.5323607166675373, -0.01469994075209993]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2515.31it/s]
least squares fit:   5%|▌         | 257/5000 [00:00<00:01, 2568.36it/s]

bootstrap sample [28.01777500444891, 1.091598878794665, -1.6190191022832499, 2.397154344588124]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2516.37it/s]
least squares fit:   5%|▌         | 263/5000 [00:00<00:01, 2626.70it/s]

bootstrap sample [29.342668886496593, 0.9815156932180105, -1.9184777914462317, 1.5482939749639446]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2416.22it/s]
least squares fit:   5%|▍         | 235/5000 [00:00<00:02, 2349.60it/s]

bootstrap sample [32.53937166649288, 1.060883971208886, -2.2704689582768722, 0.3681597653761555]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2372.82it/s]
least squares fit:   5%|▍         | 232/5000 [00:00<00:02, 2311.43it/s]

bootstrap sample [30.106198499206915, 0.9657134612613777, -1.7191529436530641, -0.6267619207221298]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2497.08it/s]
least squares fit:   5%|▍         | 249/5000 [00:00<00:01, 2489.43it/s]

bootstrap sample [29.993282359977137, 0.9757399392816419, -1.9767875486880904, 2.048669364846268]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2513.10it/s]
least squares fit:   5%|▌         | 250/5000 [00:00<00:01, 2496.73it/s]

bootstrap sample [30.571136409586924, 1.066488813531558, -1.6618835177744289, -0.19985556821698586]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2513.78it/s]
least squares fit:   5%|▍         | 248/5000 [00:00<00:01, 2475.99it/s]

bootstrap sample [30.9490097252882, 0.9597396222139453, -1.9214823753987709, 1.25885503487694]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2385.17it/s]
least squares fit:   5%|▍         | 227/5000 [00:00<00:02, 2266.75it/s]

bootstrap sample [31.887007554673076, 0.9506671496957437, -2.152653973374404, 1.6869486505999165]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2497.23it/s]
least squares fit:  10%|█         | 509/5000 [00:00<00:01, 2549.90it/s]

bootstrap sample [29.081704350215187, 1.0495038787355986, -1.6920009023683746, 3.609080049949202]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2354.34it/s]
least squares fit:   5%|▍         | 235/5000 [00:00<00:02, 2347.39it/s]

bootstrap sample [31.479546830562978, 1.1296437640969312, -1.8930013630375897, 0.2328971438009536]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2465.65it/s]
least squares fit:  10%|▉         | 491/5000 [00:00<00:01, 2455.48it/s]

bootstrap sample [30.610973912805395, 1.0065894319911013, -1.836243246680104, 0.4499397217455247]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2505.82it/s]
least squares fit:   5%|▍         | 240/5000 [00:00<00:01, 2396.84it/s]

bootstrap sample [31.80927695488258, 0.9821469730488941, -2.007959621103926, -0.2411398745050231]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2486.02it/s]
least squares fit:   5%|▍         | 236/5000 [00:00<00:02, 2357.48it/s]

bootstrap sample [31.024210851804416, 0.9515774062029452, -1.9408222914617927, 0.6442854716394794]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2401.76it/s]
least squares fit:   9%|▉         | 445/5000 [00:00<00:02, 2221.64it/s]

bootstrap sample [28.908141235990033, 1.0556273838810308, -1.7935754991375803, 2.082266951237433]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2281.56it/s]
least squares fit:   5%|▍         | 231/5000 [00:00<00:02, 2306.40it/s]

bootstrap sample [30.025383087071763, 0.9490311032868943, -1.8905462953821093, 1.614968102502849]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2360.35it/s]
least squares fit:   9%|▉         | 451/5000 [00:00<00:02, 2213.02it/s]

bootstrap sample [31.344911606937217, 0.9596230552550087, -2.084944019182774, 1.0635864768954955]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2290.08it/s]
least squares fit:   8%|▊         | 409/5000 [00:00<00:02, 2068.47it/s]

bootstrap sample [30.88785665879763, 0.9739691303740718, -1.750496781109518, -2.0086684580110616]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2064.57it/s]
least squares fit:   4%|▍         | 209/5000 [00:00<00:02, 2087.76it/s]

bootstrap sample [30.524172097277887, 0.9468432200060536, -1.7489583214704674, -0.42947540813439916]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 1774.92it/s]
least squares fit:  10%|█         | 503/5000 [00:00<00:01, 2460.10it/s]

bootstrap sample [33.73887281461898, 0.8342998931764716, -2.0056583070815233, -1.0048943591784738]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2246.91it/s]
least squares fit:   4%|▍         | 217/5000 [00:00<00:02, 2166.36it/s]

bootstrap sample [29.04731144829789, 0.9737448743420717, -1.7622553843049413, 0.9744871197165679]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2404.09it/s]
least squares fit:  10%|█         | 501/5000 [00:00<00:01, 2482.46it/s]

bootstrap sample [30.849086975523765, 1.1142041012783979, -2.055393538038613, 1.8606960468590918]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2445.59it/s]
least squares fit:  10%|█         | 507/5000 [00:00<00:01, 2490.46it/s]

bootstrap sample [31.20227902410509, 1.0148203879553739, -1.831139817867853, -0.12803605188562736]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2540.50it/s]
least squares fit:  10%|▉         | 488/5000 [00:00<00:01, 2443.12it/s]

bootstrap sample [30.44951242588852, 0.9188875408835141, -1.6623667661150159, 0.41561209518705605]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2502.65it/s]
least squares fit:   5%|▌         | 259/5000 [00:00<00:01, 2587.87it/s]

bootstrap sample [30.93743647937649, 0.9178249912706591, -1.91788395405578, 0.8027340312172657]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2545.09it/s]
least squares fit:   9%|▉         | 448/5000 [00:00<00:02, 2274.50it/s]

bootstrap sample [33.07304817934109, 0.7669188362229076, -1.8621104803815107, -0.5344373694611129]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2529.04it/s]
least squares fit:   5%|▌         | 257/5000 [00:00<00:01, 2563.18it/s]

bootstrap sample [30.98035452936738, 0.9608047189289309, -1.8571138381579286, 1.2456516010877]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2526.96it/s]
least squares fit:   5%|▌         | 252/5000 [00:00<00:01, 2519.83it/s]

bootstrap sample [29.890049201061156, 0.9320508621300003, -1.815157140889288, 1.6197634219660293]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2473.41it/s]
least squares fit:   5%|▍         | 246/5000 [00:00<00:01, 2451.44it/s]

bootstrap sample [32.7497073906742, 0.8163410438179741, -1.6727937223778233, -1.627203273138944]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2520.84it/s]
least squares fit:  10%|▉         | 496/5000 [00:00<00:01, 2475.58it/s]

bootstrap sample [32.23550207094589, 0.9915112587422378, -2.201685593411146, 0.659721525694611]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2531.43it/s]
least squares fit:   5%|▍         | 249/5000 [00:00<00:01, 2478.87it/s]

bootstrap sample [30.238346353105722, 0.9812068545490507, -1.9183149068660714, 2.4252389104819785]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2509.15it/s]
least squares fit:  10%|▉         | 475/5000 [00:00<00:01, 2385.53it/s]

bootstrap sample [30.574120085079127, 0.9174840515163696, -1.791824539551342, 0.9221993996446399]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2367.92it/s]
least squares fit:  10%|▉         | 477/5000 [00:00<00:01, 2427.50it/s]

bootstrap sample [30.200588272490595, 0.9290781608340558, -1.5128386060160508, -0.27164281164191667]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2579.40it/s]
least squares fit:  10%|█         | 503/5000 [00:00<00:01, 2485.02it/s]

bootstrap sample [30.568001921567824, 1.0423323558239714, -2.0539328282484295, 2.070512986336468]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2553.22it/s]
least squares fit:   5%|▌         | 261/5000 [00:00<00:01, 2592.76it/s]

bootstrap sample [32.24170594033224, 0.928943846288939, -1.9597146432475416, -0.3283270089441192]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2469.08it/s]
least squares fit:   4%|▍         | 223/5000 [00:00<00:02, 2224.51it/s]

bootstrap sample [32.867472630955, 1.0159010210188608, -2.0279568468137548, -0.5177147877542921]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2338.63it/s]
least squares fit:  10%|▉         | 478/5000 [00:00<00:01, 2391.79it/s]

bootstrap sample [29.215869116992934, 1.0071212080144287, -1.9567505776484149, 3.7248516336467894]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2362.19it/s]
least squares fit:   5%|▌         | 250/5000 [00:00<00:01, 2499.17it/s]

bootstrap sample [29.731620363957564, 1.0022351904608418, -1.6056750069107464, 0.38365805637548667]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2258.55it/s]
least squares fit:   5%|▍         | 246/5000 [00:00<00:01, 2453.67it/s]

bootstrap sample [32.67347841435598, 0.8824434637692973, -1.9909101579029314, 0.04871947146702888]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2309.31it/s]
least squares fit:   5%|▌         | 251/5000 [00:00<00:01, 2505.89it/s]

bootstrap sample [29.15775553756214, 1.0683351454601346, -1.7096121511993116, 3.2616854857255317]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2514.96it/s]
least squares fit:   5%|▌         | 262/5000 [00:00<00:01, 2607.72it/s]

bootstrap sample [30.488240564960755, 1.0353317712496077, -1.9149562223503453, 2.595089595245561]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2474.68it/s]
least squares fit:   5%|▌         | 261/5000 [00:00<00:01, 2607.03it/s]

bootstrap sample [31.498485256154574, 0.865173722648589, -1.9003285713857743, -0.4448014961070426]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2545.57it/s]
least squares fit:  11%|█         | 532/5000 [00:00<00:01, 2665.43it/s]

bootstrap sample [28.568637146495437, 0.9377084816305519, -1.6697079214548882, 2.0378528186926]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2512.92it/s]
least squares fit:   5%|▌         | 266/5000 [00:00<00:01, 2656.84it/s]

bootstrap sample [30.8880890003392, 0.9480046573855904, -1.9409732963472783, -0.38053847722698325]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2544.09it/s]
least squares fit:  10%|█         | 517/5000 [00:00<00:01, 2617.98it/s]

bootstrap sample [30.449302174408764, 1.121851483678596, -1.9621516796699556, 2.244341597832513]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2482.91it/s]
least squares fit:   5%|▍         | 234/5000 [00:00<00:02, 2329.23it/s]

bootstrap sample [30.266204204539516, 1.006867325680496, -2.1198992898486466, 0.5362851256019128]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2524.18it/s]
least squares fit:   9%|▉         | 472/5000 [00:00<00:01, 2402.11it/s]

bootstrap sample [29.33031812613639, 1.0424517245684064, -1.8849226826885934, 2.2650387817258566]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2515.04it/s]
least squares fit:   5%|▌         | 264/5000 [00:00<00:01, 2632.55it/s]

bootstrap sample [31.777389538816966, 0.8928310423632744, -1.9269578522157438, 0.048635890062917325]


least squares fit: 100%|██████████| 5000/5000 [00:02<00:00, 2471.72it/s]
least squares fit:  10%|▉         | 496/5000 [00:00<00:01, 2509.04it/s]

bootstrap sample [28.291072745509133, 1.1873361941623277, -1.8546687169062575, 2.6390276558088757]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2549.68it/s]
least squares fit:   5%|▍         | 233/5000 [00:00<00:02, 2329.23it/s]

bootstrap sample [31.725525991297584, 0.8939775539447468, -1.843559060469148, -0.6224324630864045]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2529.20it/s]
least squares fit:  11%|█         | 529/5000 [00:00<00:01, 2653.28it/s]

bootstrap sample [30.2731194689144, 0.8005769229528958, -1.6991234036996576, 0.9748341305369917]


least squares fit: 100%|██████████| 5000/5000 [00:01<00:00, 2586.23it/s]

bootstrap sample [31.75696506368952, 1.0790800487199688, -2.0880894078200054, 1.6420943383461737]





In [47]:
bootstrap_standard_errors = [
    standard_deviation([beta[i] for beta in bootstrap_betas])
    for i in range(4)
]
print(bootstrap_standard_errors)

[1.2715078186272781, 0.10318410116073963, 0.15510591689663628, 1.2490975248051257]


In [48]:
import math 

def normal_cdf(x: float, mu: float = 0, sigma: float = 1) -> float:
    return (1 + math.erf((x - mu) / math.sqrt(2) / sigma)) / 2

In [49]:
def p_value(beta_hat_j: float, sigma_hat_j: float) -> float:
    if beta_hat_j > 0:
        # If the coefficient is positive, we need to compute twice the probability of seeing an even larger value 
        return 2 * (1 - normal_cdf(beta_hat_j / sigma_hat_j))
    else:
        # Otherwise twice the probability of seeing a smaller value 
        return 2 * normal_cdf(beta_hat_j / sigma_hat_j)
    
assert p_value(30.58, 1.27) < 0.001
assert p_value(0.972, 0.103) < 0.001
assert p_value(-1.865, 0.155) < 0.001
assert p_value(0.923, 1.249) > 0.4

## 15.8 Regularization

Regularization is an approach in which we add to the error term a penalty that gets larger as beta gets larger. We then minimize the combined error and penalty. The more importance we place on the penalty term, the more we discourage large coefficients.

In [50]:
# alpha is a hyperparameter controlling how harsh the penalty is
# Sometimes it is called lambda but that already means something in python 

def ridge_penalty(beta: Vector, alpha: float) -> float:
    return alpha * dot(beta[1:], beta[1:])

def squared_error_ridge(x: Vector,
                        y: float,
                        beta: Vector,
                        alpha: float) -> float:
    '''estimate error plus ridge penalty on beta'''
    return error(x, y, beta) ** 2 + ridge_penalty(beta, alpha)

def ridge_penalty_gradient(beta: Vector, alpha: float) -> Vector:
    '''gradient of just the ridge penalty'''
    return [0.] + [2 * alpha * beta_j for beta_j in beta[1:]]

def sqerror_ridge_gradient(x: Vector,
                           y: float,
                           beta: Vector,
                           alpha: float) -> Vector:
    '''the gradient corresponding to the ith squared error term including the ridge penalty'''
    return add(sqerror_gradient(x, y, beta),
               ridge_penalty_gradient(beta, alpha))


In [51]:
def least_squares_fit_ridge(xs: List[Vector],
                            ys: List[float],
                            alpha: float,
                            learning_rate: float,
                            num_steps: int,
                            batch_size: int = 1) -> Vector:
    # Start guess with mean
    guess = [random.random() for _ in xs[0]]

    for i in range(num_steps):
        for start in range(0, len(xs), batch_size):
            batch_xs = xs[start:start+batch_size]
            batch_ys = ys[start:start+batch_size]

            gradient = vector_mean([sqerror_ridge_gradient(x, y, guess, alpha)
                                    for x, y in zip(batch_xs, batch_ys)])
            guess = gradient_step(guess, gradient, -learning_rate)

    return guess

In [52]:
random.seed(0)
beta_0 = least_squares_fit_ridge(inputs, daily_minutes_good, 0.0, learning_rate, 5000, 25)
assert 5 < dot(beta_0[1:], beta_0[1:]) < 6
assert 0.67 < multiple_r_squared(inputs, daily_minutes_good, beta_0) < 0.69

In [53]:
beta_0_1 = least_squares_fit_ridge(inputs, daily_minutes_good, 0.1, learning_rate, 5000, 25)
assert 4 < dot(beta_0_1[1:], beta_0_1[1:]) < 5
assert 0.67 < multiple_r_squared(inputs, daily_minutes_good, beta_0_1) < 0.69

In [54]:
beta_1 = least_squares_fit_ridge(inputs, daily_minutes_good, 1, learning_rate, 5000, 25)
assert 3 < dot(beta_1[1:], beta_1[1:]) < 4
assert 0.67 < multiple_r_squared(inputs, daily_minutes_good, beta_1) < 0.69

In [55]:
beta_10 = least_squares_fit_ridge(inputs, daily_minutes_good,10, learning_rate, 5000, 25)
assert 1 < dot(beta_10[1:], beta_10[1:]) < 2
assert 0.5 < multiple_r_squared(inputs, daily_minutes_good, beta_10) < 0.6

In [None]:
def lasso_penalty(beta, alpha):
    breturn alpha * sum(abs(beta_i) for beta_i in beta[1:])

Whereas the ridge penalty shrank the coefficients overall, the lasso penalty tends to force coefficients to be 0, which makes it good for learning sparse models. Unfortunately, it’s not amenable to gradient descent, which means that we won’t be able to solve it from scratch.

## 15.9 For further exploration

scikit-learn linear_model 

In [57]:
from sklearn import linear_model

In [58]:
import statsmodels