Skip to content

Commit

Permalink
stages/prompt: Add initial_data prompt field and ability to select a …
Browse files Browse the repository at this point in the history
…default choice for choice fields (#5095)

* Added initial_value to model

* Added initial_value to admin panel

* Added initial_value support to flows; updated tests

* Updated default blueprints

* update docs

* Fix test

* Fix another test

* Fix yet another test

* Add placeholder migration

* Remove unused import
  • Loading branch information
sdimovv committed Apr 19, 2023
1 parent 04cc781 commit ee6edec
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 114 deletions.
1 change: 1 addition & 0 deletions authentik/policies/password/tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_prompt_data(self):
"label": "PASSWORD_LABEL",
"order": 0,
"placeholder": "PASSWORD_PLACEHOLDER",
"initial_value": "",
"required": True,
"type": "password",
"sub_text": "",
Expand Down
2 changes: 2 additions & 0 deletions authentik/stages/prompt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ class Meta:
"type",
"required",
"placeholder",
"initial_value",
"order",
"promptstage_set",
"sub_text",
"placeholder_expression",
"initial_value_expression",
]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Generated by Django 4.1.7 on 2023-03-24 17:32

from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor


def migrate_placeholder_expressions(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from authentik.stages.prompt.models import CHOICE_FIELDS

db_alias = schema_editor.connection.alias
Prompt = apps.get_model("authentik_stages_prompt", "prompt")

for prompt in Prompt.objects.using(db_alias).all():
if not prompt.placeholder_expression or prompt.type in CHOICE_FIELDS:
continue

prompt.initial_value = prompt.placeholder
prompt.initial_value_expression = True
prompt.placeholder = ""
prompt.placeholder_expression = False
prompt.save()


class Migration(migrations.Migration):
dependencies = [
("authentik_stages_prompt", "0010_alter_prompt_placeholder_alter_prompt_type"),
]

operations = [
migrations.AddField(
model_name="prompt",
name="initial_value",
field=models.TextField(
blank=True,
help_text="Optionally pre-fill the input with an initial value. When creating a fixed choice field, enable interpreting as expression and return a list to return multiple default choices.",
),
),
migrations.AddField(
model_name="prompt",
name="initial_value_expression",
field=models.BooleanField(default=False),
),
migrations.AlterField(
model_name="prompt",
name="placeholder",
field=models.TextField(
blank=True,
help_text="Optionally provide a short hint that describes the expected input value. When creating a fixed choice field, enable interpreting as expression and return a list to return multiple choices.",
),
),
migrations.RunPython(code=migrate_placeholder_expressions),
]
72 changes: 59 additions & 13 deletions authentik/stages/prompt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from authentik.lib.models import SerializerModel
from authentik.policies.models import Policy

CHOICES_CONTEXT_SUFFIX = "__choices"

LOGGER = get_logger()


Expand Down Expand Up @@ -119,15 +121,25 @@ class Prompt(SerializerModel):
placeholder = models.TextField(
blank=True,
help_text=_(
"When creating a Radio Button Group or Dropdown, enable interpreting as "
"Optionally provide a short hint that describes the expected input value. "
"When creating a fixed choice field, enable interpreting as "
"expression and return a list to return multiple choices."
),
)
initial_value = models.TextField(
blank=True,
help_text=_(
"Optionally pre-fill the input with an initial value. "
"When creating a fixed choice field, enable interpreting as "
"expression and return a list to return multiple default choices."
),
)
sub_text = models.TextField(blank=True, default="")

order = models.IntegerField(default=0)

placeholder_expression = models.BooleanField(default=False)
initial_value_expression = models.BooleanField(default=False)

@property
def serializer(self) -> Type[BaseSerializer]:
Expand All @@ -148,8 +160,8 @@ def get_choices(

raw_choices = self.placeholder

if self.field_key in prompt_context:
raw_choices = prompt_context[self.field_key]
if self.field_key + CHOICES_CONTEXT_SUFFIX in prompt_context:
raw_choices = prompt_context[self.field_key + CHOICES_CONTEXT_SUFFIX]
elif self.placeholder_expression:
evaluator = PropertyMappingEvaluator(
self, user, request, prompt_context=prompt_context, dry_run=dry_run
Expand Down Expand Up @@ -184,16 +196,9 @@ def get_placeholder(
) -> str:
"""Get fully interpolated placeholder"""
if self.type in CHOICE_FIELDS:
# Make sure to return a valid choice as placeholder
choices = self.get_choices(prompt_context, user, request, dry_run=dry_run)
if not choices:
return ""
return choices[0]

if self.field_key in prompt_context:
# We don't want to parse this as an expression since a user will
# be able to control the input
return prompt_context[self.field_key]
# Choice fields use the placeholder to define all valid choices.
# Therefore their actual placeholder is always blank
return ""

if self.placeholder_expression:
evaluator = PropertyMappingEvaluator(
Expand All @@ -211,6 +216,47 @@ def get_placeholder(
raise wrapped from exc
return self.placeholder

def get_initial_value(
self,
prompt_context: dict,
user: User,
request: HttpRequest,
dry_run: Optional[bool] = False,
) -> str:
"""Get fully interpolated initial value"""

if self.field_key in prompt_context:
# We don't want to parse this as an expression since a user will
# be able to control the input
value = prompt_context[self.field_key]
elif self.initial_value_expression:
evaluator = PropertyMappingEvaluator(
self, user, request, prompt_context=prompt_context, dry_run=dry_run
)
try:
value = evaluator.evaluate(self.initial_value)
except Exception as exc: # pylint:disable=broad-except
wrapped = PropertyMappingExpressionException(str(exc))
LOGGER.warning(
"failed to evaluate prompt initial value",
exc=wrapped,
)
if dry_run:
raise wrapped from exc
value = self.initial_value
else:
value = self.initial_value

if self.type in CHOICE_FIELDS:
# Ensure returned value is a valid choice
choices = self.get_choices(prompt_context, user, request)
if not choices:
return ""
if value not in choices:
return choices[0]

return value

def field(self, default: Optional[Any], choices: Optional[list[Any]] = None) -> CharField:
"""Get field type for Challenge and response. Choices are only valid for CHOICE_FIELDS."""
field_class = CharField
Expand Down
11 changes: 8 additions & 3 deletions authentik/stages/prompt/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class StagePromptSerializer(PassiveSerializer):
type = ChoiceField(choices=FieldTypes.choices)
required = BooleanField()
placeholder = CharField(allow_blank=True)
initial_value = CharField(allow_blank=True)
order = IntegerField()
sub_text = CharField(allow_blank=True)
choices = ListField(child=CharField(allow_blank=True), allow_empty=True, allow_null=True)
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(self, *args, **kwargs):
choices = field.get_choices(
plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request
)
current = field.get_placeholder(
current = field.get_initial_value(
plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request
)
self.fields[field.field_key] = field.field(current, choices)
Expand Down Expand Up @@ -197,8 +198,9 @@ def get_prompt_challenge_fields(self, fields: list[Prompt], context: dict, dry_r
serializers = []
for field in fields:
data = StagePromptSerializer(field).data
# Ensure all choices and placeholders are str, as otherwise further in
# we can fail serializer validation if we return some types such as bool
# Ensure all choices, placeholders and initial values are str, as
# otherwise further in we can fail serializer validation if we return
# some types such as bool
choices = field.get_choices(context, self.get_pending_user(), self.request, dry_run)
if choices:
data["choices"] = [str(choice) for choice in choices]
Expand All @@ -207,6 +209,9 @@ def get_prompt_challenge_fields(self, fields: list[Prompt], context: dict, dry_r
data["placeholder"] = str(
field.get_placeholder(context, self.get_pending_user(), self.request, dry_run)
)
data["initial_value"] = str(
field.get_initial_value(context, self.get_pending_user(), self.request, dry_run)
)
serializers.append(data)
return serializers

Expand Down

0 comments on commit ee6edec

Please sign in to comment.