In this notebook I will show the different options to save and load a model, as well as some additional objects produced during training. 

On a given day, you train a model...

In [1]:
import pickle
import numpy as np
import pandas as pd
import torch
import shutil

from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.training import Trainer
from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint, LRHistory
from pytorch_widedeep.models import TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from sklearn.model_selection import train_test_split

  return f(*args, **kwds)


In [2]:
df = load_adult(as_frame=True)
df.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational-num,marital-status,occupation,relationship,race,gender,capital-gain,capital-loss,hours-per-week,native-country,income
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,<=50K


In [3]:
# For convenience, we'll replace '-' with '_'
df.columns = [c.replace("-", "_") for c in df.columns]
# binary target
df["target"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational_num,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,target
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,0
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,0
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,1
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,1
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,0


In [4]:
train, valid = train_test_split(df, test_size=0.2, stratify=df.target)
# the test data will be used lately as if it was "fresh", new data coming after some time...
valid, test = train_test_split(valid, test_size=0.5, stratify=valid.target)

In [5]:
print(f"train shape: {train.shape}")
print(f"valid shape: {valid.shape}")
print(f"test shape: {test.shape}")

train shape: (39073, 15)
valid shape: (4884, 15)
test shape: (4885, 15)


In [6]:
cat_embed_cols = [
    "workclass",
    "education",
    "marital_status",
    "occupation",
    "relationship",
    "race",
    "gender",
    "capital_gain",
    "capital_loss",
    "native_country",
]
continuous_cols = ["age", "hours_per_week"]

In [7]:
tab_preprocessor = TabPreprocessor(
    embed_cols=cat_embed_cols, continuous_cols=continuous_cols, scale=True
)
X_tab_train = tab_preprocessor.fit_transform(train)
y_train = train.target.values
X_tab_valid = tab_preprocessor.transform(valid)
y_valid = valid.target.values

In [8]:
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    cat_embed_dropout=0.1,
    continuous_cols=continuous_cols,
    mlp_hidden_dims=[400, 200],
    mlp_dropout=0.5,
    mlp_activation="leaky_relu",
)
model = WideDeep(deeptabular=tab_mlp)

In [9]:
model

WideDeep(
  (deeptabular): Sequential(
    (0): TabMlp(
      (cat_and_cont_embed): DiffSizeCatAndContEmbeddings(
        (cat_embed): DiffSizeCatEmbeddings(
          (embed_layers): ModuleDict(
            (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0)
            (emb_layer_capital_loss): Embedding(95, 20, padding_idx=0)
            (emb_layer_education): Embedding(17, 8, padding_idx=0)
            (emb_layer_gender): Embedding(3, 2, padding_idx=0)
            (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)
            (emb_layer_native_country): Embedding(43, 13, padding_idx=0)
            (emb_layer_occupation): Embedding(16, 7, padding_idx=0)
            (emb_layer_race): Embedding(6, 4, padding_idx=0)
            (emb_layer_relationship): Embedding(7, 4, padding_idx=0)
            (emb_layer_workclass): Embedding(10, 5, padding_idx=0)
          )
          (embedding_dropout): Dropout(p=0.1, inplace=False)
        )
        (cont_norm): BatchNorm1d(2, eps

In [10]:
early_stopping = EarlyStopping()
model_checkpoint = ModelCheckpoint(
    filepath="tmp_dir/adult_tabmlp_model",
    save_best_only=True,
    verbose=1,
    max_save=1,
)

trainer = Trainer(
    model,
    objective="binary",
    callbacks=[early_stopping, model_checkpoint],
    metrics=[Accuracy],
)

trainer.fit(
    X_train={"X_tab": X_tab_train, "target": y_train},
    X_val={"X_tab": X_tab_valid, "target": y_valid},
    n_epochs=4,
    batch_size=256,
)

epoch 1: 100%|██████████| 153/153 [00:03<00:00, 49.33it/s, loss=0.433, metrics={'acc': 0.7912}]
valid: 100%|██████████| 20/20 [00:00<00:00, 51.28it/s, loss=0.344, metrics={'acc': 0.8516}]
epoch 2:   3%|▎         | 5/153 [00:00<00:03, 45.26it/s, loss=0.391, metrics={'acc': 0.8229}]


Epoch 00001: val_loss improved from inf to 0.34433, saving model to tmp_dir/adult_tabmlp_model_1.p


epoch 2: 100%|██████████| 153/153 [00:03<00:00, 47.43it/s, loss=0.389, metrics={'acc': 0.8196}]
valid: 100%|██████████| 20/20 [00:00<00:00, 105.10it/s, loss=0.334, metrics={'acc': 0.8583}]
epoch 3:   3%|▎         | 5/153 [00:00<00:03, 46.44it/s, loss=0.385, metrics={'acc': 0.8051}]


Epoch 00002: val_loss improved from 0.34433 to 0.33399, saving model to tmp_dir/adult_tabmlp_model_2.p


epoch 3: 100%|██████████| 153/153 [00:03<00:00, 50.25it/s, loss=0.368, metrics={'acc': 0.829}] 
valid: 100%|██████████| 20/20 [00:00<00:00, 106.84it/s, loss=0.318, metrics={'acc': 0.8591}]
epoch 4:   3%|▎         | 5/153 [00:00<00:03, 47.02it/s, loss=0.37, metrics={'acc': 0.8142}] 


Epoch 00003: val_loss improved from 0.33399 to 0.31802, saving model to tmp_dir/adult_tabmlp_model_3.p


epoch 4: 100%|██████████| 153/153 [00:03<00:00, 49.17it/s, loss=0.358, metrics={'acc': 0.8359}]
valid: 100%|██████████| 20/20 [00:00<00:00, 106.41it/s, loss=0.318, metrics={'acc': 0.8622}]



Epoch 00004: val_loss improved from 0.31802 to 0.31791, saving model to tmp_dir/adult_tabmlp_model_4.p
Model weights restored to best epoch: 4


# Save model: option 1

save (and load) a model as you woud do with any other torch model

In [11]:
torch.save(model, "tmp_dir/model_saved_option_1.pt")

In [12]:
torch.save(model.state_dict(), "tmp_dir/model_state_dict_saved_option_1.pt")

# Save model: option 2

use the `trainer`. The `trainer` will also save the training history and the learning rate history (if learning rate schedulers are used)

In [13]:
trainer.save(path="tmp_dir/", model_filename="model_saved_option_2.pt")

or the state dict

In [14]:
trainer.save(
    path="tmp_dir/",
    model_filename="model_state_dict_saved_option_2.pt",
    save_state_dict=True,
)

In [15]:
%%bash

ls tmp_dir/

adult_tabmlp_model_4.p
eary_stop.pkl
history
model_saved_option_1.pt
model_saved_option_2.pt
model_state_dict_saved_option_1.pt
model_state_dict_saved_option_2.pt
tab_preproc.pkl


In [16]:
%%bash

ls tmp_dir/history/

train_eval_history.json


Note that since we have used the `ModelCheckpoint` Callback, `adult_tabmlp_model_2.p` is the model state dict of the model at epoch 2, i.e. same as `model_state_dict_saved_option_1.p` or `model_state_dict_saved_option_2.p`. 

# Save preprocessors and callbacks

...just pickle them

In [17]:
with open("tmp_dir/tab_preproc.pkl", "wb") as dp:
    pickle.dump(tab_preprocessor, dp)

In [18]:
with open("tmp_dir/eary_stop.pkl", "wb") as es:
    pickle.dump(early_stopping, es)

In [19]:
%%bash

ls tmp_dir/

adult_tabmlp_model_4.p
eary_stop.pkl
history
model_saved_option_1.pt
model_saved_option_2.pt
model_state_dict_saved_option_1.pt
model_state_dict_saved_option_2.pt
tab_preproc.pkl


And that is pretty much all you need to resume training or directly predict, let's see

# Run New experiment: prepare new dataset, load model, and predict

In [20]:
test.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational_num,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,target
16063,22,Self-emp-inc,171041,Bachelors,13,Never-married,Handlers-cleaners,Own-child,White,Male,0,0,40,United-States,0
48719,66,Self-emp-not-inc,102686,Masters,14,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,20,United-States,1
22580,43,Private,102180,Masters,14,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,1
47921,18,Self-emp-inc,174202,HS-grad,9,Never-married,Transport-moving,Own-child,White,Male,0,0,60,United-States,0
24205,36,State-gov,112074,Doctorate,16,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,45,United-States,0


In [21]:
with open("tmp_dir/tab_preproc.pkl", "rb") as tp:
    tab_preprocessor_new = pickle.load(tp)

In [22]:
X_test_tab = tab_preprocessor_new.transform(test)
y_test = test.target

In [23]:
tab_mlp_new = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    cat_embed_dropout=0.1,
    continuous_cols=continuous_cols,
    mlp_hidden_dims=[400, 200],
    mlp_dropout=0.5,
    mlp_activation="leaky_relu",
)
model = WideDeep(deeptabular=tab_mlp_new)

In [24]:
model.load_state_dict(torch.load("tmp_dir/model_state_dict_saved_option_2.pt"))

<All keys matched successfully>

In [25]:
trainer = Trainer(
    model,
    objective="binary",
)

In [26]:
preds = trainer.predict(X_tab=X_test_tab)

predict: 100%|██████████| 20/20 [00:00<00:00, 126.24it/s]


In [27]:
from sklearn.metrics import accuracy_score

In [28]:
accuracy_score(y_test, preds)

0.8589559877175026

In [29]:
shutil.rmtree("tmp_dir/")