diff --git a/lib/machinery.ex b/lib/machinery.ex index a8e024fac2..dc3701b2f6 100644 --- a/lib/machinery.ex +++ b/lib/machinery.ex @@ -91,7 +91,10 @@ defmodule Machinery do {:error, @guarded_error} true -> - struct = Map.put(struct, :state, next_state) + struct = struct + |> Transition.run_before_callbacks(next_state, module) + |> Map.put(:state, next_state) + |> Transition.run_after_callbacks(next_state, module) {:ok, struct} end end diff --git a/lib/machinery/transition.ex b/lib/machinery/transition.ex index e2ef97530b..a683d524ae 100644 --- a/lib/machinery/transition.ex +++ b/lib/machinery/transition.ex @@ -5,6 +5,19 @@ defmodule Machinery.Transition do It's meant to be for internal use only. """ + @doc """ + Function responsible for checking if the transition from a state to another + was specifically declared. + """ + @spec declared_transition?(list, atom, atom) :: boolean + def declared_transition?(transitions, current_state, next_state) do + case Map.fetch(transitions, current_state) do + {:ok, [_|_] = declared_states} -> Enum.member?(declared_states, next_state) + {:ok, declared_state} -> declared_state == next_state + :error -> false + end + end + @doc """ Default guard transition fallback to make sure all transitions are permitted unless another existing guard condition exists. @@ -17,16 +30,25 @@ defmodule Machinery.Transition do error in FunctionClauseError -> guard_transition_fallback?(error) end - @doc """ - Function responsible for checking if the transition from a state to another - was specifically declared. - """ - @spec declared_transition?(list, atom, atom) :: boolean - def declared_transition?(transitions, current_state, next_state) do - case Map.fetch(transitions, current_state) do - {:ok, [_|_] = declared_states} -> Enum.member?(declared_states, next_state) - {:ok, declared_state} -> declared_state == next_state - :error -> false + def run_before_callbacks(struct, state, module) do + module.before_transition(struct, state) + rescue + error in UndefinedFunctionError -> callbacks_fallback(struct, error) + error in FunctionClauseError -> callbacks_fallback(struct, error) + end + + def run_after_callbacks(struct, state, module) do + module.after_transition(struct, state) + rescue + error in UndefinedFunctionError -> callbacks_fallback(struct, error) + error in FunctionClauseError -> callbacks_fallback(struct, error) + end + + defp callbacks_fallback(struct, error) do + if error.function in [:after_transition, :before_transition] && error.arity == 2 do + struct + else + raise error end end @@ -40,4 +62,4 @@ defmodule Machinery.Transition do raise error end end -end \ No newline at end of file +end diff --git a/test/machinery/transition_test.exs b/test/machinery/transition_test.exs new file mode 100644 index 0000000000..53687c2876 --- /dev/null +++ b/test/machinery/transition_test.exs @@ -0,0 +1,16 @@ +defmodule MachineryTest.TransitionTest do + use ExUnit.Case + doctest Machinery.Transition + alias Machinery.Transition + + test "declared_transition?/3 based on a map of transitions, current and next state" do + transitions = %{ + created: [:partial, :completed], + partial: :completed + } + assert Transition.declared_transition?(transitions, :created, :partial) + assert Transition.declared_transition?(transitions, :created, :completed) + assert Transition.declared_transition?(transitions, :partial, :completed) + refute Transition.declared_transition?(transitions, :partial, :created) + end +end diff --git a/test/machinery_test.exs b/test/machinery_test.exs index d8fa257629..a1f7a60164 100644 --- a/test/machinery_test.exs +++ b/test/machinery_test.exs @@ -16,9 +16,8 @@ defmodule MachineryTest do } def guard_transition(struct, :completed) do - - # Code to unquote code into this AST that - # will force and exception. + # Code to simulate and force an exception inside a + # guard function. if Map.get(struct, :force_exception) do Machinery.non_existing_function_should_raise_error() end @@ -28,7 +27,7 @@ defmodule MachineryTest do end defmodule TestModule do - defstruct state: nil, missing_fields: nil + defstruct state: nil, missing_fields: nil, force_exception: false use Machinery, states: [:created, :partial, :completed], @@ -36,6 +35,20 @@ defmodule MachineryTest do created: [:partial, :completed], partial: :completed } + + def before_transition(struct, :partial) do + # Code to simulate and force an exception inside a + # guard function. + if Map.get(struct, :force_exception) do + Machinery.non_existing_function_should_raise_error() + end + + Map.put(struct, :missing_fields, true) + end + + def after_transition(struct, :completed) do + Map.put(struct, :missing_fields, false) + end end test "All internal functions should be injected into AST" do @@ -76,8 +89,8 @@ defmodule MachineryTest do end test "Modules without guard conditions should allow transitions by default" do - struct = %TestModule{state: :created, missing_fields: true} - assert {:ok, %TestModule{state: :completed, missing_fields: true}} = Machinery.transition_to(struct, :completed) + struct = %TestModule{state: :created} + assert {:ok, %TestModule{state: :completed}} = Machinery.transition_to(struct, :completed) end test "Implict rescue on the guard clause internals should raise any other excepetion not strictly related to missing guard_tranistion/2 existence" do @@ -86,4 +99,22 @@ defmodule MachineryTest do Machinery.transition_to(wrong_struct, :completed) end end + + test "after_transition/2 and before_transition/2 callbacks should be automatically executed" do + struct = %TestModule{} + assert struct.missing_fields == nil + + {:ok, partial_struct} = Machinery.transition_to(struct, :partial) + assert partial_struct.missing_fields == true + + {:ok, completed_struct} = Machinery.transition_to(struct, :completed) + assert completed_struct.missing_fields == false + end + + test "Implict rescue on the callbacks internals should raise any other excepetion not strictly related to missing callbacks_fallback/2 existence" do + wrong_struct = %TestModule{state: :created, force_exception: true} + assert_raise UndefinedFunctionError, fn() -> + Machinery.transition_to(wrong_struct, :partial) + end + end end