Permalink
Browse files

Add __protocol_for__ that returns the protocol that matches the given…

… structure, related to #250
  • Loading branch information...
1 parent b901ed9 commit 7c26b215c7530838400d673b973bb81abb5d598b @josevalim josevalim committed Apr 24, 2012
Showing with 49 additions and 19 deletions.
  1. +48 −18 lib/protocol.ex
  2. +1 −1 test/elixir/protocol_test.exs
View
66 lib/protocol.ex
@@ -14,16 +14,18 @@ defmodule Protocol do
# according to the only/except rules. If no protocol
# matches, returns nil;
#
+ # * `__protocol_for__!/1` - same as above but raises an error if protocol is not found
+ #
def defprotocol(name, args, opts) do
funs = to_funs(args)
quote do
defmodule unquote(name) do
def __protocol__(:name), do: unquote(name)
def __protocol__(:functions), do: unquote(funs)
- conversions = Protocol.conversions_for(unquote(opts))
- Protocol.functions(__MODULE__, conversions, unquote(funs))
- Protocol.protocol_for(__MODULE__, conversions)
+ { conversions, fallback } = Protocol.conversions_for(unquote(opts))
+ Protocol.functions(__MODULE__, unquote(funs), fallback)
+ Protocol.protocol_for(__MODULE__, conversions, fallback)
end
end
end
@@ -79,8 +81,7 @@ defmodule Protocol do
# It simply detects the protocol using __protocol_for__ and
# then dispatches to it.
@doc false
- def functions(module, conversions, funs) do
- fallback = if L.keyfind(Tuple, 1, conversions), do: Tuple, else: Any
+ def functions(module, funs, fallback) do
contents = lc fun in L.reverse(funs), do: each_function(fun, fallback)
Module.eval_quoted module, contents, [], file: __FILE__, line: __LINE__
end
@@ -89,19 +90,44 @@ defmodule Protocol do
# the module to dispatch to. Returns module.Record for records
# which should be properly handled by the dispatching function.
@doc false
- def protocol_for(module, conversions) do
+ def protocol_for(module, conversions, fallback) do
contents = lc kind in conversions, do: each_protocol_for(kind, conversions)
# If we don't implement all protocols and any is not in the
# list, we need to add a final clause that returns nil.
if !L.member({ Any, :is_any }, conversions) && length(conversions) < 10 do
contents = contents ++ [quote do
- def __protocol_for__(_) do
+ defp __raw_protocol__(_) do
nil
end
end]
end
+ # Finally add __protocol_for__ and __protocol_for__!
+ contents = contents ++ [quote do
+ def __protocol_for__(arg) do
+ case __raw_protocol__(arg) do
+ match: __MODULE__.Record
+ target = Module.concat(__MODULE__, :erlang.element(1, arg))
+ if :erlang.function_exported(target, :__protocol__, 1) do
+ target
+ else:
+ Module.concat(__MODULE__, unquote(fallback))
+ end
+ match: other
+ other
+ end
+ end
+
+ def __protocol_for__!(arg) do
+ if module = __protocol_for__(arg) do
+ module
+ else:
+ raise Protocol.UndefinedError, protocol: __MODULE__, structure: arg
+ end
+ end
+ end]
+
Module.eval_quoted module, contents, [], file: __FILE__, line: __LINE__
end
@@ -111,12 +137,16 @@ defmodule Protocol do
def conversions_for(opts) do
kinds = all_types
- if only = Keyword.get(opts, :only, false) do
- L.map(fn(i) -> L.keyfind(i, 1, kinds) end, only)
- else:
- except = Keyword.get(opts, :except, [Any])
- L.foldl(fn(i, list) -> L.keydelete(i, 1, list) end, kinds, except)
- end
+ conversions =
+ if only = Keyword.get(opts, :only, false) do
+ L.map(fn(i) -> L.keyfind(i, 1, kinds) end, only)
+ else:
+ except = Keyword.get(opts, :except, [Any])
+ L.foldl(fn(i, list) -> L.keydelete(i, 1, list) end, kinds, except)
+ end
+
+ fallback = if L.keyfind(Tuple, 1, conversions), do: Tuple, else: Any
+ { conversions, fallback }
end
## Helpers
@@ -138,7 +168,7 @@ defmodule Protocol do
end
# Returns a quoted expression that allow to checks
- # if a variable named first is built or not.
+ # if a variable named first is built in or not.
defp is_builtin?([{h,_}]) do
quote do
first == unquote(h)
@@ -155,7 +185,7 @@ defmodule Protocol do
# If this is the case, module.Record will be returned.
defp each_protocol_for({ _, :is_record }, conversions) do
quote do
- def __protocol_for__(arg) when is_tuple(arg) and is_atom(:erlang.element(1, arg)) do
+ defp __raw_protocol__(arg) when is_tuple(arg) and is_atom(:erlang.element(1, arg)) do
first = :erlang.element(1, arg)
case unquote(is_builtin?(conversions)) do
match: true
@@ -175,7 +205,7 @@ defmodule Protocol do
# Special case any as we don't need to generate a guard.
defp each_protocol_for({ _, :is_any }, _) do
quote do
- def __protocol_for__(_) do
+ defp __raw_protocol__(_) do
__MODULE__.Any
end
end
@@ -184,7 +214,7 @@ defmodule Protocol do
# Generate all others protocols.
defp each_protocol_for({ kind, fun }, _) do
quote do
- def __protocol_for__(arg) when unquote(fun).(arg) do
+ defp __raw_protocol__(arg) when unquote(fun).(arg) do
Module.concat __MODULE__, unquote(kind)
end
end
@@ -203,7 +233,7 @@ defmodule Protocol do
quote do
def unquote(name).(unquote_splicing(args)) do
args = [unquote_splicing(args)]
- case __protocol_for__(xA) do
+ case __raw_protocol__(xA) do
match: __MODULE__.Record
try do
target = Module.concat(__MODULE__, :erlang.element(1, xA))
View
2 test/elixir/protocol_test.exs
@@ -80,7 +80,7 @@ defmodule ProtocolTest do
assert_protocol_for(ProtocolTest.WithAll, List, [1,2,3])
assert_protocol_for(ProtocolTest.WithAll, Tuple, {})
assert_protocol_for(ProtocolTest.WithAll, Tuple, {1,2,3})
- assert_protocol_for(ProtocolTest.WithAll, Record, {Bar,2,3})
+ assert_protocol_for(ProtocolTest.WithAll, Tuple, {Bar,2,3})
assert_protocol_for(ProtocolTest.WithAll, BitString, "foo")
assert_protocol_for(ProtocolTest.WithAll, BitString, <<1>>)
assert_protocol_for(ProtocolTest.WithAll, PID, Process.self)

0 comments on commit 7c26b21

Please sign in to comment.