Skip to content

Commit

Permalink
STL-related updates were added in evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
acarcelik committed Jan 25, 2024
1 parent e2ff59c commit 377bfe8
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 51 deletions.
6 changes: 5 additions & 1 deletion bark_ml/environments/counterfactual_runtime.py
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import logging
import matplotlib.pyplot as plt
import types

# bark
from bark.runtime.commons.parameters import ParameterServer
Expand Down Expand Up @@ -87,7 +88,10 @@ def ReplaceBehaviorModel(self, agent_id=None, behavior=None):
cloned_world = self._world.Copy()
evaluators = self._evaluator._bark_eval_fns
for eval_key, eval_fn in evaluators.items():
cloned_world.AddEvaluator(eval_key, eval_fn())
if isinstance(eval_fn, types.LambdaType):
cloned_world.AddEvaluator(eval_key, eval_fn())
else:
cloned_world.AddEvaluator(eval_key, eval_fn)
if behavior is not None:
cloned_world.agents[agent_id].behavior_model = behavior
return cloned_world
Expand Down
74 changes: 58 additions & 16 deletions bark_ml/evaluators/evaluator_configs.py
Expand Up @@ -8,6 +8,8 @@
EvaluatorStepCount, EvaluatorDrivableArea
from bark.core.geometry import Point2d
from bark_ml.evaluators.general_evaluator import *
from bark_ml.evaluators.stl.evaluator_stl import *
from bark_ml.evaluators.stl.label_functions.safe_distance_label_function import *

class GoalReached(GeneralEvaluator):
def __init__(self, params):
Expand Down Expand Up @@ -130,8 +132,22 @@ def __init__(self, params):
"pot_center_functor": PotentialCenterlineFunctor(self._params),
"pot_vel_functor": PotentialVelocityFunctor(self._params)
})

class EvaluatorConfigurator(GeneralEvaluator):
def __init__(self, params):
self._params = params

try:
quantized = self._params["ML"]["EvaluatorConfigurator"]["RulesConfigs"]["quantized"]
except KeyError:
quantized = False

# rule_functor_prefix = "TrafficRuleSTL" if quantized else "TrafficRuleLTL"
# rule_functor_name = f"{rule_functor_prefix}Functor"

# rule_impl_prefix = "traffic_rule_stl" if quantized else "traffic_rule_ltl"
# rule_impl_name = f"{rule_impl_prefix}_functor"

# add mapping of functors to keys
self._fn_key_map = {
"CollisionFunctor" : "collision_functor",
Expand All @@ -149,12 +165,13 @@ def __init__(self, params):
"CollisionDrivableAreaFunctor" : "collision_drivable_area_functor",
"PotentialGoalReachedVelocityFunctor": "pot_goal_vel_functor",
"MaxStepCountAsGoalFunctor": "max_step_count_as_goal_functor",
"PotentialGoalPolyFunctor": "pot_goal_poly_functor",
"TrafficRuleLTLFunctor": "traffic_rule_ltl_functor"
"PotentialGoalPolyFunctor": "pot_goal_poly_functor",
# rule_functor_name: rule_impl_name
}
self._params = params

functor_configs = self._params["ML"]["EvaluatorConfigurator"]["EvaluatorConfigs"]["FunctorConfigs"]
functor_config_params_dict = functor_configs.ConvertToDict()

# initialize functor and functorweights dicts
eval_fns = {}
# get values for each item
Expand All @@ -173,29 +190,54 @@ def __init__(self, params):
rules_configs = self._params["ML"]["EvaluatorConfigurator"]["RulesConfigs"]

for rule_config in rules_configs["Rules"]:
# print("Rule name:", rule_config["RuleName"])

# parse label function for each rule
labels_list = []

for label_conf in rule_config["RuleConfig"]["labels"]:
label_params_dict = label_conf["params"].ConvertToDict()

if label_conf["type"] == "EgoBeyondPointLabelFunction" or label_conf["type"] == "AgentBeyondPointLabelFunction":
merge_point = label_params_dict["point"]
label_params_dict["point"] = Point2d(merge_point[0],merge_point[1])
labels_list.append(eval("{}(*(label_params_dict.values()))".format(label_conf["type"])))
print("labels_list:",labels_list)
label = eval("{}(*(label_params_dict.values()))".format(label_conf["type"]))
labels_list.append(label)

# instance rule evaluator for each rule
#TODO: check if evaluatorLTL can access private function in python
ltl_formula_ = rule_config["RuleConfig"]["params"]["formula"]
print("ltl_formula_:",ltl_formula_)
tmp_ltl_settings = {}
tmp_ltl_settings["agent_id"] = 1
tmp_ltl_settings["ltl_formula"] = ltl_formula_
tmp_ltl_settings["label_functions"] = labels_list

tmp_ltl_eval = eval("{}(**tmp_ltl_settings)".format(rule_config["RuleConfig"]["type"]))
bark_evals[rule_config["RuleName"]] = lambda: tmp_ltl_eval
tl_formula_ = rule_config["RuleConfig"]["params"]["formula"]
# print("ltl_formula_:",tl_formula_)
# print("labels_list:",labels_list)

try:
eval_return_robustness_only = rule_config["RuleConfig"]["params"]["eval_return_robustness_only"]
except KeyError:
eval_return_robustness_only = True

tmp_tl_settings = {}
# Check if the key exists in tmp_tl_settings; if not, create a nested dictionary
if rule_config["RuleName"] not in tmp_tl_settings:
tmp_tl_settings[rule_config["RuleName"]] = {}

tmp_tl_settings[rule_config["RuleName"]]["agent_id"] = 1
tmp_tl_settings[rule_config["RuleName"]]["ltl_formula"] = tl_formula_
tmp_tl_settings[rule_config["RuleName"]]["label_functions"] = labels_list

if quantized:
tmp_tl_settings[rule_config["RuleName"]]["eval_return_robustness_only"] = eval_return_robustness_only

tmp_tl_eval = eval("{}(**tmp_tl_settings[rule_config['RuleName']])".format(rule_config["RuleConfig"]["type"]))
# print("tmp_tl_eval:", tmp_tl_eval)
# print("lambda: tmp_tl_eval: ", lambda: tmp_tl_eval)
bark_evals[rule_config["RuleName"]] = tmp_tl_eval
# bark_evals[rule_config["RuleName"]] = lambda: tmp_tl_eval

# add rule functors to bark_ml_eval_fns
functor_n_ = rule_config["RuleName"] + "_ltl_functor"
eval_fns[functor_n_] = eval("{}(rule_config)".format("TrafficRuleLTLFunctor"))
functor_n_ = rule_config["RuleName"] + "_stl_functor" if quantized else "_ltl_functor"
eval_fns[functor_n_] = eval("{}(rule_config)".format("TrafficRuleSTLFunctor" if quantized else "TrafficRuleLTLFunctor"))

# print("bark_evals: ", bark_evals)
# print("eval_fns: ", eval_fns)

super().__init__(params=self._params, bark_eval_fns=bark_evals, bark_ml_eval_fns=eval_fns)
55 changes: 53 additions & 2 deletions bark_ml/evaluators/general_evaluator.py
Expand Up @@ -4,6 +4,7 @@
# This software is released under the MIT License.
# https://opensource.org/licenses/MIT
import numpy as np
import types

from bark.core.world.evaluation import \
EvaluatorGoalReached, EvaluatorCollisionEgoAgent, \
Expand Down Expand Up @@ -245,7 +246,7 @@ def __call__(self, observed_world, action, eval_results):
cur_pot = self.DistancePotential(
cur_dist, self._params["MaxDist", "", 100.],
self._params["DistExponent", "", 0.2])
print("!!!!!!!!!!!!!current potential is ", 0.99*cur_pot - prev_pot)
# print("!!!!!!!!!!!!!current potential is ", 0.99*cur_pot - prev_pot)
return False, self.WeightedReward(self._params["Gamma", "", 0.99]*cur_pot - prev_pot), {}
return False, 0, {}

Expand Down Expand Up @@ -423,6 +424,48 @@ def Reset(self):
# TODO: MIN/MAX functor for defined state value
# TODO: Deviation functor for state-difference (desired vel. and x,y)

class TrafficRuleSTLFunctor(Functor):
def __init__(self, params):
self._params = params
self.traffic_rule_violation_pre = 0
self.traffic_rule_violation_post = 0
self.traffic_rule_violations = 0
self.traffic_rule_eval_result = ""
super().__init__(params=self._params)


def __call__(self, observed_world, action, eval_results):
self.traffic_rule_eval_result = eval_results[self._params["RuleName"]]
# print("Eval result in functor: ", self.traffic_rule_eval_result)

if isinstance(self.traffic_rule_eval_result, str):
results = self.traffic_rule_eval_result.split(";")
self.traffic_rule_violation_post = float(results[0])
self.traffic_rule_robustness = float(results[1])

max_vio_num = self._params["ViolationTolerance", "", 15]
if self.traffic_rule_violation_post < self.traffic_rule_violation_pre:
self.traffic_rule_violation_pre = self.traffic_rule_violation_post

current_traffic_rule_violations = self.traffic_rule_violation_post - self.traffic_rule_violation_pre
self.traffic_rule_violations = self.traffic_rule_violations + current_traffic_rule_violations
self.traffic_rule_violation_pre = self.traffic_rule_violation_post
# print("current traffic rule violations:", self.traffic_rule_violations)

if self.traffic_rule_violations > max_vio_num:
return True, 0, {}
else:
# print("WARNING: # of violations are NOT considered")
self.traffic_rule_robustness = self.traffic_rule_eval_result

return False, self.WeightedReward(self.traffic_rule_robustness), {}

def Reset(self):
# self.traffic_rule_violation_pre = 0
self.traffic_rule_violation_post = 0
self.traffic_rule_violations = 0
super().Reset()

class GeneralEvaluator:
"""Evaluator using Functors"""

Expand Down Expand Up @@ -455,13 +498,15 @@ def __init__(self,
}

def Evaluate(self, observed_world, action):

"""Returns information about the current world state."""
eval_results = observed_world.Evaluate()
reward = 0.
scheduleTerminate = False

for _, eval_fn in self._bark_ml_eval_fns.items():
t, r, i = eval_fn(observed_world, action, eval_results)

eval_results = {**eval_results, **i} # merge info
reward += r # accumulate reward
if t: # if any of the t are True -> terminal
Expand All @@ -473,7 +518,13 @@ def Reset(self, world):
world.ClearEvaluators()
for eval_name, eval_fn in self._bark_eval_fns.items():
#TODO: check if reset evaluatorLTL is needed
world.AddEvaluator(eval_name, eval_fn())
# print("world.AddEvaluator(eval_name): ", eval_name)
# print("world.AddEvaluator(eval_fn()): ", eval_fn())
if isinstance(eval_fn, types.LambdaType):
world.AddEvaluator(eval_name, eval_fn())
else:
# print(f"Reset method - Eval name: {eval_name}, eval_fn: {eval_fn}")
world.AddEvaluator(eval_name, eval_fn)
for _, eval_func in self._bark_ml_eval_fns.items():
eval_func.Reset()
return world
Expand Down
28 changes: 18 additions & 10 deletions bark_ml/evaluators/stl/evaluator_stl.py
Expand Up @@ -2,26 +2,34 @@
from bark_ml.evaluators.stl.label_functions.base_label_function import BaseQuantizedLabelFunction

class EvaluatorSTL(EvaluatorLTL):
def __init__(self, agent_id: int, ltl_formula_str: str, label_functions):
super().__init__(agent_id, ltl_formula_str, label_functions)
def __init__(self, agent_id: int, ltl_formula: str, label_functions, eval_return_robustness_only: bool = True):
super().__init__(agent_id, ltl_formula, label_functions)
self.robustness = float('inf')
self.label_functions_stl = label_functions
self.eval_return_robustness_only = eval_return_robustness_only

def Evaluate(self, observed_world):
eval_return = super().Evaluate(observed_world)
# print(f"Evaluate return: {eval_return}")
# print(f"Evaluate safety_violations: {super().safety_violations}")
# TODO: Should we remove the # of safety violations? We should subtract the robustness, shouldn't we?
eval_return = eval_return - self.compute_robustness()
eval_return = super().Evaluate(observed_world)

# print(f"Evaluate STL return: {eval_return}")

if self.eval_return_robustness_only:
eval_return = self.compute_robustness()
else:
eval_return = str(eval_return) + ";" + str(self.compute_robustness())
# print(f"Evaluate return updated: {eval_return}")

return eval_return

def compute_robustness(self):
self.robustness = float('inf')

for le in self.label_functions:
if isinstance(le, BaseQuantizedLabelFunction):

for le in self.label_functions:
# print(le)
if isinstance(le, BaseQuantizedLabelFunction):
self.robustness = min(self.robustness, le.get_current_robustness())

# print("------------------")
if self.robustness == float('inf') or self.robustness == float('-inf'):
self.robustness = 0.0

Expand Down

0 comments on commit 377bfe8

Please sign in to comment.