Skip to content

A tool to ensure safely load machine/deep learning models.

License

Notifications You must be signed in to change notification settings

goeckslab/model-unpickler

Repository files navigation

model-unpickler

A tool to load machine/deep learning models with security.

Many machine/deep learning libraries (PyTorch, Scikit-Learn and so on) save trained models solely based on Python pickle, while pickle is well known for its potential to execute malicious code when loading objects from untrusted sources.

This libary provides a secure tool to load pickled models by overriding the find_class method of standard python Unpickler class together with a series of global names -- whilelist. Only globals in the whilelist are allowed in loaded model objects, whereas the loading process interrupts when an untrusted global name is found to prevent any potential exploit.

This libary also provides utils to quickly update the global whilelist in case that the corresponding machine learning libraries are updated.

Useage

For scikit-learn or other galaxy-ml supported models

from model_unpickler import safe_load_model

with open('path_to_model', 'rb') as f:
    safe_load_model(f)

For torch models

import pickle
from model_unpickler import SafeUnpickler


# Override pickle Unpickler with SafeUnpickler before calling `torch.load`
setattr(pickle, 'Unpickler', SafeUnpickler)

# torch.load('path_to_model')

To generate a new whitelist

python scripts/gen_whitelist.py

To use costom whitelist file, set environment variable PK_WHITELIST_PATH

# linux
export PK_WHITELIST_PATH='path_to_new_whitelist_file'

About

A tool to ensure safely load machine/deep learning models.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages