Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace eval with a safer alternative #147

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"trl",
"peft>=0.8.0",
"datasets>=2.15.0",
"fire"
"fire",
"simpleeval",
]

[project.optional-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
TRAINER_CONFIG_INCORRECT_SOURCE_EVENT_EXPOSED_METRICS_YAML = os.path.join(
_DATA_DIR, "incorrect_source_event_exposed_metrics.yaml"
)
TRAINER_CONFIG_TEST_INVALID_TYPE_RULE_YAML = os.path.join(
_DATA_DIR, "loss_with_invalid_type_rule.yaml"
)
TRAINER_CONFIG_TEST_MALICIOUS_OS_RULE_YAML = os.path.join(
_DATA_DIR, "loss_with_malicious_os_rule.yaml"
)
Expand Down
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_with_invalid_type_rule.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-wrong-os-rule
triggers:
- on_log
rule: "2+2"
operations:
- hfcontrols.should_training_stop
29 changes: 25 additions & 4 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

# Standard
from dataclasses import dataclass
from typing import Any

# Third Party
from simpleeval import FunctionNotDefined
from transformers import IntervalStrategy, TrainerControl, TrainerState
import pytest

Expand All @@ -32,7 +32,6 @@
import tests.data.trainercontroller as td

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler
import tuning.config.configs as config
import tuning.trainercontroller as tc

Expand Down Expand Up @@ -204,6 +203,25 @@ def test_custom_operation_invalid_action_handler():
)


def test_invalid_type_rule():
"""Tests the invalid type rule using configuration
`examples/trainer-controller-configs/loss_with_invalid_type_rule.yaml`
"""
test_data = _setup_data()
with pytest.raises(TypeError) as exception_handler:
tc_callback = tc.TrainerControllerCallback(
td.TRAINER_CONFIG_TEST_INVALID_TYPE_RULE_YAML
)
control = TrainerControl(should_training_stop=False)
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(
args=test_data.args, state=test_data.state, control=control
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert str(exception_handler.value) == "Rule failed due to incorrect type usage"


def test_malicious_os_rule():
"""Tests the malicious rule using configuration
`examples/trainer-controller-configs/loss_with_malicious_os_rule.yaml`
Expand Down Expand Up @@ -235,14 +253,17 @@ def test_malicious_input_rule():
td.TRAINER_CONFIG_TEST_MALICIOUS_INPUT_RULE_YAML
)
control = TrainerControl(should_training_stop=False)
with pytest.raises(TypeError) as exception_handler:
with pytest.raises(FunctionNotDefined) as exception_handler:
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(
args=test_data.args, state=test_data.state, control=control
)
# Trigger rule and test the condition
tc_callback.on_log(args=test_data.args, state=test_data.state, control=control)
assert str(exception_handler.value) == "Rule failed due to incorrect type usage"
assert (
str(exception_handler.value)
== "Function 'input' not defined, for expression 'input('Please enter your password:')'."
)


def test_invalid_trigger():
Expand Down
166 changes: 166 additions & 0 deletions tests/utils/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
from typing import Tuple

# Third Party
import numpy as np
import pytest

# Local
from tuning.utils.evaluator import get_evaluator


def test_mailicious_inputs_to_eval():
"""Tests the malicious rules"""
rules: list[Tuple[str, bool, str]] = [
# Valid rules
("", False, "flags['is_training'] == False"),
("", False, "not flags['is_training']"),
("", True, "-10 < loss"),
("", True, "+1000 > loss"),
("", True, "~1000 < loss"),
("", True, "(10 + 10) < loss"),
("", True, "(20 - 10) < loss"),
("", True, "(20/10) < loss"),
("", True, "(20 % 10) < loss"),
("", False, "loss < 1.0"),
("", False, "(loss < 1.0)"),
("", False, "loss*loss < 1.0"),
("", False, "loss*loss*loss < 1.0"),
("", False, "(loss*loss)*loss < 1.0"),
("", True, "int(''.join(['3', '4'])) < loss"),
("", True, "loss < 9**9"),
("", False, "loss < sqrt(xs[0]*xs[0] + xs[1]*xs[1])"),
("", True, "len(xs) > 2"),
("", True, "loss < abs(-100)"),
("", True, "loss == flags.aaa.bbb[0].ccc"),
("", True, "array3d[0][1][1] == 4"),
("", True, "numpyarray[0][1][1] == 4"),
(
"",
True,
"len(xs) == 4 and xs[0] == 1 and (xs[1] == 0 or xs[2] == 0) and xs[3] == 2",
),
# Invalid rules
(
"'aaa' is not defined for expression 'loss == aaa.bbb[0].ccc'",
False,
"loss == aaa.bbb[0].ccc",
),
("0", False, "loss == flags[0].ccc"), # KeyError
(
"Attribute 'ddd' does not exist in expression 'loss == flags.ddd[0].ccc'",
False,
"loss == flags.ddd[0].ccc",
),
(
"Sorry, access to __attributes or func_ attributes is not available. (__class__)",
False,
"'x'.__class__",
),
(
"Lambda Functions not implemented",
False,
# Try to instantiate and call Quitter
"().__class__.__base__.__subclasses__()[141]('', '')()",
),
(
"Lambda Functions not implemented",
False,
# pylint: disable=line-too-long
"[x for x in ().__class__.__base__.__subclasses__() if x.__name__ == 'Quitter'][0]('', '')()",
),
(
"Function 'getattr' not defined, for expression 'getattr((), '__class__')'.",
False,
"getattr((), '__class__')",
),
(
"Function 'getattr' not defined, for expression 'getattr((), '_' '_class_' '_')'.",
False,
"getattr((), '_' '_class_' '_')",
),
(
"Sorry, I will not evalute something that long.",
False,
'["hello"]*10000000000',
),
(
"Sorry, I will not evalute something that long.",
False,
"'i want to break free'.split() * 9999999999",
),
(
"Lambda Functions not implemented",
False,
"(lambda x='i want to break free'.split(): x * 9999999999)()",
),
(
"Sorry, NamedExpr is not available in this evaluator",
False,
"(x := 'i want to break free'.split()) and (x * 9999999999)",
),
("Sorry! I don't want to evaluate 9 ** 387420489", False, "9**9**9**9"),
(
"Function 'mymetric1' not defined, for expression 'mymetric1() > loss'.",
True,
"mymetric1() > loss",
),
(
"Function 'mymetric2' not defined, for expression 'mymetric2(loss) > loss'.",
True,
"mymetric2(loss) > loss",
),
]
metrics = {
"loss": 42.0,
"flags": {"is_training": True, "aaa": {"bbb": [{"ccc": 42.0}]}},
"xs": [1, 0, 0, 2],
"array3d": [
[
[1, 2],
[3, 4],
],
[
[5, 6],
[7, 8],
],
],
"numpyarray": (np.arange(8).reshape((2, 2, 2)) + 1),
}

evaluator = get_evaluator(metrics=metrics)

for validation_error, expected_rule_is_true, rule in rules:
rule_parsed = evaluator.parse(expr=rule)
if validation_error == "":
actual_rule_is_true = evaluator.eval(
expr=rule,
previously_parsed=rule_parsed,
)
assert (
actual_rule_is_true == expected_rule_is_true
), "failed to execute the rule"
else:
with pytest.raises(Exception) as exception_handler:
evaluator.eval(
expr=rule,
previously_parsed=rule_parsed,
)
assert str(exception_handler.value) == validation_error
37 changes: 26 additions & 11 deletions tuning/trainercontroller/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

# Standard
from importlib import resources as impresources
from typing import List, Union
from typing import Dict, List, Union
import inspect
import os
import re

# Third Party
from simpleeval import EvalWithCompoundTypes, FeatureNotAvailable, NameNotDefined
from transformers import (
TrainerCallback,
TrainerControl,
Expand All @@ -43,6 +44,7 @@
from tuning.trainercontroller.operations import (
operation_handlers as default_operation_handlers,
)
from tuning.utils.evaluator import get_evaluator

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -174,7 +176,7 @@ def __init__(self, trainer_controller_config: Union[dict, str]):
self.register_operation_handlers(default_operation_handlers)

# controls
self.control_actions_on_event: dict[str, list[Control]] = {}
self.control_actions_on_event: Dict[str, list[Control]] = {}

# List of fields produced by the metrics
self.metrics = {}
Expand Down Expand Up @@ -208,23 +210,26 @@ def _compute_metrics(self, event_name: str, **kwargs):
self.metrics[m.get_name()] = m.compute(event_name=event_name, **kwargs)

def _take_control_actions(self, event_name: str, **kwargs):
"""Invokes the act() method for all the operations registered for a given event. \
Note here that the eval() is invoked with `__builtins__` set to None. \
This is a precaution to restric the scope of eval(), to only the \
fields produced by the metrics.
"""Invokes the act() method for all the operations registered for a given event.

Args:
event_name: str. Event name.
kwargs: List of arguments (key, value)-pairs.
"""
if event_name in self.control_actions_on_event:
evaluator = get_evaluator(metrics=self.metrics)
for control_action in self.control_actions_on_event[event_name]:
rule_succeeded = False
try:
# pylint: disable=eval-used
rule_succeeded = eval(
control_action.rule, {"__builtins__": None}, self.metrics
rule_succeeded = evaluator.eval(
expr=control_action.rule_str,
previously_parsed=control_action.rule,
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
)
if not isinstance(rule_succeeded, bool):
raise TypeError(
"expected the rule to evaluate to a boolean. actual type: %s"
% (type(rule_succeeded))
)
except TypeError as et:
raise TypeError("Rule failed due to incorrect type usage") from et
except ValueError as ev:
Expand All @@ -235,6 +240,14 @@ def _take_control_actions(self, event_name: str, **kwargs):
raise NameError(
"Rule failed due to use of disallowed variables"
) from en
except NameNotDefined as en1:
raise NameError(
"Rule failed because some of the variables are not defined"
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
) from en1
except FeatureNotAvailable as ef:
raise NotImplementedError(
"Rule failed because it uses some unsupported features"
) from ef
if rule_succeeded:
for operation_action in control_action.operation_actions:
logger.info(
Expand Down Expand Up @@ -374,9 +387,11 @@ def on_init_end(
% (controller_name, event_name)
)
# Generates the byte-code for the rule from the trainer configuration
curr_rule = controller[CONTROLLER_RULE_KEY]
control = Control(
name=controller_name,
rule=compile(controller_rule, "", "eval"),
name=controller[CONTROLLER_NAME_KEY],
rule_str=curr_rule,
rule=EvalWithCompoundTypes.parse(expr=curr_rule),
operation_actions=[],
)
for control_operation_name in controller_ops:
Expand Down
4 changes: 3 additions & 1 deletion tuning/trainercontroller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Standard
from dataclasses import dataclass
from typing import List, Optional
import ast

# Local
from tuning.trainercontroller.operations import Operation
Expand All @@ -36,5 +37,6 @@ class Control:
"""Stores the name of control, rule byte-code corresponding actions"""

name: str
rule: Optional[object] = None # stores bytecode of the compiled rule
rule_str: str
rule: Optional[ast.AST] = None # stores the abstract syntax tree of the parsed rule
operation_actions: Optional[List[OperationAction]] = None
20 changes: 20 additions & 0 deletions tuning/utils/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Standard
from math import sqrt

# Third Party
from simpleeval import DEFAULT_FUNCTIONS, DEFAULT_NAMES, EvalWithCompoundTypes


def get_evaluator(metrics: dict) -> EvalWithCompoundTypes:
"""Returns an evaluator that can be used to evaluate simple Python expressions."""
all_names = {
**metrics,
**DEFAULT_NAMES.copy(),
}
all_funcs = {
"abs": abs,
"len": len,
"sqrt": sqrt,
**DEFAULT_FUNCTIONS.copy(),
}
return EvalWithCompoundTypes(functions=all_funcs, names=all_names)
Loading