-
Notifications
You must be signed in to change notification settings - Fork 0
/
ridge.py
118 lines (86 loc) · 3.55 KB
/
ridge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torchml as ml
class Ridge(ml.Model):
"""
<a href="https://github.com/learnables/torchml/blob/master/torchml/linear_model/ridge.py">[Source]</a>
## Description
Linear regression with L2 penalty term.
$$ w = (X^TX + \\lambda I)^{-1}X^Ty $$
* `w` - weights of the linear regression with L2 penalty
* `X` - variates
* `λ`- constant that multiplies the L2 term
* `y` - covariates
The above equation is the closed-form solution for ridge's objective function
$$ \\min_w \\frac{1}{2} \\vert \\vert Xw - y \\vert \\vert^2 I + \\frac{1}{2} \\lambda \\vert \\vert w \\vert \\vert^2 $$
## References
1. Arthur E. Hoerl and Robert W. Kennard's introduction to Ridge Regression [paper](https://www.jstor.org/stable/1271436)
2. Datacamp Lasso and Ridge Regression Tutorial [tutorial](https://www.datacamp.com/tutorial/tutorial-lasso-ridge-regression#data-splitting-and-scaling)
3. The scikit-learn [documentation page](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html?highlight=lasso#sklearn.linear_model.ridge)
## Arguments
* `alpha` (float, default=1.0) - Constant that multiplies the L2 term. alpha must be a non-negative float.
* `fit_intercept` (bool, default=False) - Whether or not to fit intercept in the model.
* `normalize` (bool, default=False) - If True, the regressors X will be normalized. normalize will be deprecated in the future.
* `copy_X` (bool, default=True) - If True, X will be copied.
* `solver` (string, default='auto') - Different solvers or algorithms to use.
## Example
~~~python
ridge = Ridge()
~~~
"""
def __init__(
self,
*,
alpha: float = 1.0,
fit_intercept: bool = False,
normalize: bool = False,
copy_X: bool = True,
solver: str = "auto"
):
super(Ridge, self).__init__()
self.alpha = alpha
self.fit_intercept = fit_intercept
self.normalize = normalize
self.copy_X = copy_X
self.solver = solver
def fit(self, X: torch.Tensor, y: torch.Tensor):
"""
## Description
Compute the weights for the model given variates and covariates.
## Arguments
* `X` (Tensor) - Input variates.
* `y` (Tensor) - Target covariates.
## Example
~~~python
ridge = Ridge()
ridge.fit(X_train, y_train)
~~~
"""
assert X.shape[0] == y.shape[0], "Number of X and y rows don't match"
device = X.device
if self.fit_intercept:
X = torch.cat([torch.ones(X.shape[0], 1, device=device), X], dim=1)
# L2 penalty term will not apply when alpha is 0
if self.alpha == 0:
self.weight = torch.pinverse(X.T @ X) @ X.T @ y
else:
ridge = self.alpha * torch.eye(X.shape[1], device=device)
# intercept term is not penalized when fit_intercept is true
if self.fit_intercept:
ridge[0][0] = 0
self.weight = torch.pinverse((X.T @ X) + ridge) @ X.T @ y
def predict(self, X: torch.Tensor):
"""
## Description
Predict covariates by the trained model.
## Arguments
* `X` (Tensor) - Input variates.
## Example
~~~python
ridge = Ridge()
ridge.fit(X_train, y_train)
ridge.predict(X_test)
~~~
"""
if self.fit_intercept:
X = torch.cat([torch.ones(X.shape[0], 1, device=X.device), X], dim=1)
return X @ self.weight