Skip to content

Commit

Permalink
all: switch to using a custom Reader interface instead of bytes.Reade…
Browse files Browse the repository at this point in the history
…r to avoid extra copying of data

Signed-off-by: deadprogram <ron@hybridgroup.com>
  • Loading branch information
deadprogram committed Mar 3, 2024
1 parent f328896 commit 3bac1f9
Show file tree
Hide file tree
Showing 21 changed files with 127 additions and 103 deletions.
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package config

import (
"errors"
"io"

"github.com/hybridgroup/wasman/tollstation"
)

type Reader interface {
io.Reader
io.ReaderAt
}

const (
// MemoryPageSize is the unit of memory length in WebAssembly,
// and is defined as 2^16 = 65536.
Expand Down
40 changes: 25 additions & 15 deletions expr/expr.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package expr

import (
"bytes"
"fmt"

"github.com/hybridgroup/wasman/leb128decode"
Expand All @@ -16,49 +15,60 @@ type Expression struct {
}

// ReadExpression will read an expr.Expression from the io.Reader
func ReadExpression(r *bytes.Reader) (*Expression, error) {
b, err := r.ReadByte()
func ReadExpression(r utils.Reader) (*Expression, error) {
var b [1]byte
_, err := r.Read(b[:])
if err != nil {
return nil, fmt.Errorf("read opcode: %v", err)
}

remainingBeforeData := int64(r.Len())
offsetAtData := r.Size() - remainingBeforeData

op := OpCode(b)
n := uint64(0)
op := OpCode(b[0])

switch op {
case OpCodeI32Const:
_, _, err = leb128decode.DecodeInt32(r)
_, n, err = leb128decode.DecodeInt32(r)
case OpCodeI64Const:
_, _, err = leb128decode.DecodeInt64(r)
_, n, err = leb128decode.DecodeInt64(r)
case OpCodeF32Const:
_, err = utils.ReadFloat32(r)
n = 4
case OpCodeF64Const:
_, err = utils.ReadFloat64(r)
n = 8
case OpCodeGlobalGet:
_, _, err = leb128decode.DecodeUint32(r)
_, n, err = leb128decode.DecodeUint32(r)
default:
return nil, fmt.Errorf("%v for opcodes.OpCode: %#x", types.ErrInvalidTypeByte, b)
return nil, fmt.Errorf("%v for opcodes.OpCode: %#x", types.ErrInvalidTypeByte, b[0])
}

if err != nil {
return nil, fmt.Errorf("read value: %v", err)
}

if b, err = r.ReadByte(); err != nil {
if _, err = r.Read(b[:]); err != nil {
return nil, fmt.Errorf("look for end opcode: %v", err)
}

if b != byte(OpCodeEnd) {
if b[0] != byte(OpCodeEnd) {
return nil, fmt.Errorf("constant expression has not terminated")
}

data := make([]byte, remainingBeforeData-int64(r.Len())-1)
if _, err := r.ReadAt(data, offsetAtData); err != nil {
// skip back
if _, err := r.Seek(-1*int64(n+1), 1); err != nil {
return nil, fmt.Errorf("error seeking back to read Expression Data")
}

data := make([]byte, n)
if _, err := r.Read(data); err != nil {
return nil, fmt.Errorf("error re-buffering Expression Data")
}

// skip past end opcode
if _, err := r.Read(b[:]); err != nil {
return nil, fmt.Errorf("error skipping past OpCodeEnd")
}

return &Expression{
OpCode: op,
Data: data,
Expand Down
4 changes: 2 additions & 2 deletions expr/expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func TestReadExpr(t *testing.T) {
} {
actual, err := expr.ReadExpression(bytes.NewReader(c.bytes))
if err != nil {
t.Fail()
t.Error(err)
}
if !reflect.DeepEqual(c.exp, actual) {
t.Fail()
t.Errorf("expected %v, got %v", c.exp, actual)
}
}
})
Expand Down
67 changes: 34 additions & 33 deletions leb128decode/leb128.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package leb128decode

import (
"bytes"
"errors"
"fmt"

"github.com/hybridgroup/wasman/utils"
)

const (
Expand All @@ -19,73 +20,73 @@ var (
)

// DecodeUint32 will decode a uint32 from io.Reader, returning it as the ret with the bytes length l which it read.
func DecodeUint32(r *bytes.Reader) (ret uint32, bytesRead uint64, err error) {
func DecodeUint32(r utils.Reader) (ret uint32, bytesRead uint64, err error) {
// Derived from https://github.com/golang/go/blob/aafad20b617ee63d58fcd4f6e0d98fe27760678c/src/encoding/binary/varint.go
// with the modification on the overflow handling tailored for 32-bits.
var s uint32
var b byte
var b [1]byte
for i := 0; i < maxVarintLen32; i++ {
b, err = r.ReadByte()
_, err = r.Read(b[:])
if err != nil {
return 0, 0, err
}
if b < 0x80 {
if b[0] < 0x80 {
// Unused bits must be all zero.
if i == maxVarintLen32-1 && (b&0xf0) > 0 {
if i == maxVarintLen32-1 && (b[0]&0xf0) > 0 {
return 0, 0, errOverflow32
}
return ret | uint32(b)<<s, uint64(i) + 1, nil
return ret | uint32(b[0])<<s, uint64(i) + 1, nil
}
ret |= (uint32(b) & 0x7f) << s
ret |= (uint32(b[0]) & 0x7f) << s
s += 7
}
return 0, 0, errOverflow32
}

// DecodeUint64 will decode a uint64 from io.Reader, returning it as the ret with the bytes length l which it read.
func DecodeUint64(r *bytes.Reader) (ret uint64, bytesRead uint64, err error) {
func DecodeUint64(r utils.Reader) (ret uint64, bytesRead uint64, err error) {
// Derived from https://github.com/golang/go/blob/aafad20b617ee63d58fcd4f6e0d98fe27760678c/src/encoding/binary/varint.go
var s uint64
var b byte
var b [1]byte
for i := 0; i < maxVarintLen64; i++ {
b, err = r.ReadByte()
_, err = r.Read(b[:])
if err != nil {
return 0, 0, err
}
if b < 0x80 {
if b[0] < 0x80 {
// Unused bits (non first bit) must all be zero.
if i == maxVarintLen64-1 && b > 1 {
if i == maxVarintLen64-1 && b[0] > 1 {
return 0, 0, errOverflow64
}
return ret | uint64(b)<<s, uint64(i) + 1, nil
return ret | uint64(b[0])<<s, uint64(i) + 1, nil
}
ret |= (uint64(b) & 0x7f) << s
ret |= (uint64(b[0]) & 0x7f) << s
s += 7
}
return 0, 0, errOverflow64
}

// DecodeInt32 will decode a int32 from io.Reader, returning it as the ret with the bytes length l which it read.
func DecodeInt32(r *bytes.Reader) (ret int32, bytesRead uint64, err error) {
func DecodeInt32(r utils.Reader) (ret int32, bytesRead uint64, err error) {
var shift int
var b byte
var b [1]byte
for {
b, err = r.ReadByte()
_, err = r.Read(b[:])
if err != nil {
return 0, 0, fmt.Errorf("readByte failed: %w", err)
}
ret |= (int32(b) & 0x7f) << shift
ret |= (int32(b[0]) & 0x7f) << shift
shift += 7
bytesRead++
if b&0x80 == 0 {
if shift < 32 && (b&0x40) != 0 {
if b[0]&0x80 == 0 {
if shift < 32 && (b[0]&0x40) != 0 {
ret |= ^0 << shift
}
// Over flow checks.
// fixme: can be optimized.
if bytesRead > 5 {
return 0, 0, errOverflow32
} else if unused := b & 0b00110000; bytesRead == 5 && ret < 0 && unused != 0b00110000 {
} else if unused := b[0] & 0b00110000; bytesRead == 5 && ret < 0 && unused != 0b00110000 {
return 0, 0, errOverflow32
} else if bytesRead == 5 && ret >= 0 && unused != 0x00 {
return 0, 0, errOverflow32
Expand All @@ -96,7 +97,7 @@ func DecodeInt32(r *bytes.Reader) (ret int32, bytesRead uint64, err error) {
}

// DecodeInt33AsInt64 will decode a int33 from io.Reader, returning it as the int64 ret with the bytes length l which it read.
func DecodeInt33AsInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) {
func DecodeInt33AsInt64(r utils.Reader) (ret int64, bytesRead uint64, err error) {
const (
int33Mask int64 = 1 << 7
int33Mask2 = ^int33Mask
Expand All @@ -107,13 +108,13 @@ func DecodeInt33AsInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error
)
var shift int
var b int64
var rb byte
var rb [1]byte
for shift < 35 {
rb, err = r.ReadByte()
_, err = r.Read(rb[:])
if err != nil {
return 0, 0, fmt.Errorf("readByte failed: %w", err)
}
b = int64(rb)
b = int64(rb[0])
ret |= (b & int33Mask2) << shift
shift += 7
bytesRead++
Expand Down Expand Up @@ -145,30 +146,30 @@ func DecodeInt33AsInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error
}

// DecodeInt64 will decode a int64 from io.Reader, returning it as the ret with the bytes length l which it read.
func DecodeInt64(r *bytes.Reader) (ret int64, bytesRead uint64, err error) {
func DecodeInt64(r utils.Reader) (ret int64, bytesRead uint64, err error) {
const (
int64Mask3 = 1 << 6
int64Mask4 = ^0
)
var shift int
var b byte
var b [1]byte
for {
b, err = r.ReadByte()
_, err = r.Read(b[:])
if err != nil {
return 0, 0, fmt.Errorf("readByte failed: %w", err)
}
ret |= (int64(b) & 0x7f) << shift
ret |= (int64(b[0]) & 0x7f) << shift
shift += 7
bytesRead++
if b&0x80 == 0 {
if shift < 64 && (b&int64Mask3) == int64Mask3 {
if b[0]&0x80 == 0 {
if shift < 64 && (b[0]&int64Mask3) == int64Mask3 {
ret |= int64Mask4 << shift
}
// Over flow checks.
// fixme: can be optimized.
if bytesRead > 10 {
return 0, 0, errOverflow64
} else if unused := b & 0b00111110; bytesRead == 10 && ret < 0 && unused != 0b00111110 {
} else if unused := b[0] & 0b00111110; bytesRead == 10 && ret < 0 && unused != 0b00111110 {
return 0, 0, errOverflow64
} else if bytesRead == 10 && ret >= 0 && unused != 0x00 {
return 0, 0, errOverflow64
Expand Down
12 changes: 3 additions & 9 deletions module.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@ package wasman

import (
"bytes"
"io"
"io/ioutil"

"github.com/hybridgroup/wasman/config"
"github.com/hybridgroup/wasman/utils"
"github.com/hybridgroup/wasman/wasm"
)

// Module is same to wasm.Module
type Module = wasm.Module

// NewModule is a wrapper to the wasm.NewModule
func NewModule(config config.ModuleConfig, r io.Reader) (*Module, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}

return wasm.NewModule(config, bytes.NewReader(b))
func NewModule(config config.ModuleConfig, r utils.Reader) (*Module, error) {
return wasm.NewModule(config, r)
}

// NewModuleFromBytes is a wrapper to the wasm.NewModule that avoids having to
Expand Down
7 changes: 4 additions & 3 deletions segments/code.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package segments

import (
"bytes"
"fmt"
"io"

"github.com/hybridgroup/wasman/expr"
"github.com/hybridgroup/wasman/leb128decode"
"github.com/hybridgroup/wasman/utils"
)

// CodeSegment is one unit in the wasman.Module's CodeSection
Expand All @@ -16,7 +16,7 @@ type CodeSegment struct {
}

// ReadCodeSegment reads one CodeSegment from the io.Reader
func ReadCodeSegment(r *bytes.Reader) (*CodeSegment, error) {
func ReadCodeSegment(r utils.Reader) (*CodeSegment, error) {
ss, _, err := leb128decode.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("get the size of code segment: %w", err)
Expand All @@ -34,6 +34,7 @@ func ReadCodeSegment(r *bytes.Reader) (*CodeSegment, error) {

var numLocals uint32
var n uint32
var b [1]byte
for i := uint32(0); i < ls; i++ {
n, bytesRead, err = leb128decode.DecodeUint32(r)
remaining -= int64(bytesRead) + 1 // +1 for the subsequent ReadByte
Expand All @@ -44,7 +45,7 @@ func ReadCodeSegment(r *bytes.Reader) (*CodeSegment, error) {
}
numLocals += n

if _, err := r.ReadByte(); err != nil {
if _, err := r.Read(b[:]); err != nil {
return nil, fmt.Errorf("read type of local") // TODO: save read localType
}
}
Expand Down
4 changes: 2 additions & 2 deletions segments/data.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package segments

import (
"bytes"
"fmt"
"io"

"github.com/hybridgroup/wasman/expr"
"github.com/hybridgroup/wasman/leb128decode"
"github.com/hybridgroup/wasman/utils"
)

// DataSegment is one unit of the wasman.Module's DataSection, initializing
Expand All @@ -20,7 +20,7 @@ type DataSegment struct {
}

// ReadDataSegment reads one DataSegment from the io.Reader
func ReadDataSegment(r *bytes.Reader) (*DataSegment, error) {
func ReadDataSegment(r utils.Reader) (*DataSegment, error) {
d, _, err := leb128decode.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("read memory index: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions segments/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ func TestDataSegment(t *testing.T) {
t.Run(utils.IntToString(i), func(t *testing.T) {
actual, err := segments.ReadDataSegment(bytes.NewReader(c.bytes))
if err != nil {
t.Fail()
t.Error(err)
}
if !reflect.DeepEqual(c.exp, actual) {
t.Fail()
t.Errorf("expected %v, got %v", c.exp, actual)
}
})
}
Expand Down
Loading

0 comments on commit 3bac1f9

Please sign in to comment.