In [2]:
import sagemaker
import os
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder

In [3]:
session = sagemaker.session.Session()
region = session.boto_region_name
role = sagemaker.get_execution_role()
# bucket = session.default_bucket()
bucket = 'asurion-ml-experimentation'
prefix = 'custom_preprocessing'

In [None]:
tags = [
    {"Key": "PLATFORM", "Value": "FO-ML"},
    {"Key": "BUSINESS_REGION", "Value": "GLOBAL"},
    {"Key": "BUSINESS_UNIT", "Value": "MOBILITY"},
    {"Key": "CLIENT", "Value": "MULTI_TENANT"}
   ]

In [5]:
df = pd.read_csv('s3://asurion-ml-experimentation/custom_preprocessing/data/1_full_data.csv')
df.head()

Unnamed: 0,zip_agg_customer_subtype,zip_agg_number_of_houses,zip_agg_avg_size_household,zip_agg_avg_age,zip_agg_customer_main_type,zip_agg_roman_catholic,zip_agg_protestant,zip_agg_other_religion,zip_agg_no_religion,zip_agg_married,...,nbr_private_accident_ins_policies,nbr_family_accidents_ins_policies,nbr_disability_ins_policies,nbr_fire_policies,nbr_surfboard_policies,nbr_boat_policies,nbr_bicycle_policies,nbr_property_ins_policies,nbr_ss_ins_policies,nbr_mobile_home_policies
0,Lower class large families,1,3,2,Family with grown ups,0,5,1,3,7,...,0,0,0,1,0,0,0,0,0,0
1,Mixed small town dwellers,1,2,2,Family with grown ups,1,4,1,4,6,...,0,0,0,1,0,0,0,0,0,0
2,Mixed small town dwellers,1,2,2,Family with grown ups,0,4,2,4,3,...,0,0,0,1,0,0,0,0,0,0
3,"Modern, complete families",1,3,3,Average Family,2,3,2,4,5,...,0,0,0,1,0,0,0,0,0,0
4,Large family farms,1,4,2,Farmers,1,4,1,4,7,...,0,0,0,1,0,0,0,0,0,0


In [6]:
label_col = 'nbr_mobile_home_policies'
cat_feats = ['zip_agg_customer_subtype', 'zip_agg_customer_main_type']
# INPUT_FEATURES_SIZE = 85
nbr_cols = df.shape[1]

In [7]:
pd.set_option('max_rows', 500)

df.isna().sum()

zip_agg_customer_subtype                 0
zip_agg_number_of_houses                 0
zip_agg_avg_size_household               0
zip_agg_avg_age                          0
zip_agg_customer_main_type               0
zip_agg_roman_catholic                   0
zip_agg_protestant                       0
zip_agg_other_religion                   0
zip_agg_no_religion                      0
zip_agg_married                          0
zip_agg_living_together                  0
zip_agg_other_relation                   0
zip_agg_singles                          0
zip_agg_household_without_children       0
zip_agg_household_with_children          0
zip_agg_high_level_education             0
zip_agg_medium_level_education           0
zip_agg_lower_level_education            0
zip_agg_high_status                      0
zip_agg_entrepreneur                     0
zip_agg_farmer                           0
zip_agg_middle_management                0
zip_agg_skilled_labourers                0
zip_agg_uns

In [8]:
for col in df.columns.difference(cat_feats):
    print(col)
    display(df[col].value_counts())
    print('\n')

contri_agricultural_machines_policies


0    9790
4      11
3       8
2       7
6       5
1       1
Name: contri_agricultural_machines_policies, dtype: int64



contri_bicycle_policies


0    9573
1     249
Name: contri_bicycle_policies, dtype: int64



contri_boat_policies


0    9777
4      18
2       8
1       6
3       6
5       4
6       3
Name: contri_boat_policies, dtype: int64



contri_car_policies


0    4825
6    3910
5    1013
7      64
8       5
4       4
9       1
Name: contri_car_policies, dtype: int64



contri_delivery_van_policies


0    9730
6      71
5      17
7       4
Name: contri_delivery_van_policies, dtype: int64



contri_disability_ins_policies


0    9784
6      32
7       4
4       1
5       1
Name: contri_disability_ins_policies, dtype: int64



contri_family_accidents_ins_policies


0    9744
2      50
3      28
Name: contri_family_accidents_ins_policies, dtype: int64



contri_fire_policies


0    4464
4    2142
3    1541
2     901
5     263
6     252
1     245
7      12
8       2
Name: contri_fire_policies, dtype: int64



contri_life_ins


0    9308
4     172
3     141
5      65
6      65
2      51
1      14
7       4
9       1
8       1
Name: contri_life_ins, dtype: int64



contri_lorry_policies


0    9808
6      10
7       2
4       1
9       1
Name: contri_lorry_policies, dtype: int64



contri_moped_policies


0    9150
3     481
4      99
2      63
5      27
6       2
Name: contri_moped_policies, dtype: int64



contri_motorcycle/scooter_policies


0    9460
4     207
6      79
5      70
3       4
7       2
Name: contri_motorcycle/scooter_policies, dtype: int64



contri_private_accident_ins_policies


0    9777
2      24
3       7
1       6
4       4
5       2
6       2
Name: contri_private_accident_ins_policies, dtype: int64



contri_private_third_party_ins


0    5903
2    3562
1     341
3      16
Name: contri_private_third_party_ins, dtype: int64



contri_property_ins_policies


0    9740
1      34
2      30
4       8
3       7
6       2
5       1
Name: contri_property_ins_policies, dtype: int64



contri_ss_ins_policies


0    9687
4      68
3      34
2      31
5       2
Name: contri_ss_ins_policies, dtype: int64



contri_surfboard_policies


0    9813
2       5
1       3
3       1
Name: contri_surfboard_policies, dtype: int64



contri_third_party_ins_(agriculture)


0    9613
4     108
3      92
2       7
1       2
Name: contri_third_party_ins_(agriculture), dtype: int64



contri_third_party_ins_(firms)


0    9688
2      53
3      36
4      26
1       9
6       5
5       5
Name: contri_third_party_ins_(firms), dtype: int64



contri_tractor_policies


0    9576
3     142
4      45
5      42
6      16
7       1
Name: contri_tractor_policies, dtype: int64



contri_trailer_policies


0    9719
2      62
1      30
3       9
5       1
4       1
Name: contri_trailer_policies, dtype: int64



nbr_agricultural_machines_policies


0    9790
1      21
2       7
3       2
6       1
4       1
Name: nbr_agricultural_machines_policies, dtype: int64



nbr_bicycle_policies


0    9573
1     193
2      53
3       2
4       1
Name: nbr_bicycle_policies, dtype: int64



nbr_boat_policies


0    9777
1      40
2       5
Name: nbr_boat_policies, dtype: int64



nbr_car_policies


0     4825
1     4580
2      384
3       21
4        8
7        1
6        1
12       1
5        1
Name: nbr_car_policies, dtype: int64



nbr_delivery_van_policies


0    9730
1      83
2       4
3       3
4       1
5       1
Name: nbr_delivery_van_policies, dtype: int64



nbr_disability_ins_policies


0    9784
1      34
2       4
Name: nbr_disability_ins_policies, dtype: int64



nbr_family_accidents_ins_policies


0    9744
1      78
Name: nbr_family_accidents_ins_policies, dtype: int64



nbr_fire_policies


1    5116
0    4464
2     221
3      11
4       6
5       2
7       1
6       1
Name: nbr_fire_policies, dtype: int64



nbr_life_ins


0    9308
1     305
2     170
3      23
4      13
5       2
8       1
Name: nbr_life_ins, dtype: int64



nbr_lorry_policies


0    9808
1       9
2       3
3       1
4       1
Name: nbr_lorry_policies, dtype: int64



nbr_mobile_home_policies


0    9236
1     586
Name: nbr_mobile_home_policies, dtype: int64



nbr_moped_policies


0    9150
1     647
2      24
3       1
Name: nbr_moped_policies, dtype: int64



nbr_motorcycle/scooter_policies


0    9460
1     337
2      22
3       2
8       1
Name: nbr_motorcycle/scooter_policies, dtype: int64



nbr_private_accident_ins_policies


0    9777
1      45
Name: nbr_private_accident_ins_policies, dtype: int64



nbr_private_third_party_ins


0    5903
1    3909
2      10
Name: nbr_private_third_party_ins, dtype: int64



nbr_property_ins_policies


0    9740
1      81
2       1
Name: nbr_property_ins_policies, dtype: int64



nbr_ss_ins_policies


0    9687
1     134
2       1
Name: nbr_ss_ins_policies, dtype: int64



nbr_surfboard_policies


0    9813
1       9
Name: nbr_surfboard_policies, dtype: int64



nbr_third_party_ins_(agriculture)


0    9613
1     209
Name: nbr_third_party_ins_(agriculture), dtype: int64



nbr_third_party_ins_(firms)


0    9688
1     133
5       1
Name: nbr_third_party_ins_(firms), dtype: int64



nbr_tractor_policies


0    9576
1     184
2      46
3       7
4       6
6       2
5       1
Name: nbr_tractor_policies, dtype: int64



nbr_trailer_policies


0    9719
1      96
2       5
3       2
Name: nbr_trailer_policies, dtype: int64



zip_agg_1_car


6    2822
7    2338
5    2106
9     829
4     740
8     435
3     400
2     102
0      30
1      20
Name: zip_agg_1_car, dtype: int64



zip_agg_2_cars


0    3078
2    2999
1    2454
3     638
4     531
5     105
6      14
7       2
9       1
Name: zip_agg_2_cars, dtype: int64



zip_agg_average_income


3    3232
4    3063
5    1268
2    1110
6     646
7     228
8     121
1      78
9      38
0      38
Name: zip_agg_average_income, dtype: int64



zip_agg_avg_age


3    5154
2    2409
4    1777
5     329
1     104
6      49
Name: zip_agg_avg_age, dtype: int64



zip_agg_avg_size_household


3    4513
2    3616
4    1132
1     452
5     106
6       3
Name: zip_agg_avg_size_household, dtype: int64



zip_agg_entrepreneur


0    7031
1    2009
2     600
5      94
3      70
4      18
Name: zip_agg_entrepreneur, dtype: int64



zip_agg_farmer


0    6985
1    1462
2     815
3     256
4     144
5     108
6      21
8      15
9      13
7       3
Name: zip_agg_farmer, dtype: int64



zip_agg_high_level_education


0    3621
1    2176
2    1921
3     927
4     577
5     338
6     123
7      83
8      39
9      17
Name: zip_agg_high_level_education, dtype: int64



zip_agg_high_status


0    2576
2    2278
1    2119
3    1282
4     641
5     415
6     248
7     159
9      57
8      47
Name: zip_agg_high_status, dtype: int64



zip_agg_home_owners


9    1663
0    1255
7    1185
6     978
5     912
1     871
4     834
8     756
2     712
3     656
Name: zip_agg_home_owners, dtype: int64



zip_agg_household_with_children


4    1983
5    1869
3    1596
6    1322
2    1096
7     601
1     475
8     341
9     296
0     243
Name: zip_agg_household_with_children, dtype: int64



zip_agg_household_without_children


3    2517
4    2493
2    1798
5    1043
0     613
1     608
6     516
7     160
9      47
8      27
Name: zip_agg_household_without_children, dtype: int64



zip_agg_income_30-45.000


4    2207
3    1989
5    1585
2    1569
0     782
6     671
1     486
7     337
9     134
8      62
Name: zip_agg_income_30-45.000, dtype: int64



zip_agg_income_45-75.000


3    2028
2    1962
4    1709
0    1505
1    1132
5     828
6     232
7     170
9     161
8      95
Name: zip_agg_income_45-75.000, dtype: int64



zip_agg_income_75-122.000


0    5464
1    2251
2    1256
3     429
4     278
5     119
6      10
8       7
9       6
7       2
Name: zip_agg_income_75-122.000, dtype: int64



zip_agg_income_<_30.000


0    2164
2    1893
3    1826
1    1060
4     981
5     971
6     512
7     261
9      78
8      76
Name: zip_agg_income_<_30.000, dtype: int64



zip_agg_income_>123.000


0    8253
1    1269
2     188
3      66
4      38
5       4
6       2
9       1
7       1
Name: zip_agg_income_>123.000, dtype: int64



zip_agg_living_together


0    4185
1    3402
2    1790
3     260
4     133
5      32
6      18
7       2
Name: zip_agg_living_together, dtype: int64



zip_agg_lower_level_education


5    1740
6    1472
4    1448
3    1128
2    1099
7    1090
9     531
0     494
8     424
1     396
Name: zip_agg_lower_level_education, dtype: int64



zip_agg_married


7    2800
6    2015
5    1628
9    1345
8     603
4     550
3     402
2     252
1     119
0     108
Name: zip_agg_married, dtype: int64



zip_agg_medium_level_education


4    2394
3    2332
2    1648
5    1227
0     711
1     624
6     526
7     243
9      64
8      53
Name: zip_agg_medium_level_education, dtype: int64



zip_agg_middle_management


2    2508
3    2348
4    1573
0    1164
5     726
1     699
6     348
7     295
9     137
8      24
Name: zip_agg_middle_management, dtype: int64



zip_agg_national_health_service


7    2520
5    1644
6    1451
9    1436
8    1178
4     628
2     518
3     315
0     109
1      23
Name: zip_agg_national_health_service, dtype: int64



zip_agg_no_car


2    2611
0    2475
3    1871
1    1327
4    1020
5     288
6     145
7      40
9      27
8      18
Name: zip_agg_no_car, dtype: int64



zip_agg_no_religion


3    2476
4    2245
2    1778
5    1572
0     773
6     394
1     378
7     179
9      14
8      13
Name: zip_agg_no_religion, dtype: int64



zip_agg_number_of_houses


1     8915
2      821
3       64
7        8
4        4
5        3
6        3
10       2
8        2
Name: zip_agg_number_of_houses, dtype: int64



zip_agg_other_relation


2    2944
0    1981
3    1965
4    1140
1     900
5     421
6     299
7     100
9      41
8      31
Name: zip_agg_other_relation, dtype: int64



zip_agg_other_religion


0    3460
1    3391
2    2294
3     415
4     220
5      42
Name: zip_agg_other_religion, dtype: int64



zip_agg_private_health_insurance


2    2520
4    1668
0    1436
3    1415
1    1178
5     639
7     518
6     316
9     109
8      23
Name: zip_agg_private_health_insurance, dtype: int64



zip_agg_protestant


4    2666
5    2533
6    1180
3    1022
7     995
2     666
9     290
1     225
0     127
8     118
Name: zip_agg_protestant, dtype: int64



zip_agg_purchasing_power_class


3    2556
6    1587
4    1539
5     964
1     938
7     777
2     731
8     730
Name: zip_agg_purchasing_power_class, dtype: int64



zip_agg_rented_house


0    1663
9    1255
2    1177
3     961
4     908
8     874
5     863
1     755
7     710
6     656
Name: zip_agg_rented_house, dtype: int64



zip_agg_roman_catholic


0    5420
1    2744
2    1213
3     243
4     123
5      30
6      25
7      11
9      10
8       3
Name: zip_agg_roman_catholic, dtype: int64



zip_agg_singles


0    2916
2    2143
1    1619
3    1439
4     890
5     416
6     222
7     107
8      36
9      34
Name: zip_agg_singles, dtype: int64



zip_agg_skilled_labourers


2    2327
0    1995
3    1967
1    1523
4    1006
5     524
6     299
7     119
8      38
9      24
Name: zip_agg_skilled_labourers, dtype: int64



zip_agg_social_class_a


0    2871
1    2626
2    2056
3    1168
4     452
5     228
6     152
7     147
9     106
8      16
Name: zip_agg_social_class_a, dtype: int64



zip_agg_social_class_b1


2    3009
1    2549
0    2275
3    1300
4     459
5     129
6      53
9      26
8      13
7       9
Name: zip_agg_social_class_b1, dtype: int64



zip_agg_social_class_b2


2    2778
3    2025
0    1694
1    1434
4    1104
5     588
6     176
8      11
7       9
9       3
Name: zip_agg_social_class_b2, dtype: int64



zip_agg_social_class_c


5    1953
4    1929
3    1845
2    1468
6     800
0     634
1     478
7     396
9     202
8     117
Name: zip_agg_social_class_c, dtype: int64



zip_agg_social_class_d


0    4376
1    2658
2    1434
3     757
4     378
5     158
6      37
7      22
9       1
8       1
Name: zip_agg_social_class_d, dtype: int64



zip_agg_unskilled_labourers


2    2460
3    1817
1    1685
0    1636
4    1289
5     564
6     216
7     101
9      37
8      17
Name: zip_agg_unskilled_labourers, dtype: int64





Numeric features have already been standardized or bucketed to 0-10.

In [9]:
train, test = train_test_split(df, test_size = .1, random_state=12, stratify=df[label_col])

In [10]:
train_X = train.iloc[:, :nbr_cols-1]
train_y = train.iloc[:, -1]

In [74]:
col_transformer = ColumnTransformer([
        ('encoder', OneHotEncoder(), cat_feats)],
    remainder='passthrough')

processed_df = col_transformer.fit(train_X, train_y)

Column order => first transformer (one hot encoding), passthrough columns

In [16]:
processed_df[:2]

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 1., 2., 2., 1., 0., 2., 6., 5., 1., 3., 1., 4., 5., 1., 5.,
        4., 0., 0., 1., 1., 3., 6., 0., 1., 1., 8., 1., 9., 0., 4., 1.,
        4., 9., 0., 1., 5., 1., 3., 0., 5., 2., 0., 0., 0., 6., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 1., 2., 5., 2., 5., 2., 1., 3., 0., 6., 6., 2., 2., 0., 2.,
        7., 2., 0., 1., 2., 1., 5., 0., 2., 3., 5., 1., 2., 7., 6., 0.,
        3., 9., 0., 3., 6., 1., 0., 0.,

In [31]:
train_X.head(2)

Unnamed: 0,zip_agg_customer_subtype,zip_agg_number_of_houses,zip_agg_avg_size_household,zip_agg_avg_age,zip_agg_customer_main_type,zip_agg_roman_catholic,zip_agg_protestant,zip_agg_other_religion,zip_agg_no_religion,zip_agg_married,...,nbr_life_ins,nbr_private_accident_ins_policies,nbr_family_accidents_ins_policies,nbr_disability_ins_policies,nbr_fire_policies,nbr_surfboard_policies,nbr_boat_policies,nbr_bicycle_policies,nbr_property_ins_policies,nbr_ss_ins_policies
7038,"Young, low educated",1,2,2,Living well,1,0,2,6,5,...,0,0,0,0,0,0,0,0,0,0
1530,Religious elderly singles,1,2,5,Retired and Religeous,2,5,2,1,3,...,0,0,0,0,0,0,0,0,0,0


In [17]:
col_transformer

ColumnTransformer(n_jobs=None, remainder='passthrough', sparse_threshold=0.3,
                  transformer_weights=None,
                  transformers=[('encoder',
                                 OneHotEncoder(categories='auto', drop=None,
                                               dtype=<class 'numpy.float64'>,
                                               handle_unknown='error',
                                               sparse=True),
                                 ['zip_agg_customer_subtype',
                                  'zip_agg_customer_main_type'])],
                  verbose=False)

In [75]:
col_transformer.named_transformers_['encoder'].get_feature_names()

array(['x0_Affluent senior apartments', 'x0_Affluent young families',
       'x0_Career and childcare',
       "x0_Couples with teens 'Married with children'",
       "x0_Dinki's (double income no kids)", 'x0_Etnically diverse',
       'x0_Family starters', 'x0_Fresh masters in the city',
       'x0_High Income, expensive child', 'x0_High status seniors',
       'x0_Large family farms', 'x0_Large family, employed child',
       'x0_Large religous families', 'x0_Low income catholics',
       'x0_Lower class large families', 'x0_Middle class families',
       'x0_Mixed apartment dwellers', 'x0_Mixed rurals',
       'x0_Mixed seniors', 'x0_Mixed small town dwellers',
       'x0_Modern, complete families', 'x0_Own home elderly',
       'x0_Porchless seniors: no front yard',
       'x0_Religious elderly singles', 'x0_Residential elderly',
       'x0_Senior cosmopolitans', 'x0_Seniors in apartments',
       'x0_Single youth', 'x0_Stable family', 'x0_Students in apartments',
       'x0_Suburb

In [76]:
len(col_transformer.named_transformers_['encoder'].get_feature_names())

49

In [77]:
one_hot_cols = col_transformer.named_transformers_['encoder'].get_feature_names()
col_names = []

for i, col in enumerate(cat_feats):
    del_str = f'x{i}'
    col_list = [itm for itm in one_hot_cols if itm.startswith(del_str)]
    col_names = col_names + [x.replace(del_str, col) for x in col_list]

col_names

['zip_agg_customer_subtype_Affluent senior apartments',
 'zip_agg_customer_subtype_Affluent young families',
 'zip_agg_customer_subtype_Career and childcare',
 "zip_agg_customer_subtype_Couples with teens 'Married with children'",
 "zip_agg_customer_subtype_Dinki's (double income no kids)",
 'zip_agg_customer_subtype_Etnically diverse',
 'zip_agg_customer_subtype_Family starters',
 'zip_agg_customer_subtype_Fresh masters in the city',
 'zip_agg_customer_subtype_High Income, expensive child',
 'zip_agg_customer_subtype_High status seniors',
 'zip_agg_customer_subtype_Large family farms',
 'zip_agg_customer_subtype_Large family, employed child',
 'zip_agg_customer_subtype_Large religous families',
 'zip_agg_customer_subtype_Low income catholics',
 'zip_agg_customer_subtype_Lower class large families',
 'zip_agg_customer_subtype_Middle class families',
 'zip_agg_customer_subtype_Mixed apartment dwellers',
 'zip_agg_customer_subtype_Mixed rurals',
 'zip_agg_customer_subtype_Mixed seniors',

In [69]:
col_names + list(train_X.drop(cat_feats, axis=1).columns)

['zip_agg_customer_subtype_Affluent senior apartments',
 'zip_agg_customer_subtype_Affluent young families',
 'zip_agg_customer_subtype_Career and childcare',
 "zip_agg_customer_subtype_Couples with teens 'Married with children'",
 "zip_agg_customer_subtype_Dinki's (double income no kids)",
 'zip_agg_customer_subtype_Etnically diverse',
 'zip_agg_customer_subtype_Family starters',
 'zip_agg_customer_subtype_Fresh masters in the city',
 'zip_agg_customer_subtype_High Income, expensive child',
 'zip_agg_customer_subtype_High status seniors',
 'zip_agg_customer_subtype_Large family farms',
 'zip_agg_customer_subtype_Large family, employed child',
 'zip_agg_customer_subtype_Large religous families',
 'zip_agg_customer_subtype_Low income catholics',
 'zip_agg_customer_subtype_Lower class large families',
 'zip_agg_customer_subtype_Middle class families',
 'zip_agg_customer_subtype_Mixed apartment dwellers',
 'zip_agg_customer_subtype_Mixed rurals',
 'zip_agg_customer_subtype_Mixed seniors',

In [86]:
%%writefile custom_preprocess.py
import pandas as pd
import numpy as np

import time
import sys
from io import StringIO
import os
import shutil

import argparse
import csv
import joblib
import json

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline

from sagemaker_containers.beta.framework import (
    content_types, encoders, env, modules, transformer, worker)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
    parser.add_argument('--label_col', type=str, default='nbr_mobile_home_policies')
    args = parser.parse_args()
    
    cat_feats = ['zip_agg_customer_subtype', 'zip_agg_customer_main_type']
    
    input_files = [ os.path.join(args.train, file) for file in os.listdir(args.train) ]
    if len(input_files) == 0:
        raise ValueError(('There are no files in {}.\n' +
                          'This usually indicates that the channel ({}) was incorrectly specified,\n' +
                          'the data specification in S3 was incorrectly specified or the role specified\n' +
                          'does not have permission to access the data.').format(args.train, "train"))

    raw_data = [ pd.read_csv(file) for file in input_files ]
    concat_data = pd.concat(raw_data)
    
    number_of_columns_x = concat_data.shape[1]
    train_y = concat_data.iloc[:,number_of_columns_x-1]
    train_X = concat_data.iloc[:,:number_of_columns_x-1]
    
    col_transformer = ColumnTransformer([
            ('encoder', OneHotEncoder(), cat_feats)],
        remainder='passthrough')
        
    col_transformer.fit(train_X, train_y)

    joblib.dump(col_transformer, os.path.join(args.model_dir, "model.joblib"))

    print("saved model!")
    
    one_hot_cols = col_transformer.named_transformers_['encoder'].get_feature_names()
    feature_names = []

    for i, col in enumerate(cat_feats):
        del_str = f'x{i}'
        col_list = [itm for itm in one_hot_cols if itm.startswith(del_str)]
        feature_names = feature_names + [x.replace(del_str, col) for x in col_list]
        
    feature_names = feature_names + list(train_X.drop(cat_feats, axis=1).columns)
    
    joblib.dump(feature_names, os.path.join(args.model_dir, "selected_feature_names.joblib"))
    
    print("Selected features are: {}".format(feature_names))
    
def input_fn(input_data, content_type):
    '''Parse input data payload
    
    Accepts csv, parquet, or json file types'''
    
    print('Running input function')
    
    if content_type == 'text/csv':
        df = pd.read_csv(StringIO(input_data))
        return df
    elif content_type == 'application/x-parquet':
        df = pd.read_parquet(input_data)
    elif content_type == 'application/json':
        df = pd.read_json(input_data)
    else:
        raise ValueError("{} not supported by script".format(content_type))
        
def output_fn(prediction, accept):
    '''Format prediction output.
    
    The default accept/content-type between containers for serial inference is JSON.
    We also want to set the ContentType or mimetype as the same value as accept so the next
    container can read the response payload correctly.
    '''
    
    print('Running output function')
    
    if accept == 'application/json':
        instances = []
        for row in prediction.tolist():
            instances.append({'features': row})
            
        json_output = {'instances': instances}
        
        return worker.Response(json.dumps(json_output), mimetype=accept)
    elif accept == 'text/csv':
        return worker.Response(encoders.encode(prediction, accept), mimetype=accept)
    else:
        raise RuntimeException('{} accept type is not supported by this script')
        
def predict_fn(input_data, model):
    '''Preprocess input data
    
    The default predict_fn uses .predict(), but our model is a preprocessor
    so we want to use .transform().
    '''
    
    print('Running predict_function')
    
    print('Input data shape at predict_fn: {}'.format(input_data.shape))
    if input_data.shape[1] == INPUT_FEATURES_SIZE:
        features = model.transform(input_data)
        return features
    elif input_data.shape[1] -- INPUT_FEATURES_SIZE + 1:
        features = model.transform(inputdata.iloc[:, :INPUT_FEATURES_SIZE])
        return np.insert(features, 0, input_data[label_column], axis=1)
    
def model_fn(model_dir):
    '''Deserialize fitted model'''
    
    print('Running model function')
    
    preprocessor = joblib.load(os.path.join(model_dir, 'model.joblib'))
    return preprocessor

Overwriting custom_preprocess.py


In [71]:
train.to_csv("data/train.csv", index=False)

WORK_DIRECTORY = "data"

train_input = session.upload_data(
    path="{}/{}".format(WORK_DIRECTORY, "train.csv"),
    bucket=bucket,
    key_prefix="{}/{}".format(prefix, "training_data"),
)

In [87]:
from sagemaker.sklearn.estimator import SKLearn

script_path = "custom_preprocess.py"
model_output_path = os.path.join("s3://", bucket, prefix, "preprocessing_model/")

sklearn_preprocessor = SKLearn(
    entry_point=script_path,
    role=role,
    output_path=model_output_path,
    instance_type="ml.m5.large",
    sagemaker_session=None,
    framework_version="1.0-1",
    py_version="py3",
    tags = tags
)

sklearn_preprocessor.fit({"train": train_input})

2023-02-02 20:39:43 Starting - Starting the training job...
2023-02-02 20:40:09 Starting - Preparing the instances for trainingProfilerReport-1675370383: InProgress
......
2023-02-02 20:41:11 Downloading - Downloading input data...
2023-02-02 20:41:42 Training - Downloading the training image.....[34m2023-02-02 20:42:21,746 sagemaker-containers INFO     Imported framework sagemaker_sklearn_container.training[0m
[34m2023-02-02 20:42:21,749 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2023-02-02 20:42:21,757 sagemaker_sklearn_container.training INFO     Invoking user training script.[0m
[34m2023-02-02 20:42:22,019 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2023-02-02 20:42:22,031 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2023-02-02 20:42:22,042 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2023

In [89]:
key_prefix = os.path.join(
    prefix,
    "preprocessing_model/",
    sklearn_preprocessor.latest_training_job.job_name,
    "output",
    "model.tar.gz",
)
session.download_data(path="./", bucket=bucket, key_prefix=key_prefix)

In [90]:
!tar xvzf model.tar.gz

model.joblib
selected_feature_names.joblib


In [91]:
import joblib

feature_list = list(joblib.load("selected_feature_names.joblib"))
print(feature_list)

['zip_agg_customer_subtype_Affluent senior apartments', 'zip_agg_customer_subtype_Affluent young families', 'zip_agg_customer_subtype_Career and childcare', "zip_agg_customer_subtype_Couples with teens 'Married with children'", "zip_agg_customer_subtype_Dinki's (double income no kids)", 'zip_agg_customer_subtype_Etnically diverse', 'zip_agg_customer_subtype_Family starters', 'zip_agg_customer_subtype_Fresh masters in the city', 'zip_agg_customer_subtype_High Income, expensive child', 'zip_agg_customer_subtype_High status seniors', 'zip_agg_customer_subtype_Large family farms', 'zip_agg_customer_subtype_Large family, employed child', 'zip_agg_customer_subtype_Large religous families', 'zip_agg_customer_subtype_Low income catholics', 'zip_agg_customer_subtype_Lower class large families', 'zip_agg_customer_subtype_Middle class families', 'zip_agg_customer_subtype_Mixed apartment dwellers', 'zip_agg_customer_subtype_Mixed rurals', 'zip_agg_customer_subtype_Mixed seniors', 'zip_agg_customer

In [92]:
# Define a SKLearn Transformer from the trained SKLearn Estimator
transformer_output = os.path.join("s3://", bucket, prefix, "Feature_selection_output/")
transformer = sklearn_preprocessor.transformer(
    instance_count=1,
    instance_type="ml.m5.large",
    output_path=transformer_output,
    assemble_with="Line",
    accept="text/csv",
)

In [None]:
# Preprocess training input
transformer.transform(train_input, content_type="text/csv")
print("Waiting for transform job: " + transformer.latest_transform_job.job_name)
transformer.wait()
preprocessed_train = transformer.output_path