# Predict North America Sales with Linear Regression

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

df = pd.read_csv('vgsales.csv', sep=',')
df.head()

Unnamed: 0,Rank,Name,Platform,Year,Genre,Publisher,NA_Sales,EU_Sales,JP_Sales,Other_Sales,Global_Sales
0,1,Wii Sports,Wii,2006.0,Sports,Nintendo,41.49,29.02,3.77,8.46,82.74
1,2,Super Mario Bros.,NES,1985.0,Platform,Nintendo,29.08,3.58,6.81,0.77,40.24
2,3,Mario Kart Wii,Wii,2008.0,Racing,Nintendo,15.85,12.88,3.79,3.31,35.82
3,4,Wii Sports Resort,Wii,2009.0,Sports,Nintendo,15.75,11.01,3.28,2.96,33.0
4,5,Pokemon Red/Pokemon Blue,GB,1996.0,Role-Playing,Nintendo,11.27,8.89,10.22,1.0,31.37


In [3]:
numeric_df = df.select_dtypes(include=['number'])
correlation_matrix = numeric_df.corr()

# Format the correlation matrix to set precision
formatted_correlation = correlation_matrix.style.background_gradient(cmap='coolwarm', axis=None).format(precision=2)

# Display the formatted correlation matrix
formatted_correlation

Unnamed: 0,Rank,Year,NA_Sales,EU_Sales,JP_Sales,Other_Sales,Global_Sales
Rank,1.0,0.18,-0.4,-0.38,-0.27,-0.33,-0.43
Year,0.18,1.0,-0.09,0.01,-0.17,0.04,-0.07
NA_Sales,-0.4,-0.09,1.0,0.77,0.45,0.63,0.94
EU_Sales,-0.38,0.01,0.77,1.0,0.44,0.73,0.9
JP_Sales,-0.27,-0.17,0.45,0.44,1.0,0.29,0.61
Other_Sales,-0.33,0.04,0.63,0.73,0.29,1.0,0.75
Global_Sales,-0.43,-0.07,0.94,0.9,0.61,0.75,1.0


In [None]:
# filter out inputs that have the highest correlation with NA_Sales
X = df[["EU_Sales"]] # (16597, 2) Matrix (Inputs)
y = df[["NA_Sales"]] # (16597, 1) Vector (Outputs)
X.shape

In [None]:
# perform the regression using scikit-learn
from sklearn import linear_model
regr = linear_model.LinearRegression()
regr.fit(X, y)

In [None]:
# the theta's
print('theta_0', regr.intercept_, 'theta_1', regr.coef_)

# perform prediction on the complete training set
pred = regr.predict(X)

# visualize how the model fits the training data
plt.scatter(X, y)
plt.plot(X, pred, color="r")
plt.xlabel('Population of City in 10,000s')
plt.ylabel('Profit in $10,000s')
plt.title('Profit as a function of city population')
plt.show()