Skip to content

Commit

Permalink
hdf5: mark types with illegal pointer chains
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jul 2, 2018
1 parent 21dab89 commit c257073
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 43 deletions.
16 changes: 13 additions & 3 deletions h5d_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import (

type Dataset struct {
Identifier

typ *Datatype
}

func newDataset(id C.hid_t) *Dataset {
return &Dataset{Identifier{id}}
func newDataset(id C.hid_t, typ *Datatype) *Dataset {
return &Dataset{Identifier: Identifier{id}, typ: typ}
}

func createDataset(id C.hid_t, name string, dtype *Datatype, dspace *Dataspace, dcpl *PropList) (*Dataset, error) {
Expand All @@ -35,7 +37,7 @@ func createDataset(id C.hid_t, name string, dtype *Datatype, dspace *Dataspace,
if err := checkID(hid); err != nil {
return nil, err
}
return newDataset(hid), nil
return newDataset(hid, dtype), nil
}

// Close releases and terminates access to a dataset.
Expand Down Expand Up @@ -180,3 +182,11 @@ func (s *Dataset) Datatype() (*Datatype, error) {
}
return NewDatatype(dtype_id), nil
}

// hasIllegalGoPointer returns whether the Dataset is known to have
// a Go pointer to Go pointer chain. If the Dataset was created by
// a call to OpenDataset without a read operation, it will be false,
// but will not be a valid reflection of the real situation.
func (s *Dataset) hasIllegalGoPointer() bool {
return s.typ.hasIllegalGoPointer()
}
2 changes: 1 addition & 1 deletion h5g_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (g *CommonFG) OpenDataset(name string) (*Dataset, error) {
if err := checkID(hid); err != nil {
return nil, err
}
return newDataset(hid), nil
return newDataset(hid, nil), nil
}

// NumObjects returns the number of objects in the Group.
Expand Down
1 change: 1 addition & 0 deletions h5t_shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,5 +273,6 @@ func makeGoStringDatatype() *Datatype {
if err != nil {
panic(err)
}
dt.goPtrPathLen = 1 // This is the first field of the string header.
return dt
}
38 changes: 28 additions & 10 deletions h5t_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (

type Datatype struct {
Identifier

goPtrPathLen int
}

type TypeClass C.H5T_class_t
Expand Down Expand Up @@ -108,8 +110,7 @@ func OpenDatatype(c CommonFG, name string, tapl_id int) (*Datatype, error) {

// NewDatatype creates a Datatype from an hdf5 id.
func NewDatatype(id C.hid_t) *Datatype {
t := &Datatype{Identifier{id}}
return t
return &Datatype{Identifier: Identifier{id}}
}

// CreateDatatype creates a new datatype. The value of class must be T_COMPOUND,
Expand Down Expand Up @@ -152,10 +153,14 @@ func (t *Datatype) Committed() bool {
return C.H5Tcommitted(t.id) > 0
}

// Copy copies an existing datatype. The returned datatype must be closed by the
// user when it is no longer needed.
// Copy copies an existing datatype.
func (t *Datatype) Copy() (*Datatype, error) {
return copyDatatype(t.id)
c, err := copyDatatype(t.id)
if err != nil {
return nil, err
}
c.goPtrPathLen = t.goPtrPathLen
return c, nil
}

// copyDatatype should be called by any function wishing to return
Expand Down Expand Up @@ -204,7 +209,7 @@ func NewArrayType(base_type *Datatype, dims []int) (*ArrayType, error) {
if err := checkID(hid); err != nil {
return nil, err
}
t := &ArrayType{Datatype{Identifier{hid}}}
t := &ArrayType{Datatype{Identifier: Identifier{hid}}}
return t, nil
}

Expand Down Expand Up @@ -242,7 +247,8 @@ func NewVarLenType(base_type *Datatype) (*VarLenType, error) {
if err := checkID(id); err != nil {
return nil, err
}
t := &VarLenType{Datatype{Identifier{id}}}
t := &VarLenType{Datatype{Identifier: Identifier{id}}}
t.goPtrPathLen = 1 // This is the first field of the slice header.
return t, nil
}

Expand All @@ -263,7 +269,7 @@ func NewCompoundType(size int) (*CompoundType, error) {
if err := checkID(id); err != nil {
return nil, err
}
t := &CompoundType{Datatype{Identifier{id}}}
t := &CompoundType{Datatype{Identifier: Identifier{id}}}
return t, nil
}

Expand Down Expand Up @@ -438,14 +444,18 @@ func NewDataTypeFromType(t reflect.Type) (*Datatype, error) {
if err != nil {
return nil, err
}
var ptrPathLen int
n := t.NumField()
for i := 0; i < n; i++ {
f := t.Field(i)
var field_dt *Datatype = nil
var field_dt *Datatype
field_dt, err = NewDataTypeFromType(f.Type)
if err != nil {
return nil, err
}
if field_dt.goPtrPathLen > ptrPathLen {
ptrPathLen = field_dt.goPtrPathLen
}
offset := int(f.Offset + 0)
if field_dt == nil {
return nil, fmt.Errorf("pb with field [%d-%s]", i, f.Name)
Expand All @@ -460,9 +470,11 @@ func NewDataTypeFromType(t reflect.Type) (*Datatype, error) {
}
}
dt = &cdt.Datatype
dt.goPtrPathLen += ptrPathLen

case reflect.Ptr:
return NewDataTypeFromType(t.Elem())
dt, err = NewDataTypeFromType(t.Elem())
dt.goPtrPathLen++

default:
// Should never happen.
Expand All @@ -472,6 +484,12 @@ func NewDataTypeFromType(t reflect.Type) (*Datatype, error) {
return dt, err
}

// hasIllegalGoPointer returns whether the Datatype is known to have
// a Go pointer to Go pointer chain.
func (t *Datatype) hasIllegalGoPointer() bool {
return t != nil && t.goPtrPathLen > 1
}

func getArrayDims(dt reflect.Type) []int {
result := []int{}
if dt.Kind() == reflect.Array {
Expand Down
102 changes: 73 additions & 29 deletions h5t_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,74 @@

package hdf5

import (
"runtime"
"testing"
"time"
)
import "testing"

func TestSimpleDatatypes(t *testing.T) {
// Smoke tests for the simple datatypes
tests := []interface{}{
int(0),
int8(0),
int16(0),
int32(0),
int64(0),
uint(0),
uint8(0),
uint16(0),
uint32(0),
uint64(0),
float32(0),
float64(0),
string(""),
bool(true),
tests := []struct {
v interface{}
hasIllegalPtr bool
}{
{v: int(0), hasIllegalPtr: false},
{v: int8(0), hasIllegalPtr: false},
{v: int16(0), hasIllegalPtr: false},
{v: int32(0), hasIllegalPtr: false},
{v: int64(0), hasIllegalPtr: false},
{v: uint(0), hasIllegalPtr: false},
{v: uint8(0), hasIllegalPtr: false},
{v: uint16(0), hasIllegalPtr: false},
{v: uint32(0), hasIllegalPtr: false},
{v: uint64(0), hasIllegalPtr: false},
{v: float32(0), hasIllegalPtr: false},
{v: float64(0), hasIllegalPtr: false},
{v: string(""), hasIllegalPtr: false},
{v: ([]int)(nil), hasIllegalPtr: false},
{v: [1]int{0}, hasIllegalPtr: false},
{v: bool(true), hasIllegalPtr: false},
{v: (*int)(nil), hasIllegalPtr: false},
{v: (*int8)(nil), hasIllegalPtr: false},
{v: (*int16)(nil), hasIllegalPtr: false},
{v: (*int32)(nil), hasIllegalPtr: false},
{v: (*int64)(nil), hasIllegalPtr: false},
{v: (*uint)(nil), hasIllegalPtr: false},
{v: (*uint8)(nil), hasIllegalPtr: false},
{v: (*uint16)(nil), hasIllegalPtr: false},
{v: (*uint32)(nil), hasIllegalPtr: false},
{v: (*uint64)(nil), hasIllegalPtr: false},
{v: (*float32)(nil), hasIllegalPtr: false},
{v: (*float64)(nil), hasIllegalPtr: false},
{v: (*string)(nil), hasIllegalPtr: true},
{v: (*[]int)(nil), hasIllegalPtr: true},
{v: (*[1]int)(nil), hasIllegalPtr: false},
{v: (*bool)(nil), hasIllegalPtr: false},
{v: (**int)(nil), hasIllegalPtr: true},
{v: (**int8)(nil), hasIllegalPtr: true},
{v: (**int16)(nil), hasIllegalPtr: true},
{v: (**int32)(nil), hasIllegalPtr: true},
{v: (**int64)(nil), hasIllegalPtr: true},
{v: (**uint)(nil), hasIllegalPtr: true},
{v: (**uint8)(nil), hasIllegalPtr: true},
{v: (**uint16)(nil), hasIllegalPtr: true},
{v: (**uint32)(nil), hasIllegalPtr: true},
{v: (**uint64)(nil), hasIllegalPtr: true},
{v: (**float32)(nil), hasIllegalPtr: true},
{v: (**float64)(nil), hasIllegalPtr: true},
{v: (**string)(nil), hasIllegalPtr: true},
{v: (**[]int)(nil), hasIllegalPtr: true},
{v: (**[1]int)(nil), hasIllegalPtr: true},
{v: (**bool)(nil), hasIllegalPtr: true},
}

for test := range tests {
NewDatatypeFromValue(test)
// Test again for usage with ptrs
NewDatatypeFromValue(&test)
for _, test := range tests {
dt, err := NewDatatypeFromValue(test.v)
if err != nil {
t.Errorf("unexpected error: %v", err)
continue
}
gotIllegalPtr := dt.hasIllegalGoPointer()
if gotIllegalPtr != test.hasIllegalPtr {
t.Errorf("unexpected illegal pointer status for %T: got:%t want:%t", test.v, gotIllegalPtr, test.hasIllegalPtr)
}
}
}

Expand All @@ -49,6 +88,9 @@ func TestArrayDatatype(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if dt.hasIllegalGoPointer() {
t.Errorf("unexpected illegal pointer for %T", val)
}
adt := ArrayType{*dt}
if adt.NDims() != dims {
t.Errorf("wrong number of dimensions: got %d, want %d", adt.NDims(), dims)
Expand All @@ -75,13 +117,19 @@ func TestStructDatatype(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if dtype.hasIllegalGoPointer() {
t.Errorf("unexpected illegal pointer for %T", test)
}
dtypes = append(dtypes, dtype)

// pointer to value
dtype, err = NewDatatypeFromValue(test)
dtype, err = NewDatatypeFromValue(&test)
if err != nil {
t.Fatal(err)
}
if !dtype.hasIllegalGoPointer() {
t.Errorf("expected illegal pointer for %T", &test)
}
dtypes = append(dtypes, dtype)

for _, dtype := range dtypes {
Expand Down Expand Up @@ -138,8 +186,4 @@ func TestCloseBehavior(t *testing.T) {
t.Fatal(err)
}
defer dtype.Close()

// Sleep to ensure GC runs before returning
runtime.GC()
time.Sleep(100 * time.Millisecond)
}

0 comments on commit c257073

Please sign in to comment.