In [1]:
import re
import string

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import FunctionTransformer, KBinsDiscretizer, OneHotEncoder, OrdinalEncoder, StandardScaler, TargetEncoder

In [3]:
target = "Survived"
train = pd.read_csv("../data/raw/train.csv", index_col="PassengerId")
X_test = pd.read_csv("../data/raw/test.csv", index_col="PassengerId")

y_train = train[target]
X_train = train.drop(columns=target)

In [4]:
X_train

Unnamed: 0_level_0,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.2500,,S
2,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
3,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.9250,,S
4,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1000,C123,S
5,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.0500,,S
...,...,...,...,...,...,...,...,...,...,...
887,2,"Montvila, Rev. Juozas",male,27.0,0,0,211536,13.0000,,S
888,1,"Graham, Miss. Margaret Edith",female,19.0,0,0,112053,30.0000,B42,S
889,3,"Johnston, Miss. Catherine Helen ""Carrie""",female,,1,2,W./C. 6607,23.4500,,S
890,1,"Behr, Mr. Karl Howell",male,26.0,0,0,111369,30.0000,C148,C


In [5]:
class WithinClassImputer(TransformerMixin, BaseEstimator):
    def __init__(self, target_col, *, missing_values=np.nan, class_cols=None, strategy="mean"):
        self.target_col = target_col
        self.missing_values = missing_values
        self.class_cols = class_cols or []
        self.strategy = strategy

    @staticmethod
    def _impute(x, *, statistics, missing_values):
        if x.isnull().sum() == 0:
            return x
        
        try:
            result = x.replace(missing_values, statistics[x.name])
        except IndexError:
            result = x.replace(missing_values, statistics)
        return result

    @staticmethod
    def _group_by_class_cols(x, *, class_cols, target_col):
        if len(class_cols) == 0:
            return x[target_col]
        else:
            return x.groupby(class_cols)[target_col]

    def fit(self, X, y=None, **fit_params):
        grouped = self._group_by_class_cols(X, class_cols=self.class_cols, target_col=self.target_col)
        
        if self.strategy == "mean":
            statistics = grouped.mean()
        elif self.strategy == "median":
            statistics = grouped.median()
        else:
            raise ValueError("Invalid strategy")

        try:
            self.statistics_ = statistics.to_dict()
        except AttributeError:
            self.statistics_ = statistics
        return self

    def transform(self, X):
        grouped = self._group_by_class_cols(X, class_cols=self.class_cols, target_col=self.target_col)
        X = grouped.transform(self._impute, statistics=self.statistics_, missing_values=self.missing_values)
        return np.expand_dims(X, axis=1)

    def get_feature_names_out(self, input_features=None):
        return [self.target_col]

In [6]:
def cabin_to_deck(x):
    out = (
        x.apply(lambda s: s[0] if pd.notnull(s) else "M")
            .replace(["A", "B", "C", "T"], "ABC")
            .replace(["D", "E"], "DE")
            .replace(["F", "G"], "FG")
    )
    return out.to_frame(name="Deck")

In [7]:
def calc_family_size(x):
    out = x["SibSp"] + x["Parch"] + 1
    out = pd.cut(out, bins=[0, 1, 4, 6, np.inf], labels=["Alone", "Small", "Medium", "Large"])
    return out.to_frame(name="FamilySizeGroup")

In [8]:
def create_title(x):
    out = (
        x.str.split(", ", expand=True)[1]
            .str.split(".", expand=True)[0]
            .replace(["Miss", "Mrs", "Ms", "Mlle", "Lady", "Mme", "the Countess", "Dona"], "Mrs/Ms/Miss")
            .replace(["Dr", "Col", "Major", "Jonkheer", "Capt", "Sir", "Don", "Rev"], "Dr/Military/Noble/Clergy")
    )
    return out.to_frame(name="Title")

In [9]:
def create_family(x):
    out = (
        x.str.split(", ", expand=True)[0]
            .str.replace(f"[{re.escape(string.punctuation)}]", "", regex=True)
    )
    return out.to_frame(name="Family")

In [10]:
imputation = ColumnTransformer(    [
        ("impute-median-age", WithinClassImputer("Age", class_cols=["Sex", "Pclass"], strategy="median"), ["Age", "Sex", "Pclass"]),
        ("impute-median-fare", WithinClassImputer("Fare", class_cols=["Pclass", "SibSp", "Parch"], strategy="median"), ["Fare", "Pclass", "SibSp", "Parch"]),
        ("impute-mode-embarked", SimpleImputer(strategy="most_frequent"), ["Embarked"]),
        ("pass", "passthrough", ["Sex", "Pclass", "SibSp", "Parch"]),
    ],
    remainder="passthrough",
    verbose_feature_names_out=False,
    force_int_remainder_cols=False,
).set_output(transform="pandas")

In [11]:
creation = ColumnTransformer(
    [
        ("create-deck", FunctionTransformer(cabin_to_deck), "Cabin"),
        ("create-family", FunctionTransformer(create_family), "Name"),
        ("create-family-size", FunctionTransformer(calc_family_size), ["SibSp", "Parch"]),
        ("create-title", FunctionTransformer(create_title), "Name"),
    ],
    remainder="passthrough",
    verbose_feature_names_out=False,
    force_int_remainder_cols=False,
).set_output(transform="pandas")

In [12]:
discretization = ColumnTransformer(
    [
        ("discretize-age", KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile"), ["Age"]),
        ("discretize-fare", KBinsDiscretizer(n_bins=13, encode="ordinal", strategy="quantile"), ["Fare"]),
    ],
    remainder="passthrough",
    verbose_feature_names_out=False,
    force_int_remainder_cols=False,
).set_output(transform="pandas")

In [13]:
encoding = ColumnTransformer(
    [
        ("encode-family", TargetEncoder(target_type="binary"), ["Family"]),
        ("encode-ordinal", OrdinalEncoder(), ["Pclass", "Age", "Fare"]),
        ("encode-onehot", OneHotEncoder(handle_unknown="warn", sparse_output=False), ["Sex", "Deck", "Embarked", "Title", "FamilySizeGroup"]),
    ],
    verbose_feature_names_out=False,
    force_int_remainder_cols=False,
).set_output(transform="pandas")

In [None]:
model = RandomForestClassifier(
    criterion='gini', 
    n_estimators=1750,
    max_depth=7,
    min_samples_split=6,
    min_samples_leaf=6,
    max_features='sqrt',
    oob_score=True,
    random_state=1234,
    n_jobs=-1,
)

In [14]:
pipe = Pipeline(
    [
        ("imputation", imputation),
        ("creation", creation),
        ("discretization", discretization),
        ("encoding", encoding),
        ("scaling", StandardScaler()),
    ],
).set_output(transform="pandas")

In [15]:
pipe

In [16]:
pipe.fit_transform(X_train, y_train)

Unnamed: 0_level_0,Family,Pclass,Age,Fare,Sex_female,Sex_male,Deck_ABC,Deck_DE,Deck_FG,Deck_M,...,Embarked_Q,Embarked_S,Title_Dr/Military/Noble/Clergy,Title_Master,Title_Mr,Title_Mrs/Ms/Miss,FamilySizeGroup_Alone,FamilySizeGroup_Large,FamilySizeGroup_Medium,FamilySizeGroup_Small
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,-1.489517,0.827377,-0.576950,-1.419604,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,0.850532,-0.735882,-1.231645,-0.169907,-0.208148,1.432260
2,-0.019876,-1.566107,0.827592,1.331303,1.355574,-1.355574,2.510633,-0.280522,-0.139466,-1.835115,...,-0.307562,-1.623803,-0.155364,-0.216803,-1.175735,1.358913,-1.231645,-0.169907,-0.208148,1.432260
3,-0.019876,0.827377,0.125321,-0.869422,1.355574,-1.355574,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,-1.175735,1.358913,0.811922,-0.169907,-0.208148,-0.698197
4,-1.489517,-1.566107,0.827592,1.056212,1.355574,-1.355574,2.510633,-0.280522,-0.139466,-1.835115,...,-0.307562,0.615838,-0.155364,-0.216803,-1.175735,1.358913,-1.231645,-0.169907,-0.208148,1.432260
5,2.334767,0.827377,0.827592,-0.594332,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
887,-0.019876,-0.369365,0.125321,-0.044150,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,6.436503,-0.216803,-1.175735,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
888,2.334767,-1.566107,-1.279221,0.781122,1.355574,-1.355574,2.510633,-0.280522,-0.139466,-1.835115,...,-0.307562,0.615838,-0.155364,-0.216803,-1.175735,1.358913,0.811922,-0.169907,-0.208148,-0.698197
889,-1.489517,0.827377,-0.928085,0.230940,1.355574,-1.355574,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,-1.175735,1.358913,-1.231645,-0.169907,-0.208148,1.432260
890,-0.019876,-1.566107,0.125321,0.781122,-0.737695,0.737695,2.510633,-0.280522,-0.139466,-1.835115,...,-0.307562,-1.623803,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197


In [19]:
pipe.transform(X_test)

Unnamed: 0_level_0,Family,Pclass,Age,Fare,Sex_female,Sex_male,Deck_ABC,Deck_DE,Deck_FG,Deck_M,...,Embarked_Q,Embarked_S,Title_Dr/Military/Noble/Clergy,Title_Master,Title_Mr,Title_Mrs/Ms/Miss,FamilySizeGroup_Alone,FamilySizeGroup_Large,FamilySizeGroup_Medium,FamilySizeGroup_Small
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
892,1.147067,0.827377,0.827592,-1.144513,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,3.251373,-1.623803,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
893,-0.021610,0.827377,1.529863,-1.694695,1.355574,-1.355574,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,-1.175735,1.358913,-1.231645,-0.169907,-0.208148,1.432260
894,-0.021610,-0.369365,1.529863,-0.594332,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,3.251373,-1.623803,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
895,-0.021610,0.827377,0.125321,-0.594332,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
896,2.334767,0.827377,-0.576950,-0.319241,1.355574,-1.355574,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,-1.175735,1.358913,-1.231645,-0.169907,-0.208148,1.432260
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1305,-0.021610,0.827377,-0.225814,-0.594332,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
1306,-0.021610,-1.566107,0.827592,1.606394,1.355574,-1.355574,2.510633,-0.280522,-0.139466,-1.835115,...,-0.307562,-1.623803,-0.155364,-0.216803,-1.175735,1.358913,0.811922,-0.169907,-0.208148,-0.698197
1307,-0.021610,0.827377,0.827592,-1.419604,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
1308,-0.021610,0.827377,-0.225814,-0.594332,-0.737695,0.737695,-0.398306,-0.280522,-0.139466,0.544925,...,-0.307562,0.615838,-0.155364,-0.216803,0.850532,-0.735882,0.811922,-0.169907,-0.208148,-0.698197
