# Cross-Validation

![k-folds](images/k_folds.png)

## Import Data and Tools

In [2]:
from sklearn.model_selection import cross_val_score, KFold,\
    train_test_split
from seaborn import load_dataset
from sklearn.linear_model import LinearRegression
diamonds = load_dataset('diamonds')

In [3]:
diamonds.head()

Unnamed: 0,carat,cut,color,clarity,depth,table,price,x,y,z
0,0.23,Ideal,E,SI2,61.5,55.0,326,3.95,3.98,2.43
1,0.21,Premium,E,SI1,59.8,61.0,326,3.89,3.84,2.31
2,0.23,Good,E,VS1,56.9,65.0,327,4.05,4.07,2.31
3,0.29,Premium,I,VS2,62.4,58.0,334,4.2,4.23,2.63
4,0.31,Good,J,SI2,63.3,58.0,335,4.34,4.35,2.75


## Set Variables

In [4]:
X = diamonds.select_dtypes(include=float)
y = diamonds['price']

## Split

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    random_state=42)

## Linear Regression with Cross-Validation

In [6]:
lr = LinearRegression()

In [7]:
cross_val_score(lr, X_train, y_train, cv=3)

array([0.85771951, 0.84608522, 0.86076541])

## Building Splits with `KFold`

In [8]:
KFold(n_splits=3).split(X_train)

<generator object _BaseKFold.split at 0x12d8b48b8>

In [9]:
for fold in KFold(n_splits=3).split(X_train):
    train_indices, test_indices = fold
    print(train_indices, test_indices)

[13485 13486 13487 ... 40452 40453 40454] [    0     1     2 ... 13482 13483 13484]
[    0     1     2 ... 40452 40453 40454] [13485 13486 13487 ... 26967 26968 26969]
[    0     1     2 ... 26967 26968 26969] [26970 26971 26972 ... 40452 40453 40454]


## Passing into `cross_val_score()`

In [10]:
cross_val_score(lr, X_train, y_train, cv=KFold(n_splits=3))

array([0.85771951, 0.84608522, 0.86076541])

## Shuffling Before Splitting

In [11]:
cross_val_score(lr, X_train, y_train, cv=KFold(n_splits=10,
                                               shuffle=True,
                                               random_state=42))

array([0.78772743, 0.84104572, 0.86252499, 0.8744147 , 0.85274907,
       0.85587413, 0.86258267, 0.85832762, 0.86044553, 0.86060175])