Skip to content

Commit afe9390

Browse files
sytelusclaude
andcommitted
add RestrictedUnpickler and env-var HMAC key sharing for pickle hardening
Add safe_pickle.py with a RestrictedUnpickler that blocks known-dangerous modules (os, subprocess, socket, ctypes, importlib, etc.) during pickle deserialization as defense-in-depth. Replace raw pickle.load/loads with restricted variants in file_stream.py and zmq_wrapper.py. Add TENSORWATCH_HMAC_KEY environment variable support to get_hmac_key() so multi-process Watcher/WatcherClient setups can share the HMAC signing key without code changes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f08d801 commit afe9390

4 files changed

Lines changed: 109 additions & 7 deletions

File tree

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ TensorWatch supports Python 3.x and is tested with PyTorch 0.4-1.x. Most feature
5858
> - All incoming ZMQ messages are HMAC-SHA256 verified **before**
5959
> deserialization (`ZmqWrapper.verify_and_loads`). Messages with invalid
6060
> signatures are rejected without being deserialized.
61+
> - A `RestrictedUnpickler` blocks known-dangerous modules (`os`,
62+
> `subprocess`, `socket`, `ctypes`, etc.) as defense-in-depth.
63+
> - For multi-process setups, set the `TENSORWATCH_HMAC_KEY` environment
64+
> variable to a shared hex-encoded secret (e.g.
65+
> `export TENSORWATCH_HMAC_KEY=$(python -c "import os; print(os.urandom(32).hex())")`).
66+
> Alternatively, set `ZmqWrapper._hmac_key` directly in code before
67+
> calling `initialize()`.
6168
>
6269
> **User responsibilities:**
6370
> - Ensure the HMAC key is kept secret and shared only with trusted processes.
@@ -68,6 +75,11 @@ TensorWatch supports Python 3.x and is tested with PyTorch 0.4-1.x. Most feature
6875
> `FileStream` (in `file_stream.py`) uses `pickle.load()` to read stream data
6976
> from files. A crafted pickle file can execute arbitrary code when loaded.
7077
>
78+
> **Mitigations in place:**
79+
> - A `RestrictedUnpickler` blocks known-dangerous modules (`os`,
80+
> `subprocess`, etc.) as defense-in-depth. This is **not** a complete
81+
> sandbox — determined attackers may find bypasses.
82+
>
7183
> **User responsibilities:**
7284
> - **Only open TensorWatch data files (`.log`, `.pkl`) that you created
7385
> yourself or that come from a source you fully trust.**

tensorwatch/file_stream.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pickle, os
66
from typing import Any
77
from . import utils
8+
from .safe_pickle import restricted_load
89
import time
910

1011
class FileStream(Stream):
@@ -38,7 +39,7 @@ def read_all(self, from_stream:'Stream'=None):
3839
if self._file is not None:
3940
self._file.seek(0, 0) # we may filter this stream multiple times
4041
while not utils.is_eof(self._file):
41-
yield pickle.load(self._file)
42+
yield restricted_load(self._file)
4243
for item in super(FileStream, self).read_all():
4344
yield item
4445

@@ -48,7 +49,7 @@ def load(self, from_stream:'Stream'=None):
4849
if self._file is not None:
4950
self._file.seek(0, 0) # we may filter this stream multiple times
5051
while not utils.is_eof(self._file):
51-
stream_item = pickle.load(self._file)
52+
stream_item = restricted_load(self._file)
5253
self.write(stream_item)
5354
super(FileStream, self).load()
5455

tensorwatch/safe_pickle.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""Defense-in-depth RestrictedUnpickler for TensorWatch.
5+
6+
Blocks modules commonly exploited in pickle deserialization attacks while
7+
allowing the data types that TensorWatch legitimately serializes (numpy
8+
arrays, torch tensors, TensorWatch data classes, built-in collections, etc.).
9+
10+
WARNING: This is NOT a complete sandbox. A determined attacker may still find
11+
bypass techniques. Do not load pickle data from untrusted sources.
12+
"""
13+
14+
import io
15+
import pickle
16+
import logging
17+
18+
_BLOCKED_MODULES = frozenset({
19+
# OS / filesystem access
20+
'os', 'posix', 'nt', 'os.path',
21+
'shutil', 'pathlib',
22+
'tempfile', 'glob', 'fnmatch',
23+
# Process / subprocess execution
24+
'subprocess', 'multiprocessing',
25+
'pty', 'commands',
26+
# Code compilation / execution
27+
'code', 'codeop', 'compileall',
28+
'importlib', 'runpy', 'pkgutil',
29+
# Network
30+
'socket', 'http', 'urllib', 'ftplib', 'smtplib', 'xmlrpc',
31+
'socketserver', 'asyncio',
32+
# Low-level / FFI
33+
'ctypes', 'mmap',
34+
# Interactive / debug
35+
'pdb', 'profile', 'webbrowser',
36+
# Signal handling
37+
'signal',
38+
})
39+
40+
# Specific names blocked from builtins module
41+
_BLOCKED_BUILTINS = frozenset({
42+
'eval', 'exec', 'compile', '__import__',
43+
'open', 'input', 'breakpoint',
44+
'exit', 'quit',
45+
'globals', 'locals', 'vars',
46+
'getattr', 'setattr', 'delattr',
47+
})
48+
49+
50+
class RestrictedUnpickler(pickle.Unpickler):
51+
"""Unpickler that blocks known-dangerous modules and callables.
52+
53+
Allowed: numpy, torch, tensorwatch, collections, standard data types.
54+
Blocked: os, subprocess, socket, ctypes, importlib, etc.
55+
"""
56+
57+
def find_class(self, module, name):
58+
top_module = module.split('.')[0]
59+
60+
if top_module in _BLOCKED_MODULES:
61+
raise pickle.UnpicklingError(
62+
"Blocked: unpickling {}.{} is not allowed "
63+
"(module '{}' is restricted)".format(module, name, top_module))
64+
65+
if top_module == 'builtins' and name in _BLOCKED_BUILTINS:
66+
raise pickle.UnpicklingError(
67+
"Blocked: unpickling builtins.{} is not allowed".format(name))
68+
69+
return super().find_class(module, name)
70+
71+
72+
def restricted_loads(data: bytes):
73+
"""Deserialize bytes using RestrictedUnpickler."""
74+
return RestrictedUnpickler(io.BytesIO(data)).load()
75+
76+
77+
def restricted_load(f):
78+
"""Deserialize from a file object using RestrictedUnpickler."""
79+
return RestrictedUnpickler(f).load()

tensorwatch/zmq_wrapper.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools, sys, logging
1313
from threading import Thread, Event
1414
from . import utils
15+
from .safe_pickle import restricted_loads
1516
import weakref, logging
1617

1718
class ZmqWrapper:
@@ -25,12 +26,21 @@ class ZmqWrapper:
2526
@staticmethod
2627
def get_hmac_key() -> bytes:
2728
"""Get or generate the HMAC signing key used to authenticate ZMQ messages.
28-
Set ZmqWrapper._hmac_key before calling initialize() to use a specific
29-
key (e.g. for multi-process setups where Watcher and WatcherClient run
30-
in separate processes).
29+
30+
Key resolution order:
31+
1. ZmqWrapper._hmac_key if set directly (e.g. via application code).
32+
2. TENSORWATCH_HMAC_KEY environment variable (hex-encoded, for
33+
multi-process setups where Watcher and WatcherClient run in
34+
separate processes).
35+
3. A random 32-byte key generated with os.urandom (single-process
36+
default).
3137
"""
3238
if ZmqWrapper._hmac_key is None:
33-
ZmqWrapper._hmac_key = os.urandom(32)
39+
env_key = os.environ.get('TENSORWATCH_HMAC_KEY')
40+
if env_key:
41+
ZmqWrapper._hmac_key = bytes.fromhex(env_key)
42+
else:
43+
ZmqWrapper._hmac_key = os.urandom(32)
3444
return ZmqWrapper._hmac_key
3545

3646
@staticmethod
@@ -52,7 +62,7 @@ def verify_and_loads(signed_data: bytes):
5262
expected = hmac.new(ZmqWrapper.get_hmac_key(), payload, hashlib.sha256).digest()
5363
if not hmac.compare_digest(sig, expected):
5464
raise ValueError("HMAC verification failed - rejecting untrusted message")
55-
return pickle.loads(payload)
65+
return restricted_loads(payload)
5666

5767
@staticmethod
5868
def initialize():

0 commit comments

Comments
 (0)