In this notebook I will show the different options to save and load a model, as well as some additional objects produced during training. 

On a given day, you train a model...

In [1]:
import pickle
import numpy as np
import pandas as pd
import torch
import shutil

from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.training import Trainer
from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint, LRHistory
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from sklearn.model_selection import train_test_split

  return f(*args, **kwds)


In [2]:
df = pd.read_csv('data/adult/adult.csv.zip')
df.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational-num,marital-status,occupation,relationship,race,gender,capital-gain,capital-loss,hours-per-week,native-country,income
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,<=50K


In [3]:
# For convenience, we'll replace '-' with '_'
df.columns = [c.replace("-", "_") for c in df.columns]
#binary target
df['target'] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop('income', axis=1, inplace=True)
df.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational_num,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,target
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,0
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,0
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,1
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,1
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,0


In [4]:
train, valid = train_test_split(df, test_size=0.2, stratify=df.target)
# the test data will be used lately as if it was "fresh", new data coming after some time...
valid, test = train_test_split(valid, test_size=0.5, stratify=valid.target)

In [5]:
print(f"train shape: {train.shape}")
print(f"valid shape: {valid.shape}")
print(f"test shape: {test.shape}")

train shape: (39073, 15)
valid shape: (4884, 15)
test shape: (4885, 15)


In [6]:
wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']
crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]

In [7]:
cat_embed_cols = []
for col in train.columns:
    if train[col].dtype == "O" or train[col].nunique() < 200 and col != "target":
        cat_embed_cols.append(col)
num_cols = [c for c in train.columns if c not in cat_embed_cols + ["target"]]

In [8]:
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide_train = wide_preprocessor.fit_transform(train)
X_wide_valid = wide_preprocessor.transform(valid)

In [9]:
tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols, continuous_cols=num_cols, scale=True)
X_tab_train = tab_preprocessor.fit_transform(train)
y_train = train.target.values
X_tab_valid = tab_preprocessor.transform(valid)
y_valid = valid.target.values

In [10]:
# save wide_dim somewhere
wide = Wide(wide_dim=wide_preprocessor.wide_dim)
deeptabular = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
)
model = WideDeep(wide=wide, deeptabular=deeptabular)

In [11]:
model

WideDeep(
  (wide): Wide(
    (wide_linear): Embedding(773, 1, padding_idx=0)
  )
  (deeptabular): Sequential(
    (0): TabMlp(
      (embed_layers): ModuleDict(
        (emb_layer_age): Embedding(75, 18, padding_idx=0)
        (emb_layer_capital_gain): Embedding(121, 23, padding_idx=0)
        (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0)
        (emb_layer_education): Embedding(17, 8, padding_idx=0)
        (emb_layer_educational_num): Embedding(17, 8, padding_idx=0)
        (emb_layer_gender): Embedding(3, 2, padding_idx=0)
        (emb_layer_hours_per_week): Embedding(96, 20, padding_idx=0)
        (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)
        (emb_layer_native_country): Embedding(43, 13, padding_idx=0)
        (emb_layer_occupation): Embedding(16, 7, padding_idx=0)
        (emb_layer_race): Embedding(6, 4, padding_idx=0)
        (emb_layer_relationship): Embedding(7, 4, padding_idx=0)
        (emb_layer_workclass): Embedding(10, 5, padding_idx=0)


In [12]:
early_stopping = EarlyStopping()
model_checkpoint = ModelCheckpoint(
    filepath="tmp_dir/adult_tabmlp_model",
    save_best_only=True,
    verbose=1,
    max_save=1,
)

trainer = Trainer(
    model,
    objective="binary",
    callbacks=[early_stopping, model_checkpoint],
    metrics=[Accuracy],
)

trainer.fit(
    X_train={"X_wide": X_wide_train, "X_tab": X_tab_train, "target": y_train},
    X_val={"X_wide": X_wide_valid, "X_tab": X_tab_valid, "target": y_valid},
    n_epochs=2,
    batch_size=256,
)

epoch 1: 100%|██████████| 153/153 [00:04<00:00, 31.06it/s, loss=0.479, metrics={'acc': 0.7839}]
valid: 100%|██████████| 20/20 [00:00<00:00, 48.70it/s, loss=0.348, metrics={'acc': 0.8444}]
epoch 2:   3%|▎         | 4/153 [00:00<00:04, 32.06it/s, loss=0.365, metrics={'acc': 0.8385}]


Epoch 00001: val_loss improved from inf to 0.34800, saving model to tmp_dir/adult_tabmlp_model_1.p


epoch 2: 100%|██████████| 153/153 [00:04<00:00, 32.33it/s, loss=0.354, metrics={'acc': 0.8379}]
valid: 100%|██████████| 20/20 [00:00<00:00, 92.91it/s, loss=0.322, metrics={'acc': 0.8511}] 


Epoch 00002: val_loss improved from 0.34800 to 0.32204, saving model to tmp_dir/adult_tabmlp_model_2.p
Model weights restored to best epoch: 2





# Save model: option 1

save (and load) a model as you woud do with any other torch model

In [13]:
torch.save(model, "tmp_dir/model_saved_option_1.pt")

In [14]:
torch.save(model.state_dict(), "tmp_dir/model_state_dict_saved_option_1.pt")

# Save model: option 2

use the `trainer`. The `trainer` will also save the training history and the learning rate history (if learning rate schedulers are used)

In [15]:
trainer.save(path="tmp_dir/", model_filename="model_saved_option_2.pt")

or the state dict

In [16]:
trainer.save(path="tmp_dir/", model_filename="model_state_dict_saved_option_2.pt", save_state_dict=True)

In [17]:
%%bash

ls tmp_dir/

adult_tabmlp_model_2.p
history
model_saved_option_1.pt
model_saved_option_2.pt
model_state_dict_saved_option_1.pt
model_state_dict_saved_option_2.pt


In [18]:
%%bash

ls tmp_dir/history/

train_eval_history.json


Note that since we have used the `ModelCheckpoint` Callback, `adult_tabmlp_model_2.p` is the model state dict of the model at epoch 2, i.e. same as `model_state_dict_saved_option_1.p` or `model_state_dict_saved_option_2.p`. 

# Save preprocessors and callbacks

...just pickle them

In [19]:
with open('tmp_dir/wide_preproc.pkl', 'wb') as wp:
    pickle.dump(wide_preprocessor, wp)

In [20]:
with open('tmp_dir/tab_preproc.pkl', 'wb') as dp:
    pickle.dump(tab_preprocessor, dp)

In [21]:
with open('tmp_dir/eary_stop.pkl', 'wb') as es:
    pickle.dump(early_stopping, es)

In [22]:
%%bash

ls tmp_dir/

adult_tabmlp_model_2.p
eary_stop.pkl
history
model_saved_option_1.pt
model_saved_option_2.pt
model_state_dict_saved_option_1.pt
model_state_dict_saved_option_2.pt
tab_preproc.pkl
wide_preproc.pkl


And that is pretty much all you need to resume training or directly predict, let's see

# Run New experiment: prepare new dataset, load model, and predict

In [23]:
test.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational_num,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,target
32823,33,Private,201988,Masters,14,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,45,United-States,0
40713,31,Private,231826,HS-grad,9,Married-civ-spouse,Other-service,Husband,White,Male,0,0,52,Mexico,0
16020,38,Private,24126,Some-college,10,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,40,United-States,0
32766,38,State-gov,312528,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,37,United-States,0
9713,40,Self-emp-not-inc,121012,Prof-school,15,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,1977,50,United-States,1


In [24]:
with open('tmp_dir/wide_preproc.pkl', 'rb') as wp:
    wide_preprocessor_new = pickle.load(wp)

In [25]:
with open('tmp_dir/tab_preproc.pkl', 'rb') as tp:
    tab_preprocessor_new = pickle.load(tp)

In [26]:
X_test_wide = wide_preprocessor_new.transform(test)
X_test_tab = tab_preprocessor_new.transform(test)
y_test = test.target

In [27]:
wide_new = Wide(wide_dim=wide_preprocessor_new.wide_dim)
deeptabular = TabMlp(
    column_idx=tab_preprocessor_new.column_idx,
    embed_input=tab_preprocessor_new.embeddings_input,
)
model = WideDeep(wide=wide, deeptabular=deeptabular)

In [28]:
model.load_state_dict(torch.load("tmp_dir/model_state_dict_saved_option_2.pt"))

<All keys matched successfully>

In [29]:
trainer = Trainer(
    model,
    objective="binary",
)

In [30]:
preds = trainer.predict(X_wide=X_test_wide, X_tab=X_test_tab)

predict: 100%|██████████| 20/20 [00:00<00:00, 86.04it/s]


In [31]:
from sklearn.metrics import accuracy_score

In [32]:
accuracy_score(y_test, preds)

0.8554759467758444

In [33]:
shutil.rmtree("tmp_dir/")