Skip to content

Commit

Permalink
Pytorch Load / Save Plugin
Browse files Browse the repository at this point in the history
This plugin checks for the use of `torch.load` and `torch.save`.
Using `torch.load` with untrusted data can lead to arbitrary code
execution, and improper use of `torch.save` might expose sensitive
data or lead to data corruption.

Signed-off-by: Luke Hinds <luke@stacklok.com>
  • Loading branch information
lukehinds committed Mar 3, 2024
1 parent 4c5b3c8 commit 8b92a02
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 0 deletions.
25 changes: 25 additions & 0 deletions bandit/blacklists/calls.py
Expand Up @@ -320,6 +320,19 @@
| | | - os.tmpnam | |
+------+---------------------+------------------------------------+-----------+
B704: pytorch_load_save
Use of unsafe PyTorch load. `torch.load` can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead to data
corruption.
+------+---------------------+--------------------------------------+-----------+
| ID | Name | Calls | Severity |
+======+=====================+======================================+===========+
| B704 | pytorch_load_save| | - torch.load | Medium |
| B704 | pytorch_load_save| | - torch.save | Medium |
+------+---------------------+--------------------------------------+-----------+
"""
import sys

Expand Down Expand Up @@ -685,6 +698,18 @@ def gen_blacklist():
)
)

sets.append(
utils.build_conf_dict(
"pytorch_load_save",
"B704",
issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
["torch.load", "torch.save"],
"Use of unsafe PyTorch load or save",
"MEDIUM",
)
)


# skipped B324 (used in bandit/plugins/hashlib_new_insecure_functions.py)

# skipped B325 as the check for a call to os.tempnam and os.tmpnam have
Expand Down
66 changes: 66 additions & 0 deletions bandit/plugins/pytorch_load_save.py
@@ -0,0 +1,66 @@
# Copyright (c) 2024 Stacklok, Inc.
#
# SPDX-License-Identifier: Apache-2.0
r"""
=========================================
B704: Test for unsafe PyTorch load or save
=========================================
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
with untrusted data can lead to arbitrary code execution, and improper use of
`torch.save` might expose sensitive data or lead to data corruption.
:Example:
.. code-block:: none
>> Issue: Use of unsafe PyTorch load or save
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
8 another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
9
10 print("Model loaded successfully!")
.. seealso::
- https://cwe.mitre.org/data/definitions/94.html
.. versionadded:: 1.7.8
"""
import bandit
from bandit.core import issue
from bandit.core import test_properties as test


@test.checks("Call")
@test.test_id("B704") # Ensure the test ID is unique and does not conflict with existing Bandit tests
def pytorch_load_save(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
with untrusted data can lead to arbitrary code execution, and improper use of
`torch.save` might expose sensitive data or lead to data corruption.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
if not imported and isinstance(qualname, str):
return

qualname_list = qualname.split(".")
func = qualname_list[-1]
if all(
[
"torch" in qualname_list,
func in ["load"],
not context.check_call_arg_value("map_location", "cpu"),
]
):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
text="Use of unsafe PyTorch load or save",
cwe=issue.Cwe.UNTRUSTED_INPUT,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b704_pytorch_load_save.rst
@@ -0,0 +1,5 @@
-----------------------
B704: pytorch_load_save
-----------------------

.. automodule:: bandit.plugins.pytorch_load_save
16 changes: 16 additions & 0 deletions examples/pytorch_load_save.py
@@ -0,0 +1,16 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

print("Model loaded successfully!")
3 changes: 3 additions & 0 deletions setup.cfg
Expand Up @@ -148,6 +148,9 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save

[build_sphinx]
all_files = 1
build-dir = doc/build
Expand Down
8 changes: 8 additions & 0 deletions tests/functional/test_functional.py
Expand Up @@ -930,3 +930,11 @@ def test_tarfile_unsafe_members(self):
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 1},
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 1, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)

0 comments on commit 8b92a02

Please sign in to comment.