diff --git a/environment.go b/environment.go index dd493d7e2..9f2dc9dd4 100644 --- a/environment.go +++ b/environment.go @@ -2,6 +2,7 @@ package httpexpect import ( "errors" + "sync" "time" ) @@ -13,9 +14,9 @@ import ( // env.Put("key", "value") // value := env.GetString("key") type Environment struct { - noCopy noCopy - chain *chain - data map[string]interface{} + mu sync.RWMutex + chain *chain + data map[string]interface{} } // NewEnvironment returns a new Environment. @@ -58,6 +59,9 @@ func (e *Environment) Put(key string, value interface{}) { opChain := e.chain.enter("Put(%q)", key) defer opChain.leave() + e.mu.Lock() + defer e.mu.Unlock() + e.data[key] = value } @@ -72,6 +76,9 @@ func (e *Environment) Delete(key string) { opChain := e.chain.enter("Delete(%q)", key) defer opChain.leave() + e.mu.Lock() + defer e.mu.Unlock() + delete(e.data, key) } @@ -86,6 +93,9 @@ func (e *Environment) Has(key string) bool { opChain := e.chain.enter("Has(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + _, ok := e.data[key] return ok } @@ -102,6 +112,9 @@ func (e *Environment) Get(key string) interface{} { opChain := e.chain.enter("Get(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, _ := envValue(opChain, e.data, key) return value @@ -118,6 +131,9 @@ func (e *Environment) GetBool(key string) bool { opChain := e.chain.enter("GetBool(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return false @@ -150,6 +166,9 @@ func (e *Environment) GetInt(key string) int { opChain := e.chain.enter("GetInt(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return 0 @@ -235,6 +254,9 @@ func (e *Environment) GetFloat(key string) float64 { opChain := e.chain.enter("GetFloat(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return 0 @@ -275,6 +297,9 @@ func (e *Environment) GetString(key string) string { opChain := e.chain.enter("GetString(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return "" @@ -306,6 +331,9 @@ func (e *Environment) GetBytes(key string) []byte { opChain := e.chain.enter("GetBytes(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return nil @@ -338,6 +366,9 @@ func (e *Environment) GetDuration(key string) time.Duration { opChain := e.chain.enter("GetDuration(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return time.Duration(0) @@ -370,6 +401,9 @@ func (e *Environment) GetTime(key string) time.Time { opChain := e.chain.enter("GetTime(%q)", key) defer opChain.leave() + e.mu.RLock() + defer e.mu.RUnlock() + value, ok := envValue(opChain, e.data, key) if !ok { return time.Unix(0, 0) diff --git a/environment_test.go b/environment_test.go index 02d78856b..dfe698b2e 100644 --- a/environment_test.go +++ b/environment_test.go @@ -32,6 +32,23 @@ func TestEnvironment_Constructors(t *testing.T) { }) } +func TestEnvironment_Reentrant(t *testing.T) { + reporter := newMockReporter(t) + + env := NewEnvironment(reporter) + + reportCalled := false + reporter.reportCb = func() { + env.Put("good_key", 123) + reportCalled = true + } + + env.Get("bad_key") + env.chain.assertFailed(t) + + assert.True(t, reportCalled) +} + func TestEnvironment_Basic(t *testing.T) { env := newEnvironment(newMockChain(t)) diff --git a/mocks_test.go b/mocks_test.go index 6b14084cf..316c8f2c9 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -183,6 +183,7 @@ func (l *mockLogger) Logf(message string, args ...interface{}) { type mockReporter struct { testing *testing.T reported bool + reportCb func() } func newMockReporter(t *testing.T) *mockReporter { @@ -192,6 +193,10 @@ func newMockReporter(t *testing.T) *mockReporter { func (r *mockReporter) Errorf(message string, args ...interface{}) { r.testing.Logf("Fail: "+message, args...) r.reported = true + + if r.reportCb != nil { + r.reportCb() + } } type mockFormatter struct {