Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor lib/exla/nif.ex #19

Merged
merged 1 commit into from
Nov 14, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 86 additions & 76 deletions lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,178 +8,188 @@ defmodule Exla.NIF do
:erlang.load_nif(path, 0)
end

defmacrop nif_error() do
quote do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about having it call nif_error(ENV.function) which then builds the error message? That’s so we don’t build the same string over and over again.

{name, arity} = __ENV__.function
raise "failed to load implementation of #{inspect(__MODULE__)}.#{name}/#{arity}"
end
end

def new_builder(_name),
do: raise("Failed to load implementation of #{__MODULE__}.new_builder/1")
do: nif_error()

def binary_to_shaped_buffer(_client, _binary, _shape, _device_ordinal),
do: raise("Failed to load implementation of #{__MODULE__}.binary_to_shaped_buffer/4.")
do: nif_error()

def on_host_shape(_buffer),
do: raise("Failed to load implementation of #{__MODULE__}.on_host_shape/1")
do: nif_error()

def make_shape(_type, _dims),
do: raise("Failed to load implementation of #{__MODULE__}.make_shape/2.")
do: nif_error()

def make_scalar_shape(_type),
do: raise("Failed to load implementation of #{__MODULE__}.make_scalar_shape/1.")
do: nif_error()

def human_string(_shape),
do: raise("Failed to load implementation of #{__MODULE__}.human_string/1.")
do: nif_error()

def create_r0(_value), do: raise("Failed to load implementation of #{__MODULE__}.create_r0/1.")
def create_r0(_value),
do: nif_error()

def literal_to_string(_literal),
do: raise("Failed to load implementation of #{__MODULE__}.literal_to_string/1.")
do: nif_error()

def parameter(_builder, _number, _shape, _name),
do: raise("Failed to load implementation of #{__MODULE__}.parameter/3.")
do: nif_error()

def add(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.add/3.")
do: nif_error()

def sub(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.sub/3.")
do: nif_error()

def mul(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.mul/3.")
do: nif_error()

def div(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.div/3.")
do: nif_error()

def rem(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.rem/3.")
do: nif_error()

def min(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.min/3.")
do: nif_error()

def max(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.max/3.")
do: nif_error()

def logical_and(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.and/3.")
do: nif_error()

def logical_or(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.or/3.")
do: nif_error()

def logical_xor(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.xor/3.")
do: nif_error()

def shift_left(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.shift_left/3.")
do: nif_error()

def shift_right_arithmetic(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.shift_right_arithmetic/3.")
do: nif_error()

def shift_right_logical(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.shift_right_logical/3.")
do: nif_error()

def eq(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.eq/3.")
do: nif_error()

def eq_total_order(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.eq_total_order/3.")
do: nif_error()

def ne(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.ne/3.")
do: nif_error()

def ne_total_order(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.ne_total_order/3.")
do: nif_error()

def ge(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.ge/3.")
do: nif_error()

def ge_total_order(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.ge_total_order/3.")
do: nif_error()

def gt(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.gt/3.")
do: nif_error()

def gt_total_order(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.gt_total_order/3.")
do: nif_error()

def le(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.le/3.")
do: nif_error()

def le_total_order(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.le_total_order/3.")
do: nif_error()

def lt(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.lt/3.")
do: nif_error()

def lt_total_order(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.lt/3.")
do: nif_error()

def pow(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.pow/3.")
do: nif_error()

def complex(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.complex/3.")
do: nif_error()

def atan2(_a, _b, _broadcast_dims),
do: raise("Failed to load implementation of #{__MODULE__}.atan2/3.")

def abs(_a), do: raise("Failed to load implementation of #{__MODULE__}.abs/1.")
def exp(_a), do: raise("Failed to load implementation of #{__MODULE__}.exp/1.")
def expm1(_a), do: raise("Failed to load implementation of #{__MODULE__}.expm1/1.")
def floor(_a), do: raise("Failed to load implementation of #{__MODULE__}.floor/1.")
def ceil(_a), do: raise("Failed to load implementation of #{__MODULE__}.ceil/1.")
def round(_a), do: raise("Failed to load implementation of #{__MODULE__}.round/1.")
def log(_a), do: raise("Failed to load implementation of #{__MODULE__}.log/1.")
def log1p(_a), do: raise("Failed to load implementation of #{__MODULE__}.log1p/1.")
def logistic(_a), do: raise("Failed to load implementation of #{__MODULE__}.logistic/1.")
def sign(_a), do: raise("Failed to load implementation of #{__MODULE__}.sign/1.")
def clz(_a), do: raise("Failed to load implementation of #{__MODULE__}.clz/1.")
def cos(_a), do: raise("Failed to load implementation of #{__MODULE__}.cos/1.")
def sin(_a), do: raise("Failed to load implementation of #{__MODULE__}.sin/1.")
def tanh(_a), do: raise("Failed to load implementation of #{__MODULE__}.tanh/1.")
def real(_a), do: raise("Failed to load implementation of #{__MODULE__}.real/1.")
def imag(_a), do: raise("Failed to load implementation of #{__MODULE__}.imag/1.")
def sqrt(_a), do: raise("Failed to load implementation of #{__MODULE__}.sqrt/1.")
def rsqrt(_a), do: raise("Failed to load implementation of #{__MODULE__}.rsqrt/1.")
def cbrt(_a), do: raise("Failed to load implementation of #{__MODULE__}.cbrt/1.")
def is_finite(_a), do: raise("Failed to load implementation of #{__MODULE__}.is_finite/1.")
def logical_not(_a), do: raise("Failed to load implementation of #{__MODULE__}.not/1.")
def neg(_a), do: raise("Failed to load implementation of #{__MODULE__}.neg/1.")
def conj(_a), do: raise("Failed to load implementation of #{__MODULE__}.conj/1.")
def copy(_a), do: raise("Failed to load implementation of #{__MODULE__}.copy/1.")
do: nif_error()

def abs(_a), do: nif_error()
def exp(_a), do: nif_error()
def expm1(_a), do: nif_error()
def floor(_a), do: nif_error()
def ceil(_a), do: nif_error()
def round(_a), do: nif_error()
def log(_a), do: nif_error()
def log1p(_a), do: nif_error()
def logistic(_a), do: nif_error()
def sign(_a), do: nif_error()
def clz(_a), do: nif_error()
def cos(_a), do: nif_error()
def sin(_a), do: nif_error()
def tanh(_a), do: nif_error()
def real(_a), do: nif_error()
def imag(_a), do: nif_error()
def sqrt(_a), do: nif_error()
def rsqrt(_a), do: nif_error()
def cbrt(_a), do: nif_error()
def is_finite(_a), do: nif_error()
def logical_not(_a), do: nif_error()
def neg(_a), do: nif_error()
def conj(_a), do: nif_error()
def copy(_a), do: nif_error()

def population_count(_a),
do: raise("Failed to load implementation of #{__MODULE__}.population_count/1.")
do: nif_error()

def dot(_a, _b), do: raise("Failed to load implementation of #{__MODULE__}.dot/2.")
def dot(_a, _b),
do: nif_error()

def constant_r0(_builder, _value),
do: raise("Failed to load implementation of #{__MODULE__}.constant_r0/2.")
do: nif_error()

def constant_r1(_length, _value),
do: raise("Failed to load implementation of #{__MODULE__}.constant_r1/2.")
do: nif_error()

def get_or_create_local_client(_platform, _number_of_replicas, _intra_op_parallelism_threads),
do: raise("Failed to load implementation of #{__MODULE__}.get_or_create_local_client/3.")
do: nif_error()

def get_default_device_ordinal(_client),
do: raise("Failed to load implementation of #{__MODULE__}.get_default_device_ordinal/1.")
do: nif_error()

def get_device_count(_client),
do: raise("Failed to load implementation of #{__MODULE__}.get_device_count/1.")
do: nif_error()

def get_computation_hlo_proto(_computation),
do: raise("Failed to load implementation of #{__MODULE__}.get_computation_hlo_proto/0.")
do: nif_error()

def get_computation_hlo_text(_computation),
do: raise("Failed to load implementation of #{__MODULE__}.get_computation_hlo_text/0.")
do: nif_error()

def build(_builder, _root), do: raise("Failed to load implementation of #{__MODULE__}.build/2.")
def build(_builder, _root),
do: nif_error()

def compile(_client, _computation, _argument_layouts, _options),
do: raise("Failed to load implementation of #{__MODULE__}.compile/4.")
do: nif_error()

def run(_executable, _arguments, _run_options),
do: raise("Failed to load implementation of #{__MODULE__}.run/3.")
do: nif_error()

def literal_to_shaped_buffer(_client, _literal, _device_ordinal, _allocator),
do: raise("Failed to load implementation of #{__MODULE__}.literal_to_shaped_buffer/4.")
do: nif_error()

def shaped_buffer_to_literal(_client, _shaped_buffer),
do: raise("Failed to load implementation of #{__MODULE__}.shaped_buffer_to_literal/2.")
do: nif_error()
end