diff --git a/lib/safetensors.ex b/lib/safetensors.ex index 719d6e5..5ca6356 100644 --- a/lib/safetensors.ex +++ b/lib/safetensors.ex @@ -7,22 +7,59 @@ defmodule Safetensors do @header_metadata_key "__metadata__" - @dtype_mapping %{ - "BF16" => :bf16, - "F64" => :f64, - "F32" => :f32, - "F16" => :f16, - "I64" => :s64, - "I32" => :s32, - "I16" => :s16, - "I8" => :s8, - "U64" => :u64, - "U32" => :u32, - "U16" => :u16, - "U8" => :u8 - # "BOOL" => :u8 + @type_to_dtype %{ + {:bf, 16} => "BF16", + {:f, 64} => "F64", + {:f, 32} => "F32", + {:f, 16} => "F16", + {:s, 64} => "I64", + {:s, 32} => "I32", + {:s, 16} => "I16", + {:s, 8} => "I8", + {:u, 64} => "U64", + {:u, 32} => "U32", + {:u, 16} => "U16", + {:u, 8} => "U8" } + @dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k} + + def dump(tensors) when is_map(tensors) do + {header, buffer} = + tensors + |> Enum.map_reduce( + <<>>, + fn {tensor_name, tensor}, buffer -> + binary = Nx.to_binary(tensor) + + offset = byte_size(buffer) + + { + { + tensor_name, + Jason.OrderedObject.new( + dtype: tensor |> Nx.type() |> type_to_dtype(), + shape: tensor |> Nx.shape() |> Tuple.to_list(), + data_offsets: [offset, offset + byte_size(binary)] + ) + }, + buffer <> binary + } + end + ) + + header_json = + header + |> Jason.OrderedObject.new() + |> Jason.encode!() + + << + String.length(header_json)::unsigned-64-integer-little, + header_json::binary, + buffer::binary + >> + end + def load!(data) when is_binary(data) do << header_size::unsigned-64-integer-little, @@ -43,15 +80,21 @@ defmodule Safetensors do "shape" => shape } = tensor_info - type = @dtype_mapping[dtype] || raise "unrecognized dtype #{dtype}" - { tensor_name, buffer |> binary_part(offset_start, offset_end - offset_start) - |> Nx.from_binary(type) + |> Nx.from_binary(dtype |> dtype_to_type()) |> Nx.reshape(List.to_tuple(shape)) } end) end + + defp type_to_dtype(type) do + @type_to_dtype[type] || raise "unrecognized type #{inspect(type)}" + end + + defp dtype_to_type(dtype) do + @dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}" + end end diff --git a/test/safetensors_test.exs b/test/safetensors_test.exs index 80d929d..c9ca416 100644 --- a/test/safetensors_test.exs +++ b/test/safetensors_test.exs @@ -2,6 +2,18 @@ defmodule SafetensorsTest do use ExUnit.Case doctest Safetensors + test "dump" do + binary = + %{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)} + |> Safetensors.dump() + + # source: + # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L22-L25 + # with the header padding removed and changed numbers + assert binary == + ~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00) + end + test "load" do # source: # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40