From 52c15173205b4590ff91f013e4f980eb0498f56b Mon Sep 17 00:00:00 2001 From: limpo1989 Date: Fri, 10 Nov 2023 11:25:06 +0800 Subject: [PATCH] Redesign the BeanInit interface to get away from Go-Spring dependence --- gs/gs.go | 13 +++++ gs/gs_bean.go | 37 ++++++++++-- gs/gs_test.go | 159 +++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 169 insertions(+), 40 deletions(-) diff --git a/gs/gs.go b/gs/gs.go index 33f07d8c..10df0979 100644 --- a/gs/gs.go +++ b/gs/gs.go @@ -79,6 +79,19 @@ type Context interface { Go(fn func(ctx context.Context)) } +type contextKey struct{} + +func WithContext(ctx Context) context.Context { + return context.WithValue(ctx.Context(), contextKey{}, ctx) +} + +func FromContext(ctx context.Context) Context { + if val := ctx.Value(contextKey{}); val != nil { + return val.(Context) + } + return nil +} + type tempContainer struct { props *conf.Properties beans []*BeanDefinition diff --git a/gs/gs_bean.go b/gs/gs_bean.go index 37866e80..b51ebdc6 100644 --- a/gs/gs_bean.go +++ b/gs/gs_bean.go @@ -17,6 +17,7 @@ package gs import ( + "context" "errors" "fmt" "reflect" @@ -66,7 +67,7 @@ func BeanID(typ interface{}, name string) string { } type BeanInit interface { - OnInit(ctx Context) error + OnInit(ctx context.Context) error } type BeanDestroy interface { @@ -215,9 +216,24 @@ func validLifeCycleFunc(fnType reflect.Type, beanValue reflect.Value) bool { if !utils.IsFuncType(fnType) { return false } - if fnType.NumIn() != 1 || !utils.HasReceiver(fnType, beanValue) { + + switch fnType.NumIn() { + case 1: + // func(bean) + // func(bean) error + if !utils.HasReceiver(fnType, beanValue) { + return false + } + case 2: + // func(bean, ctx) + // func(bean, ctx) error + if !utils.HasReceiver(fnType, beanValue) || !utils.IsContextType(fnType.In(1)) { + return false + } + default: return false } + return utils.ReturnNothing(fnType) || utils.ReturnOnlyError(fnType) } @@ -227,7 +243,7 @@ func (d *BeanDefinition) Init(fn interface{}) *BeanDefinition { d.init = fn return d } - panic(errors.New("init should be func(bean) or func(bean)error")) + panic(errors.New("init should be func(bean,[ctx]) or func(bean,[ctx])error")) } // Destroy Set the destruction function for a bean. @@ -236,7 +252,7 @@ func (d *BeanDefinition) Destroy(fn interface{}) *BeanDefinition { d.destroy = fn return d } - panic(errors.New("destroy should be func(bean) or func(bean)error")) + panic(errors.New("destroy should be func(bean,[ctx]) or func(bean,[ctx])error")) } // Export indicates the types of interface to export. @@ -283,14 +299,19 @@ func (d *BeanDefinition) export(exports ...interface{}) error { func (d *BeanDefinition) constructor(ctx Context) error { if d.init != nil { fnValue := reflect.ValueOf(d.init) - out := fnValue.Call([]reflect.Value{d.Value()}) + fnValues := []reflect.Value{d.Value()} + if fnValue.Type().NumIn() > 1 { + fnValues = append(fnValues, reflect.ValueOf(WithContext(ctx))) + } + + out := fnValue.Call(fnValues) if len(out) > 0 && !out[0].IsNil() { return out[0].Interface().(error) } } if f, ok := d.Interface().(BeanInit); ok { - if err := f.OnInit(ctx); err != nil { + if err := f.OnInit(WithContext(ctx)); err != nil { return err } } @@ -300,6 +321,10 @@ func (d *BeanDefinition) constructor(ctx Context) error { func (d *BeanDefinition) destructor() { if d.destroy != nil { fnValue := reflect.ValueOf(d.destroy) + fnValues := []reflect.Value{d.Value()} + if fnValue.Type().NumIn() > 1 { + fnValues = append(fnValues, reflect.ValueOf(context.Background())) + } fnValue.Call([]reflect.Value{d.Value()}) } diff --git a/gs/gs_test.go b/gs/gs_test.go index d7f8b2da..5b43d36d 100644 --- a/gs/gs_test.go +++ b/gs/gs_test.go @@ -18,6 +18,7 @@ package gs import ( "bytes" + "context" "errors" "fmt" "image" @@ -926,6 +927,8 @@ type destroyable interface { Init() Destroy() InitWithError() error + InitWithCtx(ctx context.Context) + InitWithCtxError(ctx context.Context) error DestroyWithError() error } @@ -951,6 +954,27 @@ func (d *callDestroy) InitWithError() error { return fmt.Errorf("error") } +func (d *callDestroy) InitWithCtx(ctx context.Context) { + if d.i == 0 { + d.inited = true + } + if nil == FromContext(ctx) { + panic("invalid context") + } +} + +func (d *callDestroy) InitWithCtxError(ctx context.Context) error { + if d.i == 0 { + d.inited = true + + if nil == FromContext(ctx) { + return fmt.Errorf("invalid context") + } + return nil + } + return fmt.Errorf("error") +} + func (d *callDestroy) DestroyWithError() error { if d.i == 0 { d.destroyed = true @@ -971,15 +995,29 @@ func TestRegisterBean_InitFunc(t *testing.T) { t.Run("call init", func(t *testing.T) { - c := New() - c.Object(new(callDestroy)).Init((*callDestroy).Init) - err := runTest(c, func(p Context) { - var d *callDestroy - err := p.Get(&d) + { + c := New() + c.Object(new(callDestroy)).Init((*callDestroy).Init) + err := runTest(c, func(p Context) { + var d *callDestroy + err := p.Get(&d) + assert.Nil(t, err) + assert.True(t, d.inited) + }) assert.Nil(t, err) - assert.True(t, d.inited) - }) - assert.Nil(t, err) + } + + { + c := New() + c.Object(new(callDestroy)).Init((*callDestroy).InitWithCtx) + err := runTest(c, func(p Context) { + var d *callDestroy + err := p.Get(&d) + assert.Nil(t, err) + assert.True(t, d.inited) + }) + assert.Nil(t, err) + } }) t.Run("call init with error", func(t *testing.T) { @@ -991,32 +1029,72 @@ func TestRegisterBean_InitFunc(t *testing.T) { assert.Error(t, err, "error") } - c := New() - p := conf.New() - p.Set("int", 0) - c.Object(&callDestroy{}).Init((*callDestroy).InitWithError) + { + c := New() + c.Object(&callDestroy{i: 1}).Init((*callDestroy).InitWithCtxError) + err := c.Refresh() + assert.Error(t, err, "error") + } - err := c.Properties().Refresh(p) - assert.Nil(t, err) - err = runTest(c, func(p Context) { - var d *callDestroy - err = p.Get(&d) + { + c := New() + p := conf.New() + p.Set("int", 0) + c.Object(&callDestroy{}).Init((*callDestroy).InitWithError) + + err := c.Properties().Refresh(p) assert.Nil(t, err) - assert.True(t, d.inited) - }) - assert.Nil(t, err) + err = runTest(c, func(p Context) { + var d *callDestroy + err = p.Get(&d) + assert.Nil(t, err) + assert.True(t, d.inited) + }) + assert.Nil(t, err) + } + + { + c := New() + p := conf.New() + p.Set("int", 0) + c.Object(&callDestroy{}).Init((*callDestroy).InitWithCtxError) + + err := c.Properties().Refresh(p) + assert.Nil(t, err) + err = runTest(c, func(p Context) { + var d *callDestroy + err = p.Get(&d) + assert.Nil(t, err) + assert.True(t, d.inited) + }) + assert.Nil(t, err) + } }) t.Run("call interface init", func(t *testing.T) { - c := New() - c.Provide(func() destroyable { return new(callDestroy) }).Init(destroyable.Init) - err := runTest(c, func(p Context) { - var d destroyable - err := p.Get(&d) + { + c := New() + c.Provide(func() destroyable { return new(callDestroy) }).Init(destroyable.Init) + err := runTest(c, func(p Context) { + var d destroyable + err := p.Get(&d) + assert.Nil(t, err) + assert.True(t, d.(*callDestroy).inited) + }) assert.Nil(t, err) - assert.True(t, d.(*callDestroy).inited) - }) - assert.Nil(t, err) + } + + { + c := New() + c.Provide(func() destroyable { return new(callDestroy) }).Init(destroyable.InitWithCtx) + err := runTest(c, func(p Context) { + var d destroyable + err := p.Get(&d) + assert.Nil(t, err) + assert.True(t, d.(*callDestroy).inited) + }) + assert.Nil(t, err) + } }) t.Run("call interface init with error", func(t *testing.T) { @@ -1028,6 +1106,13 @@ func TestRegisterBean_InitFunc(t *testing.T) { assert.Error(t, err, "error") } + { + c := New() + c.Provide(func() destroyable { return &callDestroy{i: 1} }).Init(destroyable.InitWithCtxError) + err := c.Refresh() + assert.Error(t, err, "error") + } + c := New() p := conf.New() p.Set("int", 0) @@ -1932,22 +2017,22 @@ func TestApplicationContext_Close(t *testing.T) { assert.Panic(t, func() { c := New() c.Object(func() {}).Destroy(func() {}) - }, "destroy should be func\\(bean\\) or func\\(bean\\)error") + }, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error") assert.Panic(t, func() { c := New() c.Object(func() {}).Destroy(func() int { return 0 }) - }, "destroy should be func\\(bean\\) or func\\(bean\\)error") + }, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error") assert.Panic(t, func() { c := New() c.Object(func() {}).Destroy(func(int) {}) - }, "destroy should be func\\(bean\\) or func\\(bean\\)error") + }, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error") assert.Panic(t, func() { c := New() c.Object(func() {}).Destroy(func(int, int) {}) - }, "destroy should be func\\(bean\\) or func\\(bean\\)error") + }, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error") }) t.Run("call destroy fn", func(t *testing.T) { @@ -2848,7 +2933,10 @@ func TestLazy(t *testing.T) { type memory struct { } -func (m *memory) OnInit(ctx Context) error { +func (m *memory) OnInit(ctx context.Context) error { + if nil == FromContext(ctx) { + panic("invalid context") + } fmt.Println("memory.OnInit") return nil } @@ -2861,7 +2949,10 @@ type table struct { _ *memory `autowire:""` } -func (t *table) OnInit(ctx Context) error { +func (t *table) OnInit(ctx context.Context) error { + if nil == FromContext(ctx) { + panic("invalid context") + } fmt.Println("table.OnInit") return nil }