Skip to content

Commit

Permalink
feat: supports Safetensors.dump/1
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Aug 3, 2023
1 parent 62d8745 commit 3107378
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 17 deletions.
77 changes: 60 additions & 17 deletions lib/safetensors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,59 @@ defmodule Safetensors do

@header_metadata_key "__metadata__"

@dtype_mapping %{
"BF16" => :bf16,
"F64" => :f64,
"F32" => :f32,
"F16" => :f16,
"I64" => :s64,
"I32" => :s32,
"I16" => :s16,
"I8" => :s8,
"U64" => :u64,
"U32" => :u32,
"U16" => :u16,
"U8" => :u8
# "BOOL" => :u8
@type_to_dtype %{
{:bf, 16} => "BF16",
{:f, 64} => "F64",
{:f, 32} => "F32",
{:f, 16} => "F16",
{:s, 64} => "I64",
{:s, 32} => "I32",
{:s, 16} => "I16",
{:s, 8} => "I8",
{:u, 64} => "U64",
{:u, 32} => "U32",
{:u, 16} => "U16",
{:u, 8} => "U8"
}

@dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k}

def dump(tensors) when is_map(tensors) do
{header, buffer} =
tensors
|> Enum.map_reduce(
<<>>,
fn {tensor_name, tensor}, buffer ->
binary = Nx.to_binary(tensor)

offset = byte_size(buffer)

{
{
tensor_name,
Jason.OrderedObject.new(
dtype: tensor |> Nx.type() |> type_to_dtype(),
shape: tensor |> Nx.shape() |> Tuple.to_list(),
data_offsets: [offset, offset + byte_size(binary)]
)
},
buffer <> binary
}
end
)

header_json =
header
|> Jason.OrderedObject.new()
|> Jason.encode!()

<<
String.length(header_json)::unsigned-64-integer-little,
header_json::binary,
buffer::binary
>>
end

def load!(data) when is_binary(data) do
<<
header_size::unsigned-64-integer-little,
Expand All @@ -43,15 +80,21 @@ defmodule Safetensors do
"shape" => shape
} = tensor_info

type = @dtype_mapping[dtype] || raise "unrecognized dtype #{dtype}"

{
tensor_name,
buffer
|> binary_part(offset_start, offset_end - offset_start)
|> Nx.from_binary(type)
|> Nx.from_binary(dtype |> dtype_to_type())
|> Nx.reshape(List.to_tuple(shape))
}
end)
end

defp type_to_dtype(type) do
@type_to_dtype[type] || raise "unrecognized type #{inspect(type)}"
end

defp dtype_to_type(dtype) do
@dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}"
end
end
12 changes: 12 additions & 0 deletions test/safetensors_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@ defmodule SafetensorsTest do
use ExUnit.Case
doctest Safetensors

test "dump" do
binary =
%{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)}
|> Safetensors.dump()

# source:
# https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L22-L25
# with the header padding removed and changed numbers
assert binary ==
~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00)
end

test "load" do
# source:
# https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40
Expand Down

0 comments on commit 3107378

Please sign in to comment.