-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4ce3484
commit cdfaaf0
Showing
2 changed files
with
487 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
// Copyright 2023 The NLP Odyssey Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package header | ||
|
||
import ( | ||
"encoding/binary" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"math" | ||
"strconv" | ||
|
||
"github.com/nlpodyssey/safetensors/dtype" | ||
) | ||
|
||
type rawDecodedHeader map[string]map[string]any | ||
|
||
const metadataKey = "__metadata__" | ||
|
||
// Read reads and parses from "r" the initial part of a safetensors | ||
// data stream. | ||
func Read(r io.Reader) (Header, error) { | ||
size, err := readHeaderSize(r) | ||
switch { | ||
case err != nil: | ||
return Header{}, err | ||
case size < 2: | ||
return Header{}, fmt.Errorf("header size too small: %d", size) | ||
case size > math.MaxInt-8: | ||
return Header{}, fmt.Errorf("header size too large: %d", size) | ||
} | ||
|
||
raw, err := readAndDecodeJSON(r, int64(size)) | ||
if err != nil { | ||
return Header{}, fmt.Errorf("failed to JSON-decode header: %w", err) | ||
} | ||
|
||
h, err := convertRawHeader(raw) | ||
if err != nil { | ||
return Header{}, err | ||
} | ||
|
||
h.ByteBufferOffset = 8 + int(size) | ||
return h, nil | ||
} | ||
|
||
func readHeaderSize(r io.Reader) (uint64, error) { | ||
var arr [8]byte | ||
b := arr[:] | ||
if _, err := io.ReadFull(r, b); err != nil { | ||
return 0, fmt.Errorf("failed to read header size: %w", err) | ||
} | ||
return binary.LittleEndian.Uint64(b), nil | ||
} | ||
|
||
func readAndDecodeJSON(r io.Reader, size int64) (rawDecodedHeader, error) { | ||
dec := json.NewDecoder(&io.LimitedReader{R: r, N: size}) | ||
dec.UseNumber() | ||
|
||
var raw rawDecodedHeader | ||
if err := dec.Decode(&raw); err != nil { | ||
return nil, err | ||
} | ||
// take care of possible padding spaces after JSON object | ||
if off := dec.InputOffset(); off != size { | ||
if _, err := dec.Token(); err == nil { | ||
return nil, fmt.Errorf("unexpected data at byte offset %d", off) | ||
} else if err != io.EOF { | ||
return nil, err | ||
} | ||
} | ||
return raw, nil | ||
} | ||
|
||
func convertRawHeader(raw rawDecodedHeader) (h Header, err error) { | ||
if rawMeta, ok := raw[metadataKey]; ok { | ||
delete(raw, metadataKey) | ||
if h.Metadata, err = convertRawMetadata(rawMeta); err != nil { | ||
return | ||
} | ||
} | ||
h.Tensors, err = convertRawTensors(raw) | ||
return | ||
} | ||
|
||
func convertRawMetadata(raw map[string]any) (Metadata, error) { | ||
if len(raw) == 0 { | ||
return nil, nil | ||
} | ||
metadata := make(Metadata, len(raw)) | ||
for key, rawVal := range raw { | ||
var ok bool | ||
if metadata[key], ok = rawVal.(string); !ok { | ||
return nil, fmt.Errorf("failed to interpret header metadata: found non-string value for key %q", key) | ||
} | ||
} | ||
return metadata, nil | ||
} | ||
|
||
func convertRawTensors(raw rawDecodedHeader) (Tensors, error) { | ||
if len(raw) == 0 { | ||
return nil, nil | ||
} | ||
tensors := make(Tensors, len(raw)) | ||
for key, rawVal := range raw { | ||
var err error | ||
if tensors[key], err = convertRawTensor(rawVal); err != nil { | ||
return nil, fmt.Errorf("failed to interpret header tensor %q: %w", key, err) | ||
} | ||
} | ||
return tensors, nil | ||
} | ||
|
||
func convertRawTensor(raw map[string]any) (t Tensor, err error) { | ||
if t.DType, err = convertRawTensorDType(raw); err != nil { | ||
return | ||
} | ||
if t.Shape, err = convertRawTensorShape(raw); err != nil { | ||
return | ||
} | ||
if t.Offsets, err = convertRawTensorOffsets(raw); err != nil { | ||
return | ||
} | ||
if len(raw) != 3 { | ||
err = errors.New("JSON object contains unknown keys") | ||
} | ||
return | ||
} | ||
|
||
func convertRawTensorDType(raw map[string]any) (dtype.DType, error) { | ||
rawDType, ok := raw["dtype"] | ||
if !ok { | ||
return 0, errors.New(`"dtype" is missing`) | ||
} | ||
strDType, ok := rawDType.(string) | ||
if !ok { | ||
return 0, errors.New(`found non-string "dtype" value`) | ||
} | ||
var dt dtype.DType | ||
if err := dt.UnmarshalText([]byte(strDType)); err != nil { | ||
return 0, fmt.Errorf(`invalid "dtype" value: %q`, strDType) | ||
} | ||
return dt, nil | ||
} | ||
|
||
func convertRawTensorShape(raw map[string]any) ([]int, error) { | ||
rawShape, ok := raw["shape"] | ||
if !ok { | ||
return nil, fmt.Errorf(`"shape" is missing`) | ||
} | ||
rawSlice, ok := rawShape.([]any) | ||
if !ok { | ||
return nil, errors.New(`found non-array "shape" value`) | ||
} | ||
shape := make([]int, len(rawSlice)) | ||
for i, rawItem := range rawSlice { | ||
var err error | ||
if shape[i], err = convertNonNegInt(rawItem); err != nil { | ||
return nil, fmt.Errorf(`failed to interpret "shape" value at index %d: %w`, i, err) | ||
} | ||
} | ||
return shape, nil | ||
} | ||
|
||
func convertRawTensorOffsets(raw map[string]any) (Offsets, error) { | ||
rawOffsets, ok := raw["data_offsets"] | ||
if !ok { | ||
return Offsets{}, fmt.Errorf(`"data_offsets" is missing`) | ||
} | ||
rawSlice, ok := rawOffsets.([]any) | ||
if !ok { | ||
return Offsets{}, errors.New(`found non-array "data_offsets" value`) | ||
} | ||
if l := len(rawSlice); l != 2 { | ||
return Offsets{}, fmt.Errorf(`bad "data_offsets" length: expected 2, actual %d`, l) | ||
} | ||
var parsed [2]int | ||
for i, rawItem := range rawSlice { | ||
var err error | ||
if parsed[i], err = convertNonNegInt(rawItem); err != nil { | ||
return Offsets{}, fmt.Errorf(`failed to interpret "data_offsets" value at index %d: %w`, i, err) | ||
} | ||
} | ||
return Offsets{Begin: parsed[0], End: parsed[1]}, nil | ||
} | ||
|
||
func convertNonNegInt(value any) (int, error) { | ||
jNum, ok := value.(json.Number) | ||
if !ok { | ||
return 0, errors.New("value is not a number") | ||
} | ||
num, err := strconv.ParseInt(jNum.String(), 10, strconv.IntSize) | ||
if err != nil { | ||
return 0, fmt.Errorf("failed to convert value %q to int: %w", jNum.String(), err) | ||
} | ||
if num < 0 { | ||
return 0, fmt.Errorf("value is negative: %d", num) | ||
} | ||
return int(num), nil | ||
} |
Oops, something went wrong.