# Linear regression (CAP5602 Lecture 7)

In this demo, we will train and test linear regression on a toy linear dataset:
$Y = aX + \epsilon$, where $\epsilon$ is Gaussian noise with zero mean and a given standard deviation. We can generate such a dataset with the [make_regression](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html) function from sklearn.

### 1. Generate the toy dataset

In [None]:
from sklearn.datasets import make_regression

X, Y = make_regression(n_samples=100, n_features=1, noise=2.0)

### 2. Train/test split

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

print(X_train.shape, X_test.shape)

### 2. Visualize dataset

In [None]:
import matplotlib.pyplot as plt

plt.scatter(X_train, Y_train, color='black') # Plot train points with black color
plt.scatter(X_test, Y_test, color='red') # Plot test points with red color
plt.show() # Show the plot

### 4. Train a linear regression model on the train set

In [None]:
from sklearn.linear_model import LinearRegression

model = LinearRegression()
model.fit(X_train, Y_train)

### 5. Evaluate the trained linear regression model on the test set

In [None]:
from sklearn.metrics import mean_squared_error

Y_pred = model.predict(X_test)
mse = mean_squared_error(Y_test, Y_pred)

print(mse)

### 6. Visualize the trained model

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Generate 200 inputs evenly spaced in the input range and reshape them into a 200 x 1 matrix (each row is an example)
xx = np.linspace(min(X), max(X), 200).reshape(-1, 1)

# Make prediction on these inputs
yy = model.predict(xx)

# Plot a line connecting these points
plt.plot(xx, yy, color='blue', label='prediction', linewidth=2)

# Plot train points with black color
plt.scatter(X_train, Y_train, color='black', label='train')

# Plot test points with red color
plt.scatter(X_test, Y_test, color='red', label='test')

# Add the legend
plt.legend()

# Show the plot
plt.show()