Skip to content

Commit

Permalink
Add header reading functions
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-nicola committed Jun 5, 2023
1 parent 4ce3484 commit cdfaaf0
Show file tree
Hide file tree
Showing 2 changed files with 487 additions and 0 deletions.
203 changes: 203 additions & 0 deletions header/read.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// Copyright 2023 The NLP Odyssey Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package header

import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"strconv"

"github.com/nlpodyssey/safetensors/dtype"
)

type rawDecodedHeader map[string]map[string]any

const metadataKey = "__metadata__"

// Read reads and parses from "r" the initial part of a safetensors
// data stream.
func Read(r io.Reader) (Header, error) {
size, err := readHeaderSize(r)
switch {
case err != nil:
return Header{}, err
case size < 2:
return Header{}, fmt.Errorf("header size too small: %d", size)
case size > math.MaxInt-8:
return Header{}, fmt.Errorf("header size too large: %d", size)
}

raw, err := readAndDecodeJSON(r, int64(size))
if err != nil {
return Header{}, fmt.Errorf("failed to JSON-decode header: %w", err)
}

h, err := convertRawHeader(raw)
if err != nil {
return Header{}, err
}

h.ByteBufferOffset = 8 + int(size)
return h, nil
}

func readHeaderSize(r io.Reader) (uint64, error) {
var arr [8]byte
b := arr[:]
if _, err := io.ReadFull(r, b); err != nil {
return 0, fmt.Errorf("failed to read header size: %w", err)
}
return binary.LittleEndian.Uint64(b), nil
}

func readAndDecodeJSON(r io.Reader, size int64) (rawDecodedHeader, error) {
dec := json.NewDecoder(&io.LimitedReader{R: r, N: size})
dec.UseNumber()

var raw rawDecodedHeader
if err := dec.Decode(&raw); err != nil {
return nil, err
}
// take care of possible padding spaces after JSON object
if off := dec.InputOffset(); off != size {
if _, err := dec.Token(); err == nil {
return nil, fmt.Errorf("unexpected data at byte offset %d", off)
} else if err != io.EOF {
return nil, err
}
}
return raw, nil
}

func convertRawHeader(raw rawDecodedHeader) (h Header, err error) {
if rawMeta, ok := raw[metadataKey]; ok {
delete(raw, metadataKey)
if h.Metadata, err = convertRawMetadata(rawMeta); err != nil {
return
}
}
h.Tensors, err = convertRawTensors(raw)
return
}

func convertRawMetadata(raw map[string]any) (Metadata, error) {
if len(raw) == 0 {
return nil, nil
}
metadata := make(Metadata, len(raw))
for key, rawVal := range raw {
var ok bool
if metadata[key], ok = rawVal.(string); !ok {
return nil, fmt.Errorf("failed to interpret header metadata: found non-string value for key %q", key)
}
}
return metadata, nil
}

func convertRawTensors(raw rawDecodedHeader) (Tensors, error) {
if len(raw) == 0 {
return nil, nil
}
tensors := make(Tensors, len(raw))
for key, rawVal := range raw {
var err error
if tensors[key], err = convertRawTensor(rawVal); err != nil {
return nil, fmt.Errorf("failed to interpret header tensor %q: %w", key, err)
}
}
return tensors, nil
}

func convertRawTensor(raw map[string]any) (t Tensor, err error) {
if t.DType, err = convertRawTensorDType(raw); err != nil {
return
}
if t.Shape, err = convertRawTensorShape(raw); err != nil {
return
}
if t.Offsets, err = convertRawTensorOffsets(raw); err != nil {
return
}
if len(raw) != 3 {
err = errors.New("JSON object contains unknown keys")
}
return
}

func convertRawTensorDType(raw map[string]any) (dtype.DType, error) {
rawDType, ok := raw["dtype"]
if !ok {
return 0, errors.New(`"dtype" is missing`)
}
strDType, ok := rawDType.(string)
if !ok {
return 0, errors.New(`found non-string "dtype" value`)
}
var dt dtype.DType
if err := dt.UnmarshalText([]byte(strDType)); err != nil {
return 0, fmt.Errorf(`invalid "dtype" value: %q`, strDType)
}
return dt, nil
}

func convertRawTensorShape(raw map[string]any) ([]int, error) {
rawShape, ok := raw["shape"]
if !ok {
return nil, fmt.Errorf(`"shape" is missing`)
}
rawSlice, ok := rawShape.([]any)
if !ok {
return nil, errors.New(`found non-array "shape" value`)
}
shape := make([]int, len(rawSlice))
for i, rawItem := range rawSlice {
var err error
if shape[i], err = convertNonNegInt(rawItem); err != nil {
return nil, fmt.Errorf(`failed to interpret "shape" value at index %d: %w`, i, err)
}
}
return shape, nil
}

func convertRawTensorOffsets(raw map[string]any) (Offsets, error) {
rawOffsets, ok := raw["data_offsets"]
if !ok {
return Offsets{}, fmt.Errorf(`"data_offsets" is missing`)
}
rawSlice, ok := rawOffsets.([]any)
if !ok {
return Offsets{}, errors.New(`found non-array "data_offsets" value`)
}
if l := len(rawSlice); l != 2 {
return Offsets{}, fmt.Errorf(`bad "data_offsets" length: expected 2, actual %d`, l)
}
var parsed [2]int
for i, rawItem := range rawSlice {
var err error
if parsed[i], err = convertNonNegInt(rawItem); err != nil {
return Offsets{}, fmt.Errorf(`failed to interpret "data_offsets" value at index %d: %w`, i, err)
}
}
return Offsets{Begin: parsed[0], End: parsed[1]}, nil
}

func convertNonNegInt(value any) (int, error) {
jNum, ok := value.(json.Number)
if !ok {
return 0, errors.New("value is not a number")
}
num, err := strconv.ParseInt(jNum.String(), 10, strconv.IntSize)
if err != nil {
return 0, fmt.Errorf("failed to convert value %q to int: %w", jNum.String(), err)
}
if num < 0 {
return 0, fmt.Errorf("value is negative: %d", num)
}
return int(num), nil
}
Loading

0 comments on commit cdfaaf0

Please sign in to comment.