In [1]:
import numpy as np
import pandas as pd
import torch

from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.optim import RAdam
from pytorch_widedeep.initializers import KaimingNormal, XavierNormal
from pytorch_widedeep.callbacks import LRHistory, ModelCheckpoint, EarlyStopping
from pytorch_widedeep.metrics import BinaryAccuracy

In [2]:
df = pd.read_csv('../data/adult/adult.csv.zip')
df.columns = [c.replace("-", "_") for c in df.columns]
df['age_buckets'] = pd.cut(df.age, bins=[16, 25, 30, 35, 40, 45, 50, 55, 60, 91], labels=np.arange(9))
df['income_label'] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop('income', axis=1, inplace=True)
df.head()

wide_cols = ['age_buckets', 'education', 'relationship','workclass','occupation',
    'native_country','gender']
crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]
cat_embed_cols = [('education',10), ('relationship',8), ('workclass',10),
    ('occupation',10),('native_country',10)]
continuous_cols = ["age","hours_per_week"]
target = 'income_label'

target = df[target].values
prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = prepare_wide.fit_transform(df)
prepare_deep = DeepPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)
X_deep = prepare_deep.fit_transform(df)
wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)
deepdense = DeepDense(hidden_layers=[64,32], dropout=[0.5],
                      deep_column_idx=prepare_deep.deep_column_idx,
                      embed_input=prepare_deep.embeddings_input,
                      continuous_cols=continuous_cols)
model = WideDeep(wide=wide, deepdense=deepdense)
model

WideDeep(
  (wide): Wide(
    (wide_linear): Linear(in_features=805, out_features=1, bias=True)
  )
  (deepdense): Sequential(
    (0): DeepDense(
      (embed_layers): ModuleDict(
        (emb_layer_education): Embedding(16, 10)
        (emb_layer_native_country): Embedding(42, 10)
        (emb_layer_occupation): Embedding(15, 10)
        (emb_layer_relationship): Embedding(6, 8)
        (emb_layer_workclass): Embedding(9, 10)
      )
      (dense): Sequential(
        (dense_layer_0): Sequential(
          (0): Linear(in_features=50, out_features=64, bias=True)
          (1): LeakyReLU(negative_slope=0.01, inplace=True)
          (2): Dropout(p=0.0, inplace=False)
        )
        (dense_layer_1): Sequential(
          (0): Linear(in_features=64, out_features=32, bias=True)
          (1): LeakyReLU(negative_slope=0.01, inplace=True)
          (2): Dropout(p=0.5, inplace=False)
        )
      )
    )
    (1): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [3]:
wide_opt = torch.optim.Adam(model.wide.parameters())
deep_opt = RAdam(model.deepdense.parameters())

In [4]:
wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)
deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)

In [5]:
initializers = {'wide': KaimingNormal, 'deepdense':XavierNormal}
optimizers = {'wide': wide_opt, 'deepdense':deep_opt}
schedulers = {'wide': wide_sch, 'deepdense':deep_sch}
callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='../model_weights/wd_out')]
metrics = [BinaryAccuracy]

In [6]:
model.compile(method='logistic', initializers=initializers, optimizers=optimizers, lr_schedulers=schedulers,
              callbacks=callbacks, metrics=metrics)

In [7]:
model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:01<00:00, 90.38it/s, loss=0.554, metrics={'acc': 0.7213}]
valid: 100%|██████████| 39/39 [00:00<00:00, 143.57it/s, loss=0.45, metrics={'acc': 0.7365}]
epoch 2: 100%|██████████| 153/153 [00:01<00:00, 99.12it/s, loss=0.403, metrics={'acc': 0.8173}] 
valid: 100%|██████████| 39/39 [00:00<00:00, 137.70it/s, loss=0.382, metrics={'acc': 0.8183}]
epoch 3: 100%|██████████| 153/153 [00:01<00:00, 97.10it/s, loss=0.368, metrics={'acc': 0.8321}] 
valid: 100%|██████████| 39/39 [00:00<00:00, 142.64it/s, loss=0.365, metrics={'acc': 0.8319}]
epoch 4: 100%|██████████| 153/153 [00:01<00:00, 97.67it/s, loss=0.356, metrics={'acc': 0.8356}]
valid: 100%|██████████| 39/39 [00:00<00:00, 139.90it/s, loss=0.358, metrics={'acc': 0.8354}]
epoch 5: 100%|██████████| 153/153 [00:01<00:00, 98.53it/s, loss=0.351, metrics={'acc': 0.8373}]
valid: 100%|██████████| 39/39 [00:00<00:00, 142.39it/s, loss=0.355, metrics={'acc': 0.8368}]
epoch 6: 100%|██████████| 153/153 [00:01<00:00, 96.25i

> /Users/javier/pytorch-widedeep/pytorch_widedeep/callbacks.py(338)on_epoch_end()
-> if self.wait >= self.patience:
(Pdb) c


epoch 9: 100%|██████████| 153/153 [00:01<00:00, 91.56it/s, loss=0.347, metrics={'acc': 0.8401}]
valid: 100%|██████████| 39/39 [00:00<00:00, 125.78it/s, loss=0.353, metrics={'acc': 0.8392}]
epoch 10: 100%|██████████| 153/153 [00:01<00:00, 93.83it/s, loss=0.346, metrics={'acc': 0.8406}]
valid: 100%|██████████| 39/39 [00:00<00:00, 138.04it/s, loss=0.351, metrics={'acc': 0.8399}]


In [8]:
model.history._history

{'train_loss': [0.554220751804464,
  0.40329983148699494,
  0.3677725032264111,
  0.3562958643716924,
  0.35065410904635014,
  0.34857911046813517,
  0.34745271260442295,
  0.34648635554936974,
  0.34684639136775647,
  0.34612662651959586],
 'train_acc': [0.7213,
  0.8173,
  0.8321,
  0.8356,
  0.8373,
  0.8398,
  0.8399,
  0.8395,
  0.8401,
  0.8406],
 'val_loss': [0.45026036103566486,
  0.3821376669101226,
  0.36544298820006543,
  0.3581323524316152,
  0.3551228948128529,
  0.35395471866314226,
  0.3530398324514047,
  0.35306169665776765,
  0.352843076754839,
  0.3511059482892354],
 'val_acc': [0.7365,
  0.8183,
  0.8319,
  0.8354,
  0.8368,
  0.8392,
  0.8389,
  0.8387,
  0.8392,
  0.8399]}

In [9]:
model.lr_history

{'lr_wide_0': [0.001,
  0.001,
  0.001,
  0.0001,
  0.0001,
  0.0001,
  1.0000000000000003e-05,
  1.0000000000000003e-05,
  1.0000000000000003e-05,
  1.0000000000000002e-06,
  1.0000000000000002e-06],
 'lr_deepdense_0': [0.001,
  0.001,
  0.001,
  0.001,
  0.001,
  0.0001,
  0.0001,
  0.0001,
  0.0001,
  0.0001,
  1.0000000000000003e-05]}