diff --git a/chell_test.go b/chell_test.go index 7aa1911..2615ec7 100644 --- a/chell_test.go +++ b/chell_test.go @@ -118,8 +118,24 @@ func TestDumpOneFilterOnlyFields(t *testing.T) { err = Dump(&taskSchema3, &task, Only("title", "simple_user"), FieldAliasMapTagName("json")) assert.Nil(t, err) - data, _ = json.Marshal(taskSchema) + data, _ = json.Marshal(taskSchema3) assert.Equal(t, `{"title":"Finish your jobs.","simple_user":{"name":"user:1"},"unknown":""}`, string(data)) + + // nil pointer + var taskSchema4 *TaskSchema + err = Dump(&taskSchema4, &task, Only("title", "simple_user"), FieldAliasMapTagName("json")) + assert.Nil(t, err) + + data, _ = json.Marshal(taskSchema4) + assert.Equal(t, `{"title":"Finish your jobs.","simple_user":{"name":"user:1"},"unknown":""}`, string(data)) + + // non-nil pointer to schema + var taskSchema5 = &TaskSchema{Unknown: "unknown"} + err = Dump(&taskSchema5, &task, Only("title", "simple_user"), FieldAliasMapTagName("json")) + assert.Nil(t, err) + + data, _ = json.Marshal(taskSchema5) + assert.Equal(t, `{"title":"Finish your jobs.","simple_user":{"name":"user:1"},"unknown":"unknown"}`, string(data)) } func TestDumpOneExcludeFields(t *testing.T) { diff --git a/schema.go b/schema.go index 75d07f1..37abb7d 100644 --- a/schema.go +++ b/schema.go @@ -31,12 +31,16 @@ func newSchema(v interface{}) *schema { case reflect.Ptr: // var schema *SchemaStruct // ptr := &schema - typ, err := innerStructType(rv.Type()) - if err != nil { - panic(fmt.Errorf("cannot get schema struct: %s", err)) + if rv.Elem().IsNil() { + typ, err := innerStructType(rv.Type()) + if err != nil { + panic(fmt.Errorf("cannot get schema struct: %s", err)) + } + schemaValue = reflect.New(typ).Elem() + rv.Elem().Set(schemaValue.Addr()) + } else { + schemaValue = rv.Elem().Elem() } - schemaValue = reflect.New(typ).Elem() - rv.Elem().Set(schemaValue.Addr()) default: panic("expect a pointer to struct") }