|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +"""Defense-in-depth RestrictedUnpickler for TensorWatch. |
| 5 | +
|
| 6 | +Blocks modules commonly exploited in pickle deserialization attacks while |
| 7 | +allowing the data types that TensorWatch legitimately serializes (numpy |
| 8 | +arrays, torch tensors, TensorWatch data classes, built-in collections, etc.). |
| 9 | +
|
| 10 | +WARNING: This is NOT a complete sandbox. A determined attacker may still find |
| 11 | +bypass techniques. Do not load pickle data from untrusted sources. |
| 12 | +""" |
| 13 | + |
| 14 | +import io |
| 15 | +import pickle |
| 16 | +import logging |
| 17 | + |
| 18 | +_BLOCKED_MODULES = frozenset({ |
| 19 | + # OS / filesystem access |
| 20 | + 'os', 'posix', 'nt', 'os.path', |
| 21 | + 'shutil', 'pathlib', |
| 22 | + 'tempfile', 'glob', 'fnmatch', |
| 23 | + # Process / subprocess execution |
| 24 | + 'subprocess', 'multiprocessing', |
| 25 | + 'pty', 'commands', |
| 26 | + # Code compilation / execution |
| 27 | + 'code', 'codeop', 'compileall', |
| 28 | + 'importlib', 'runpy', 'pkgutil', |
| 29 | + # Network |
| 30 | + 'socket', 'http', 'urllib', 'ftplib', 'smtplib', 'xmlrpc', |
| 31 | + 'socketserver', 'asyncio', |
| 32 | + # Low-level / FFI |
| 33 | + 'ctypes', 'mmap', |
| 34 | + # Interactive / debug |
| 35 | + 'pdb', 'profile', 'webbrowser', |
| 36 | + # Signal handling |
| 37 | + 'signal', |
| 38 | +}) |
| 39 | + |
| 40 | +# Specific names blocked from builtins module |
| 41 | +_BLOCKED_BUILTINS = frozenset({ |
| 42 | + 'eval', 'exec', 'compile', '__import__', |
| 43 | + 'open', 'input', 'breakpoint', |
| 44 | + 'exit', 'quit', |
| 45 | + 'globals', 'locals', 'vars', |
| 46 | + 'getattr', 'setattr', 'delattr', |
| 47 | +}) |
| 48 | + |
| 49 | + |
| 50 | +class RestrictedUnpickler(pickle.Unpickler): |
| 51 | + """Unpickler that blocks known-dangerous modules and callables. |
| 52 | +
|
| 53 | + Allowed: numpy, torch, tensorwatch, collections, standard data types. |
| 54 | + Blocked: os, subprocess, socket, ctypes, importlib, etc. |
| 55 | + """ |
| 56 | + |
| 57 | + def find_class(self, module, name): |
| 58 | + top_module = module.split('.')[0] |
| 59 | + |
| 60 | + if top_module in _BLOCKED_MODULES: |
| 61 | + raise pickle.UnpicklingError( |
| 62 | + "Blocked: unpickling {}.{} is not allowed " |
| 63 | + "(module '{}' is restricted)".format(module, name, top_module)) |
| 64 | + |
| 65 | + if top_module == 'builtins' and name in _BLOCKED_BUILTINS: |
| 66 | + raise pickle.UnpicklingError( |
| 67 | + "Blocked: unpickling builtins.{} is not allowed".format(name)) |
| 68 | + |
| 69 | + return super().find_class(module, name) |
| 70 | + |
| 71 | + |
| 72 | +def restricted_loads(data: bytes): |
| 73 | + """Deserialize bytes using RestrictedUnpickler.""" |
| 74 | + return RestrictedUnpickler(io.BytesIO(data)).load() |
| 75 | + |
| 76 | + |
| 77 | +def restricted_load(f): |
| 78 | + """Deserialize from a file object using RestrictedUnpickler.""" |
| 79 | + return RestrictedUnpickler(f).load() |
0 commit comments