-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
lasso_regression.py
62 lines (48 loc) · 1.94 KB
/
lasso_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Import helper functions
from mlfromscratch.supervised_learning import LassoRegression
from mlfromscratch.utils import k_fold_cross_validation_sets, normalize, mean_squared_error
from mlfromscratch.utils import train_test_split, polynomial_features, Plot
def main():
# Load temperature data
data = pd.read_csv('mlfromscratch/data/TempLinkoping2016.txt', sep="\t")
time = np.atleast_2d(data["time"].values).T
temp = data["temp"].values
X = time # fraction of the year [0, 1]
y = temp
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)
poly_degree = 13
model = LassoRegression(degree=15,
reg_factor=0.05,
learning_rate=0.001,
n_iterations=4000)
model.fit(X_train, y_train)
# Training error plot
n = len(model.training_errors)
training, = plt.plot(range(n), model.training_errors, label="Training Error")
plt.legend(handles=[training])
plt.title("Error Plot")
plt.ylabel('Mean Squared Error')
plt.xlabel('Iterations')
plt.show()
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print ("Mean squared error: %s (given by reg. factor: %s)" % (mse, 0.05))
y_pred_line = model.predict(X)
# Color map
cmap = plt.get_cmap('viridis')
# Plot the results
m1 = plt.scatter(366 * X_train, y_train, color=cmap(0.9), s=10)
m2 = plt.scatter(366 * X_test, y_test, color=cmap(0.5), s=10)
plt.plot(366 * X, y_pred_line, color='black', linewidth=2, label="Prediction")
plt.suptitle("Lasso Regression")
plt.title("MSE: %.2f" % mse, fontsize=10)
plt.xlabel('Day')
plt.ylabel('Temperature in Celcius')
plt.legend((m1, m2), ("Training data", "Test data"), loc='lower right')
plt.show()
if __name__ == "__main__":
main()