# Model Persistence Example

Model persistence is an important part of deploying and sharing models. This notebook shows the two supported ways for saving and loading models: `pickle` and PyTorch's `state_dict`.

In [1]:
import os
os.chdir('..')

## Pickle

Pickling is the default way of persisting arbitrary Python objects to files. It is easy, but comes with the caveat that these artifacts are not highly portable: you will need the exact same requirements in order to load them back in. Despite that, they remain the default way of persisting Scikit-learn models.

To demonstrate, we first create a PyTorch model that uses pooled word vectors, with an SVD in the middle to reduce the dimensionality. The results are then fed into a linear layer which produces a desired output.

In [2]:
import torch
import torch.nn as nn
from textwiser import TextWiser, Embedding, Transformation, WordOptions, PoolOptions

docs = ['This is one document.', 'This is a second document.', 'Not the second document!']

tw = TextWiser(Embedding.Word(word_option=WordOptions.word2vec, pretrained='en-turian'),
                       [Transformation.SVD(n_components=2), Transformation.Pool(pool_option=PoolOptions.mean)], dtype=torch.float32)
tw.fit(docs)
model = nn.Sequential(tw, nn.Linear(2, 1))
model

Sequential(
  (0): TextWiser(
    (model): _Sequential(
      (0): _WordEmbeddings(
        (model): Embedding(246117, 50, sparse=True)
      )
      (1): _SVDTransformation()
      (2): _PoolTransformation()
    )
  )
  (1): Linear(in_features=2, out_features=1, bias=True)
)

For demonstration, we can look at the output of the model with some example documents.

In [3]:
# Get results of the model
expected = model(docs)
expected

tensor([[0.0605],
        [0.0669],
        [0.0652]], grad_fn=<AddmmBackward>)

The model can then be persisted using pickle and loaded back in. Note that in the below cell we delete the `model` object between saving and loading, meaning the `model` object after loading is brand new.

In [4]:
import pickle
from tempfile import NamedTemporaryFile

with NamedTemporaryFile() as file:
    with open(file.name, 'wb') as fp:
        pickle.dump(model, fp)
    del model
    with open(file.name, 'rb') as fp:
        model = pickle.load(fp)
    print(model(docs))

tensor([[0.0605],
        [0.0669],
        [0.0652]], grad_fn=<AddmmBackward>)


As expected, the output of the loaded object is exactly the same as when it was first created.

## State dict

The preferred way of saving any PyTorch model is to use the state dictionary. An example can be found [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html).

TextWiser overloads the state dictionary to also hold the data it requires. For example, an `Embedding.TfIdf` model will store its internal Scikit-learn `TfIdfVectorizer` object inside the state dictionary. Note that these objects will **still get pickled**, meaning that the Scikit-learn version needs to remain the same when saving and loading. However, the PyTorch version **can** change without any issues.

In [5]:
import torch

with NamedTemporaryFile() as file:
    # Save the model
    torch.save(model.state_dict(), file.name)
    # Get rid of the original model
    del tw
    del model
    # Create the same model
    tw = TextWiser(Embedding.Word(word_option=WordOptions.word2vec, pretrained='en-turian'),
                   [Transformation.SVD(n_components=2), Transformation.Pool(pool_option=PoolOptions.mean)], dtype=torch.float32)
    tw.fit()
    model = nn.Sequential(tw, nn.Linear(2, 1))
    # Load the model from file
    model.load_state_dict(torch.load(file.name))
    # Do predictions with the loaded model
    predicted = model(docs)
    print(predicted)

tensor([[0.0605],
        [0.0669],
        [0.0652]], grad_fn=<AddmmBackward>)


As shown here, the output is again the same between saving and loading. Again, the model is deleted in between saving and loading.

While this approach doesn't fully solve the dependency on a specific environment, it does lessen the blow, and has the added benefit of being compatible with 3rd party experiment management solutions, which assume that your model parameters will be persisted using the state dictionary.