In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.experimental import (
    enable_iterative_imputer,
)
from sklearn import (
    preprocessing,
    impute,
)
from sklearn.model_selection import (
    train_test_split,
)

In [2]:
url = ("datasets/titanic3.xls")
df = pd.read_excel(url)
orig_df = df

In [3]:
def tweak_titanic(df):
    df = df.drop(
        columns=[
            "name",
            "ticket",
            "home.dest",
            "boat",
            "body",
            "cabin",
        ]
    ).pipe(pd.get_dummies, drop_first=True)
    
    return df


def get_train_test_X_y(
    df, y_col, size=0.3, std_cols=None
):
    y = df[y_col]
    X = df.drop(columns=y_col)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=size, random_state=42
    )
    cols = X.columns
    num_cols = [
        "pclass",
        "age",
        "sibsp",
        "parch",
        "fare"
    ]
    fi = impute.IterativeImputer()
    X_train.loc[
        :, num_cols
    ] = fi.fit_transform(X_train[num_cols])
    X_test.loc[:, num_cols] = fi.transform(
        X_test[num_cols]
    )

    if std_cols:
        std = preprocessing.StandardScaler()
        X_train.loc[
            :, std_cols
        ] = std.fit_transform(
            X_train[std_cols]
        )
        X_test.loc[
            :, std_cols
        ] = std.transform(X_test[std_cols])

    return X_train, X_test, y_train, y_test

In [4]:
ti_df = tweak_titanic(orig_df)
std_cols = "pclass,age,sibsp,fare".split(",")
X_train, X_test, y_train, y_test = get_train_test_X_y(
    ti_df, "survived", std_cols=std_cols
)

In [5]:
X_train

Unnamed: 0,pclass,age,sibsp,parch,fare,sex_male,embarked_Q,embarked_S
1214,0.825248,-0.167248,-0.498616,0,-0.473625,True,False,True
677,0.825248,-0.205255,-0.498616,0,-0.488146,True,False,True
534,-0.363317,-0.751526,-0.498616,0,-0.145246,False,False,True
1174,0.825248,-2.153148,6.897852,2,0.679608,False,False,True
864,0.825248,-0.049178,-0.498616,0,-0.490434,False,False,True
...,...,...,...,...,...,...,...,...
1095,0.825248,-0.166508,-0.498616,0,-0.493196,False,True,False
1130,0.825248,-0.829564,-0.498616,0,-0.490434,False,False,True
1294,0.825248,-0.010159,-0.498616,0,-0.332756,True,False,True
860,0.825248,-0.205255,-0.498616,0,-0.487593,False,False,True
