Skip to content

Commit

Permalink
Add support for parameters in scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jimtng committed Jun 9, 2022
1 parent ec1fae6 commit ad144f2
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 98 deletions.
95 changes: 85 additions & 10 deletions esphome/components/script/__init__.py
Expand Up @@ -2,7 +2,8 @@
import esphome.config_validation as cv
from esphome import automation
from esphome.automation import maybe_simple_id
from esphome.const import CONF_ID, CONF_MODE
from esphome.const import CONF_ID, CONF_MODE, CONF_PARAMETERS
from esphome.core import CORE

CODEOWNERS = ["@esphome/core"]
script_ns = cg.esphome_ns.namespace("script")
Expand All @@ -16,6 +17,7 @@
QueueingScript = script_ns.class_("QueueingScript", Script, cg.Component)
ParallelScript = script_ns.class_("ParallelScript", Script)

CONF_SCRIPT = "script"
CONF_SINGLE = "single"
CONF_RESTART = "restart"
CONF_QUEUED = "queued"
Expand All @@ -29,6 +31,26 @@
CONF_PARALLEL: ParallelScript,
}

# TODO add support for entity ids as params
SCRIPT_PARAMS_NATIVE_TYPES = {
"bool": bool,
"int": cg.int32,
"float": float,
"string": cg.std_string,
"bool[]": cg.std_vector.template(bool),
"int[]": cg.std_vector.template(cg.int32),
"float[]": cg.std_vector.template(float),
"string[]": cg.std_vector.template(cg.std_string),
}


def get_script(script_id):
scripts = CORE.config.get(CONF_SCRIPT, {})
for script in scripts:
if script.get(CONF_ID, None) == script_id:
return script
return {}


def check_max_runs(value):
if CONF_MAX_RUNS not in value:
Expand All @@ -47,6 +69,19 @@ def assign_declare_id(value):
return value


def parameters_to_template(args):
template_args = []
func_args = []
script_arg_names = []
for name, type_ in args.items():
native_type = SCRIPT_PARAMS_NATIVE_TYPES[type_]
template_args.append(native_type)
func_args.append((native_type, name))
script_arg_names.append(name)
template = cg.TemplateArguments(*template_args)
return template, func_args


CONFIG_SCHEMA = automation.validate_automation(
{
# Don't declare id as cv.declare_id yet, because the ID type
Expand All @@ -56,6 +91,11 @@ def assign_declare_id(value):
*SCRIPT_MODES, lower=True
),
cv.Optional(CONF_MAX_RUNS): cv.positive_int,
cv.Optional(CONF_PARAMETERS, default={}): cv.Schema(
{
cv.validate_id_name: cv.one_of(*SCRIPT_PARAMS_NATIVE_TYPES, lower=True),
}
),
},
extra_validators=cv.All(check_max_runs, assign_declare_id),
)
Expand All @@ -65,7 +105,8 @@ async def to_code(config):
# Register all variables first, so that scripts can use other scripts
triggers = []
for conf in config:
trigger = cg.new_Pvariable(conf[CONF_ID])
template, func_args = parameters_to_template(conf[CONF_PARAMETERS])
trigger = cg.new_Pvariable(conf[CONF_ID], template)
# Add a human-readable name to the script
cg.add(trigger.set_name(conf[CONF_ID].id))

Expand All @@ -75,10 +116,10 @@ async def to_code(config):
if conf[CONF_MODE] == CONF_QUEUED:
await cg.register_component(trigger, conf)

triggers.append((trigger, conf))
triggers.append((trigger, func_args, conf))

for trigger, conf in triggers:
await automation.build_automation(trigger, [], conf)
for trigger, func_args, conf in triggers:
await automation.build_automation(trigger, func_args, conf)


@automation.register_action(
Expand All @@ -87,12 +128,43 @@ async def to_code(config):
maybe_simple_id(
{
cv.Required(CONF_ID): cv.use_id(Script),
}
cv.Optional(cv.validate_id_name): cv.templatable(cv.valid),
},
),
)
async def script_execute_action_to_code(config, action_id, template_arg, args):
from esphome import core

async def get_ordered_args(config, script_params):
config_args = config.copy()
config_args.pop(CONF_ID)

# match script_args to the formal parameter order
script_args = []
for _, name in script_params:
if name not in config_args:
raise cv.Invalid(
f"Missing parameter: '{name}' in script.execute {config[CONF_ID]}"
)
arg = await cg.templatable(config_args[name], args, None)
script_args.append(arg)
return script_args

script = get_script(config[CONF_ID])
params = script.get(CONF_PARAMETERS, [])
template, script_params = parameters_to_template(params)
script_args = await get_ordered_args(config, script_params)

# We need to use the parent class 'Script' as the template argument
# to match the partial specialization of the ScriptExecuteAction template
parent_class = core.ID(None, type=Script)
parent_class.type = parent_class.type.template(template)
template_arg = cg.TemplateArguments(parent_class.type, *template_arg)

paren = await cg.get_variable(config[CONF_ID])
return cg.new_Pvariable(action_id, template_arg, paren)
var = cg.new_Pvariable(action_id, template_arg, paren)
cg.add(var.set_args(*script_args))
return var


@automation.register_action(
Expand All @@ -101,7 +173,8 @@ async def script_execute_action_to_code(config, action_id, template_arg, args):
maybe_simple_id({cv.Required(CONF_ID): cv.use_id(Script)}),
)
async def script_stop_action_to_code(config, action_id, template_arg, args):
paren = await cg.get_variable(config[CONF_ID])
full_id, paren = await cg.get_variable_with_full_id(config[CONF_ID])
template_arg = cg.TemplateArguments(full_id.type, *template_arg)
return cg.new_Pvariable(action_id, template_arg, paren)


Expand All @@ -111,7 +184,8 @@ async def script_stop_action_to_code(config, action_id, template_arg, args):
maybe_simple_id({cv.Required(CONF_ID): cv.use_id(Script)}),
)
async def script_wait_action_to_code(config, action_id, template_arg, args):
paren = await cg.get_variable(config[CONF_ID])
full_id, paren = await cg.get_variable_with_full_id(config[CONF_ID])
template_arg = cg.TemplateArguments(full_id.type, *template_arg)
var = cg.new_Pvariable(action_id, template_arg, paren)
await cg.register_component(var, {})
return var
Expand All @@ -123,5 +197,6 @@ async def script_wait_action_to_code(config, action_id, template_arg, args):
automation.maybe_simple_id({cv.Required(CONF_ID): cv.use_id(Script)}),
)
async def script_is_running_to_code(config, condition_id, template_arg, args):
paren = await cg.get_variable(config[CONF_ID])
full_id, paren = await cg.get_variable_with_full_id(config[CONF_ID])
template_arg = cg.TemplateArguments(full_id.type, *template_arg)
return cg.new_Pvariable(condition_id, template_arg, paren)
64 changes: 1 addition & 63 deletions esphome/components/script/script.cpp
@@ -1,67 +1,5 @@
#include "script.h"
#include "esphome/core/log.h"

namespace esphome {
namespace script {

static const char *const TAG = "script";

void SingleScript::execute() {
if (this->is_action_running()) {
ESP_LOGW(TAG, "Script '%s' is already running! (mode: single)", this->name_.c_str());
return;
}

this->trigger();
}

void RestartScript::execute() {
if (this->is_action_running()) {
ESP_LOGD(TAG, "Script '%s' restarting (mode: restart)", this->name_.c_str());
this->stop_action();
}

this->trigger();
}

void QueueingScript::execute() {
if (this->is_action_running()) {
// num_runs_ is the number of *queued* instances, so total number of instances is
// num_runs_ + 1
if (this->max_runs_ != 0 && this->num_runs_ + 1 >= this->max_runs_) {
ESP_LOGW(TAG, "Script '%s' maximum number of queued runs exceeded!", this->name_.c_str());
return;
}

ESP_LOGD(TAG, "Script '%s' queueing new instance (mode: queued)", this->name_.c_str());
this->num_runs_++;
return;
}

this->trigger();
// Check if the trigger was immediate and we can continue right away.
this->loop();
}

void QueueingScript::stop() {
this->num_runs_ = 0;
Script::stop();
}

void QueueingScript::loop() {
if (this->num_runs_ != 0 && !this->is_action_running()) {
this->num_runs_--;
this->trigger();
}
}

void ParallelScript::execute() {
if (this->max_runs_ != 0 && this->automation_parent_->num_running() >= this->max_runs_) {
ESP_LOGW(TAG, "Script '%s' maximum number of parallel runs exceeded!", this->name_.c_str());
return;
}
this->trigger();
}

} // namespace script
namespace script {} // namespace script
} // namespace esphome

0 comments on commit ad144f2

Please sign in to comment.