# Multiple Regression

xxx

## The Model

You can improve a `single linear regression` model with additional features (independent variables):

<img src="images/multiple_linear_regression1.png" alt="" style="width: 600px;"/>

In multiple regression the vector of parameters is usually called β.

Assumptions:
- The columns of `x` are `linearly independent` - there is no way to write any one as a weighted sum of some of the others. If not met, we could not estimate `beta`.
- The columns of `x` are all uncorrelated with the `errors e`.

In [2]:
from scratch.linear_algebra import dot, Vector

# pay attention to constant terms added to beta and x_i on the first position
# beta = [alpha, beta_1, ..., beta_k]
# x_i = [1, x_i1, ..., x_ik]

def predict(x: Vector, beta: Vector) -> float:
    '''assumes that the first element of x is 1'''
    return dot(x, beta)

# For example, x_i (independent variables vector):
[
    1,  # constant term
    34, # feature 1
    2,  # feature 2
    0   # feature 3...
]

[1, 34, 2, 0]

As we did in the simple linear model, we’ll choose `beta` to minimize the `sum of squared errors`. Finding an exact solution is not simple to do by hand, which means we’ll need to use `gradient descent`. The error function is almost identical to the one used for `simple linear regression` but instead of expecting parameters `alpha`, `beta` it will take a vector of arbitrary length:

In [3]:
from typing import List

# -e = alpha + beta_i * x_i - y_i
def error(x: Vector, y: float, beta: Vector) -> float:
    return predict(x, beta) - y

In [4]:
# -e^2 == e^2
def squared_error(x: Vector, y: float, beta: Vector) -> float:
    return error(x, y, beta) ** 2

In [5]:
# 30 = 4*1 + 4*2 + 4*3 + e

x = [1, 2, 3]
y = 30
beta = [4, 4, 4] # so prediction = 4 + 8 + 12 = 24

In [6]:
error(x, y, beta)

-6

In [7]:
squared_error(x, y, beta)

36

- [Gradient calculation (in Polish)](https://www.youtube.com/watch?v=woqa2CUxw00)

In [10]:
# The gradient (vector of partial derivatives)
# err = (alpha + beta_i * x_i - y_i)
def sqerror_gradient(x: Vector, y: float, beta: Vector) -> Vector:
    err = error(x, y, beta)
    return [2 * err * x_i for x_i in x] # x_0 = 1, so we have partial deriverative in respect to alpha = 2*err*1

In [12]:
# grad_a = 2 * err = -12
# grad_b1 = 2 * err * x_1 = -12 * 2 = -24
# grad_b2 = 2 * err * x_2 = -12 * 3 = -36

sqerror_gradient(x, y, beta)

-6


[-12, -24, -36]

At this point we are ready to find the optimal `beta` using `gradient descent`. The `least_squares_fit` function can work with any dataset:

In [14]:
inputs: List[List[float]] = [[1.,49,4,0],[1,41,9,0],[1,40,8,0],[1,25,6,0],[1,21,1,0],[1,21,0,0],[1,19,3,0],[1,19,0,0],[1,18,9,0],[1,18,8,0],[1,16,4,0],[1,15,3,0],[1,15,0,0],[1,15,2,0],[1,15,7,0],[1,14,0,0],[1,14,1,0],[1,13,1,0],[1,13,7,0],[1,13,4,0],[1,13,2,0],[1,12,5,0],[1,12,0,0],[1,11,9,0],[1,10,9,0],[1,10,1,0],[1,10,1,0],[1,10,7,0],[1,10,9,0],[1,10,1,0],[1,10,6,0],[1,10,6,0],[1,10,8,0],[1,10,10,0],[1,10,6,0],[1,10,0,0],[1,10,5,0],[1,10,3,0],[1,10,4,0],[1,9,9,0],[1,9,9,0],[1,9,0,0],[1,9,0,0],[1,9,6,0],[1,9,10,0],[1,9,8,0],[1,9,5,0],[1,9,2,0],[1,9,9,0],[1,9,10,0],[1,9,7,0],[1,9,2,0],[1,9,0,0],[1,9,4,0],[1,9,6,0],[1,9,4,0],[1,9,7,0],[1,8,3,0],[1,8,2,0],[1,8,4,0],[1,8,9,0],[1,8,2,0],[1,8,3,0],[1,8,5,0],[1,8,8,0],[1,8,0,0],[1,8,9,0],[1,8,10,0],[1,8,5,0],[1,8,5,0],[1,7,5,0],[1,7,5,0],[1,7,0,0],[1,7,2,0],[1,7,8,0],[1,7,10,0],[1,7,5,0],[1,7,3,0],[1,7,3,0],[1,7,6,0],[1,7,7,0],[1,7,7,0],[1,7,9,0],[1,7,3,0],[1,7,8,0],[1,6,4,0],[1,6,6,0],[1,6,4,0],[1,6,9,0],[1,6,0,0],[1,6,1,0],[1,6,4,0],[1,6,1,0],[1,6,0,0],[1,6,7,0],[1,6,0,0],[1,6,8,0],[1,6,4,0],[1,6,2,1],[1,6,1,1],[1,6,3,1],[1,6,6,1],[1,6,4,1],[1,6,4,1],[1,6,1,1],[1,6,3,1],[1,6,4,1],[1,5,1,1],[1,5,9,1],[1,5,4,1],[1,5,6,1],[1,5,4,1],[1,5,4,1],[1,5,10,1],[1,5,5,1],[1,5,2,1],[1,5,4,1],[1,5,4,1],[1,5,9,1],[1,5,3,1],[1,5,10,1],[1,5,2,1],[1,5,2,1],[1,5,9,1],[1,4,8,1],[1,4,6,1],[1,4,0,1],[1,4,10,1],[1,4,5,1],[1,4,10,1],[1,4,9,1],[1,4,1,1],[1,4,4,1],[1,4,4,1],[1,4,0,1],[1,4,3,1],[1,4,1,1],[1,4,3,1],[1,4,2,1],[1,4,4,1],[1,4,4,1],[1,4,8,1],[1,4,2,1],[1,4,4,1],[1,3,2,1],[1,3,6,1],[1,3,4,1],[1,3,7,1],[1,3,4,1],[1,3,1,1],[1,3,10,1],[1,3,3,1],[1,3,4,1],[1,3,7,1],[1,3,5,1],[1,3,6,1],[1,3,1,1],[1,3,6,1],[1,3,10,1],[1,3,2,1],[1,3,4,1],[1,3,2,1],[1,3,1,1],[1,3,5,1],[1,2,4,1],[1,2,2,1],[1,2,8,1],[1,2,3,1],[1,2,1,1],[1,2,9,1],[1,2,10,1],[1,2,9,1],[1,2,4,1],[1,2,5,1],[1,2,0,1],[1,2,9,1],[1,2,9,1],[1,2,0,1],[1,2,1,1],[1,2,1,1],[1,2,4,1],[1,1,0,1],[1,1,2,1],[1,1,2,1],[1,1,5,1],[1,1,3,1],[1,1,10,1],[1,1,6,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,4,1],[1,1,9,1],[1,1,9,1],[1,1,4,1],[1,1,2,1],[1,1,9,1],[1,1,0,1],[1,1,8,1],[1,1,6,1],[1,1,1,1],[1,1,1,1],[1,1,5,1]]

In [29]:
import random
import tqdm
from scratch.linear_algebra import vector_mean
from scratch.gradient_descent import gradient_step

def least_squares_fit(xs: List[Vector], ys: List[float], 
                      learning_rate: float = 0.001, num_steps: int = 1000,
                     batch_size: int = 1) -> Vector:
    '''Find the beta that minimizes the sum of squared errors
        assuming the model y = dot(x, beta)'''
    # Start with random guess
    # xs[0] -> [1.0, 49, 4, 0]
    guess = [random.random() for _ in xs[0]]
    '''guess = 
    [0.2385977483303291,
     0.6553070011285017,
     0.5627055254613955,
     0.3790588563128069]
    '''
    for _ in tqdm.trange(num_steps, desc="least squares fit"):
        for start in range(0, len(xs), batch_size):
            batch_xs = xs[start:start+batch_size]
            batch_ys = ys[start:start+batch_size]
            #print(batch_xs)
            #print("-------")
            #print(batch_ys)
            #print("---------------")
            
            # vector_mean computes the element-wise average
            gradient = vector_mean([sqerror_gradient(x, y, guess)
                                   for x, y in zip(batch_xs, batch_ys)])
            '''gradient = 
            29.49471348565843
            26.31325484669618
            25.45433711032261
            16.293954316745346
            12.557753890325948
            '''
    return gradient#guess

In [17]:
from scratch.statistics import daily_hours_good
from scratch.gradient_descent import gradient_step

In [20]:
[1,    # constant term
 49,   # number of friends
 4,    # work hours per day
 0]    # doesn't have PhD
inputs[:10]

[[1.0, 49, 4, 0],
 [1, 41, 9, 0],
 [1, 40, 8, 0],
 [1, 25, 6, 0],
 [1, 21, 1, 0],
 [1, 21, 0, 0],
 [1, 19, 3, 0],
 [1, 19, 0, 0],
 [1, 18, 9, 0],
 [1, 18, 8, 0]]

In [18]:
daily_hours_good[:10]

[1.1461666666666666,
 0.8541666666666666,
 0.868,
 0.6393333333333333,
 0.7423333333333333,
 0.9521666666666667,
 0.8566666666666667,
 0.6903333333333334,
 0.5203333333333333,
 0.5793333333333333]

In [30]:
least_squares_fit(inputs, daily_hours_good, 0.001, 100, 1)

least squares fit:   8%|▊         | 8/100 [00:00<00:01, 57.07it/s]

29.49471348565843
26.31325484669618
25.45433711032261
16.293954316745346
12.557753890325948
12.095057742734177
11.764703007945878
11.172447898503902
13.025991638048032
12.714128823789595
10.197401055858906
9.605983319485333
8.700728210043357
9.181287171893564
10.802934576519078
8.217839954594885
8.542369435519989
7.751481180071519
9.610824732288803
8.543902956180162
8.222343994329956
8.431877514990127
7.1338967770312784
8.966273849908736
8.412385594460266
6.123483080392773
6.109483080392772
7.77532663261006
8.288385594460266
6.180316413726106
7.70913048501829
7.684463818351623
8.122022780201828
8.605748408718704
7.658463818351623
5.75045359946767
7.288767670759853
6.72804204224298
7.140404856501417
7.8596640056784635
7.7506640056784635
5.4032320106858664
5.264398677352533
7.039908896236486
8.145526819936899
7.590301191420026
6.897879415311383
5.663624305869407
7.820997339011797
8.029026819936899
7.261105043828256
5.833624305869407
5.2685653440192
6.36834993438628
6.966575562903153
6.44

least squares fit:  19%|█▉        | 19/100 [00:00<00:01, 54.93it/s]


5.442158513568643
3.5494816280180226
4.040873923201563
3.119426705902886
4.269544629603299
3.4321523344197593
4.424074110528402
3.757652334419759
2.7557305583111162
5.353329219970379
3.336622853494656
3.5316523344197592
4.4049074438617355
3.9560151486781954
4.262544629603299
2.7693972249777827
4.326544629603299
5.358995886637046
3.1144267059028863
3.5399856677530925
3.079926705902886
2.7648972249777826
4.097348482011529
3.140930745637955
2.520205117121082
4.1888820026717015
2.6392345980461855
2.2480089695293124
4.594244816930138
4.588107631188575
4.410244816930138
2.8935974123046218
3.4021268932297257
1.9528128219375422
4.4707448169301385
4.542911483596805
1.9581461552708754
2.202342302862646
2.145175636195979
3.262430745637955
1.4190912331557386
1.8299835283392785
1.857150195005945
2.9069053044479216
2.0791796759310484
4.115552709073437
3.2261014520396913
1.4705912331557385
3.7368270805565644
3.004934785373025
2.373542490189485
3.9028565614816677
3.7921898948150012
2.3708758235228182

least squares fit:  30%|███       | 30/100 [00:00<00:01, 50.67it/s]


4.45173269772493
6.3584390642006525
4.20053655013316
4.141369883466494
6.211742916608882
5.42932518023531
4.801432885051769
3.0407559995011497
5.757384142085515
4.50823673746
5.745384142085515
5.62652132782708
3.4391188137595865
4.342040589868229
4.169207256534896
3.1075893328344835
3.8130111089431264
3.4351188137595865
3.949677775609793
3.5619816280180228
4.181373923201562
4.061040589868229
5.442158513568643
3.5494816280180226
4.040873923201563
3.119426705902886
4.269544629603299
3.4321523344197593
4.424074110528402
3.757652334419759
2.7557305583111162
5.353329219970379
3.336622853494656
3.5316523344197592
4.4049074438617355
3.9560151486781954
4.262544629603299
2.7693972249777827
4.326544629603299
5.358995886637046
3.1144267059028863
3.5399856677530925
3.079926705902886
2.7648972249777826
4.097348482011529
3.140930745637955
2.520205117121082
4.1888820026717015
2.6392345980461855
2.2480089695293124
4.594244816930138
4.588107631188575
4.410244816930138
2.8935974123046218
3.402126893229

least squares fit:  42%|████▏     | 42/100 [00:00<00:01, 54.69it/s]

2.202342302862646
2.145175636195979
3.262430745637955
1.4190912331557386
1.8299835283392785
1.857150195005945
2.9069053044479216
2.0791796759310484
4.115552709073437
3.2261014520396913
1.4705912331557385
3.7368270805565644
3.004934785373025
2.373542490189485
3.9028565614816677
3.7921898948150012
2.3708758235228182
1.8284835283392784
3.7765232281483345
1.429424566489072
3.509160413889898
3.2551014520396913
1.666954047414175
1.5912873807475085
2.7445719711145884
29.49471348565843
26.31325484669618
25.45433711032261
16.293954316745346
12.557753890325948
12.095057742734177
11.764703007945878
11.172447898503902
13.025991638048032
12.714128823789595
10.197401055858906
9.605983319485333
8.700728210043357
9.181287171893564
10.802934576519078
8.217839954594885
8.542369435519989
7.751481180071519
9.610824732288803
8.543902956180162
8.222343994329956
8.431877514990127
7.1338967770312784
8.966273849908736
8.412385594460266
6.123483080392773
6.109483080392772
7.77532663261006
8.288385594460266
6.18

least squares fit:  55%|█████▌    | 55/100 [00:01<00:00, 50.79it/s]


5.75045359946767
7.288767670759853
6.72804204224298
7.140404856501417
7.8596640056784635
7.7506640056784635
5.4032320106858664
5.264398677352533
7.039908896236486
8.145526819936899
7.590301191420026
6.897879415311383
5.663624305869407
7.820997339011797
8.029026819936899
7.261105043828256
5.833624305869407
5.2685653440192
6.36834993438628
6.966575562903153
6.4486832677196135
7.193938377161589
5.471765531346039
5.217736050420936
5.754628345604476
7.273609083563326
5.339736050420936
5.5057655313460385
6.1349911598629125
6.886579602638222
4.727010421904063
7.369442416896659
7.474138564488428
6.240824493196246
6.109824493196246
5.594936237747775
5.689602904414442
4.166622166455592
4.518514461639132
6.491858013856418
7.060250309039958
5.581602904414442
4.9310439425642345
4.979877275897568
5.8597990520062115
6.203995199597981
6.046661866264648
6.627554161448188
4.981710609230901
6.350191347189751
4.761018501374202
5.409744129891075
4.600185168040868
6.185832572666384
3.501900577673789
3.7367

least squares fit:  67%|██████▋   | 67/100 [00:01<00:00, 51.85it/s]

5.665106944149512
3.4299005776737888
5.886303091741281
4.8478518347075354
4.800758138914964
4.498395324656527
4.940120953173401
6.119542729282044
5.2628171007651705
5.18131710076517
4.394561991323194
5.0064542865067345
5.248983767431837
3.7715070692080563
6.112076249942215
4.909595511983366
5.398654473833573
4.944762178650033
4.797762178650033
6.3216057308673195
5.067291659575136
4.25603655013316
4.914762178650033
4.799095511983366
6.291909583275549
4.45173269772493
6.3584390642006525
4.20053655013316
4.141369883466494
6.211742916608882
5.42932518023531
4.801432885051769
3.0407559995011497
5.757384142085515
4.50823673746
5.745384142085515
5.62652132782708
3.4391188137595865
4.342040589868229
4.169207256534896
3.1075893328344835
3.8130111089431264
3.4351188137595865
3.949677775609793
3.5619816280180228
4.181373923201562
4.061040589868229
5.442158513568643
3.5494816280180226
4.040873923201563
3.119426705902886
4.269544629603299
3.4321523344197593
4.424074110528402
3.757652334419759
2.755

least squares fit:  82%|████████▏ | 82/100 [00:01<00:00, 55.19it/s]


3.9560151486781954
4.262544629603299
2.7693972249777827
4.326544629603299
5.358995886637046
3.1144267059028863
3.5399856677530925
3.079926705902886
2.7648972249777826
4.097348482011529
3.140930745637955
2.520205117121082
4.1888820026717015
2.6392345980461855
2.2480089695293124
4.594244816930138
4.588107631188575
4.410244816930138
2.8935974123046218
3.4021268932297257
1.9528128219375422
4.4707448169301385
4.542911483596805
1.9581461552708754
2.202342302862646
2.145175636195979
3.262430745637955
1.4190912331557386
1.8299835283392785
1.857150195005945
2.9069053044479216
2.0791796759310484
4.115552709073437
3.2261014520396913
1.4705912331557385
3.7368270805565644
3.004934785373025
2.373542490189485
3.9028565614816677
3.7921898948150012
2.3708758235228182
1.8284835283392784
3.7765232281483345
1.429424566489072
3.509160413889898
3.2551014520396913
1.666954047414175
1.5912873807475085
2.7445719711145884
29.49471348565843
26.31325484669618
25.45433711032261
16.293954316745346
12.5577538903259

least squares fit:  97%|█████████▋| 97/100 [00:01<00:00, 57.79it/s]

10.197401055858906
9.605983319485333
8.700728210043357
9.181287171893564
10.802934576519078
8.217839954594885
8.542369435519989
7.751481180071519
9.610824732288803
8.543902956180162
8.222343994329956
8.431877514990127
7.1338967770312784
8.966273849908736
8.412385594460266
6.123483080392773
6.109483080392772
7.77532663261006
8.288385594460266
6.180316413726106
7.70913048501829
7.684463818351623
8.122022780201828
8.605748408718704
7.658463818351623
5.75045359946767
7.288767670759853
6.72804204224298
7.140404856501417
7.8596640056784635
7.7506640056784635
5.4032320106858664
5.264398677352533
7.039908896236486
8.145526819936899
7.590301191420026
6.897879415311383
5.663624305869407
7.820997339011797
8.029026819936899
7.261105043828256
5.833624305869407
5.2685653440192
6.36834993438628
6.966575562903153
6.4486832677196135
7.193938377161589
5.471765531346039
5.217736050420936
5.754628345604476
7.273609083563326
5.339736050420936
5.5057655313460385
6.1349911598629125
6.886579602638222
4.727010

least squares fit: 100%|██████████| 100/100 [00:01<00:00, 54.50it/s]

5.886303091741281
4.8478518347075354
4.800758138914964
4.498395324656527
4.940120953173401
6.119542729282044
5.2628171007651705
5.18131710076517
4.394561991323194
5.0064542865067345
5.248983767431837
3.7715070692080563
6.112076249942215
4.909595511983366
5.398654473833573
4.944762178650033
4.797762178650033
6.3216057308673195
5.067291659575136
4.25603655013316
4.914762178650033
4.799095511983366
6.291909583275549
4.45173269772493
6.3584390642006525
4.20053655013316
4.141369883466494
6.211742916608882
5.42932518023531
4.801432885051769
3.0407559995011497
5.757384142085515
4.50823673746
5.745384142085515
5.62652132782708
3.4391188137595865
4.342040589868229
4.169207256534896
3.1075893328344835
3.8130111089431264
3.4351188137595865
3.949677775609793
3.5619816280180228
4.181373923201562
4.061040589868229
5.442158513568643
3.5494816280180226
4.040873923201563
3.119426705902886
4.269544629603299
3.4321523344197593
4.424074110528402
3.757652334419759
2.7557305583111162
5.353329219970379
3.336




[5.489143942229177, 5.489143942229177, 27.445719711145884, 5.489143942229177]