# Imports

In [1]:
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets, linear_model
from scipy.linalg import svd
import sklearn.cluster as cluster
import plotly.express as px
import plotly.graph_objects as go

np.random.seed(1234)

In [2]:
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", None)
# mpl.rcParams.update({"axes.grid": True})

# 9.1 Polynomial Features

In [3]:
def get_vehicle_data():
    vehicle_data = pd.read_csv("../08/data/auto-mpg.csv")
    vehicle_data["horsepower"] = pd.to_numeric(
        vehicle_data["horsepower"], errors="coerce"
    )
    vehicle_data.rename(columns={"horsepower": "hp"}, inplace=True)
    vehicle_data.dropna(inplace=True)
    return vehicle_data

In [4]:
vehicle_data = get_vehicle_data()
vehicle_data.head()

Unnamed: 0,mpg,cylinders,displacement,hp,weight,acceleration,model year,origin,car name
0,18.0,8,307.0,130.0,3504,12.0,70,1,chevrolet chevelle malibu
1,15.0,8,350.0,165.0,3693,11.5,70,1,buick skylark 320
2,18.0,8,318.0,150.0,3436,11.0,70,1,plymouth satellite
3,16.0,8,304.0,150.0,3433,12.0,70,1,amc rebel sst
4,17.0,8,302.0,140.0,3449,10.5,70,1,ford torino


In [9]:
from sklearn.preprocessing import PolynomialFeatures

vehicle_data = get_vehicle_data()
poly_transform = PolynomialFeatures(degree=2, include_bias=False)
vehicle_data_with_squared_features = pd.DataFrame(
    poly_transform.fit_transform(vehicle_data[["hp", "weight", "displacement"]]),
    columns=poly_transform.get_feature_names_out(),
)

vehicle_data_with_squared_features

Unnamed: 0,hp,weight,displacement,hp^2,hp weight,hp displacement,weight^2,weight displacement,displacement^2
0,130.0,3504.0,307.0,16900.0,455520.0,39910.0,12278016.0,1075728.0,94249.0
1,165.0,3693.0,350.0,27225.0,609345.0,57750.0,13638249.0,1292550.0,122500.0
2,150.0,3436.0,318.0,22500.0,515400.0,47700.0,11806096.0,1092648.0,101124.0
3,150.0,3433.0,304.0,22500.0,514950.0,45600.0,11785489.0,1043632.0,92416.0
4,140.0,3449.0,302.0,19600.0,482860.0,42280.0,11895601.0,1041598.0,91204.0
...,...,...,...,...,...,...,...,...,...
387,86.0,2790.0,140.0,7396.0,239940.0,12040.0,7784100.0,390600.0,19600.0
388,52.0,2130.0,97.0,2704.0,110760.0,5044.0,4536900.0,206610.0,9409.0
389,84.0,2295.0,135.0,7056.0,192780.0,11340.0,5267025.0,309825.0,18225.0
390,79.0,2625.0,120.0,6241.0,207375.0,9480.0,6890625.0,315000.0,14400.0
