In [134]:
%load_ext autoreload
%autoreload 2

# Autoreload is important, otherwise .py scripts won't be reloaded after changes.

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import data_loader

from sklearn.model_selection import  train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

import common.columns as columns
from common.plots import plot_survived_by_category
from common.selectors import get_survived_counts

from features.title_adder import TitleAdder
from features.person_type_adder import (
  PersonTypeAdder,
  CHILD_TYPE,
  MAN_TYPE,
  WOMAN_TYPE
)

from features.column_dropper import ColumnDropper

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [135]:
loader = data_loader.DataLoader()

train_set, _ = loader.get_data()

labels = train_set[[columns.SURVIVED]]
train_set = train_set.drop(columns.SURVIVED, axis=1)

In [136]:
X_train, X_test, y_train, y_test = train_test_split(train_set, labels, random_state=42)

In [137]:
numerical_pipeline = Pipeline([
  ('imputer', SimpleImputer(strategy='mean')),
  ('scaler', StandardScaler())
])

categorical_pipeline = Pipeline([
  ('imputer', SimpleImputer(strategy='constant', fill_value='N/A')),
  ('encoder', OneHotEncoder(handle_unknown='ignore', sparse=False))
])

trasformer = ColumnTransformer([
  ('numerical', numerical_pipeline, [columns.AGE, columns.FARE]),
  ('categorical', categorical_pipeline, [columns.SEX, columns.EMBARKED])
], remainder='passthrough')

preparation_pipeline = Pipeline([
  ('column_dropper', ColumnDropper([columns.CABIN, columns.PASSENGER_ID, columns.NAME, columns.TICKET])),
  ('column_transformer', trasformer)
])

In [138]:
X_train_prepared = preparation_pipeline.fit_transform(X_train)

pd.DataFrame(X_train_prepared).head(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,0.0,-0.032568,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0
1,-0.34011,-0.487331,0.0,1.0,0.0,0.0,0.0,1.0,3.0,0.0,0.0
2,-0.417034,-0.342854,1.0,0.0,0.0,0.0,0.0,1.0,2.0,0.0,2.0
3,-0.570884,-0.478201,0.0,1.0,0.0,0.0,0.0,1.0,3.0,0.0,0.0
4,-2.192453,2.314937,0.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,2.0
5,-0.417034,-0.119836,1.0,0.0,0.0,0.0,0.0,1.0,2.0,1.0,0.0
6,-0.263185,-0.470362,1.0,0.0,0.0,0.0,0.0,1.0,3.0,0.0,0.0
7,0.352211,-0.2168,0.0,1.0,0.0,0.0,0.0,1.0,2.0,1.0,0.0
8,-0.647808,4.464151,1.0,0.0,1.0,0.0,0.0,0.0,1.0,2.0,2.0
9,0.0,-0.473756,0.0,1.0,0.0,0.0,1.0,0.0,3.0,0.0,0.0
