/
memoize.go
335 lines (298 loc) · 9.94 KB
/
memoize.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package memoize defines a "promise" abstraction that enables
// memoization of the result of calling an expensive but idempotent
// function.
//
// Call p = NewPromise(f) to obtain a promise for the future result of
// calling f(), and call p.Get() to obtain that result. All calls to
// p.Get return the result of a single call of f().
// Get blocks if the function has not finished (or started).
//
// A Store is a map of arbitrary keys to promises. Use Store.Promise
// to create a promise in the store. All calls to Handle(k) return the
// same promise as long as it is in the store. These promises are
// reference-counted and must be explicitly released. Once the last
// reference is released, the promise is removed from the store.
package memoize
import (
"context"
"fmt"
"reflect"
"runtime/trace"
"sync"
"sync/atomic"
"cuelang.org/go/internal/golangorgx/tools/xcontext"
)
// Function is the type of a function that can be memoized.
//
// If the arg is a RefCounted, its Acquire/Release operations are called.
//
// The argument must not materially affect the result of the function
// in ways that are not captured by the promise's key, since if
// Promise.Get is called twice concurrently, with the same (implicit)
// key but different arguments, the Function is called only once but
// its result must be suitable for both callers.
//
// The main purpose of the argument is to avoid the Function closure
// needing to retain large objects (in practice: the snapshot) in
// memory that can be supplied at call time by any caller.
type Function func(ctx context.Context, arg interface{}) interface{}
// A RefCounted is a value whose functional lifetime is determined by
// reference counting.
//
// Its Acquire method is called before the Function is invoked, and
// the corresponding release is called when the Function returns.
// Usually both events happen within a single call to Get, so Get
// would be fine with a "borrowed" reference, but if the context is
// cancelled, Get may return before the Function is complete, causing
// the argument to escape, and potential premature destruction of the
// value. For a reference-counted type, this requires a pair of
// increment/decrement operations to extend its life.
type RefCounted interface {
// Acquire prevents the value from being destroyed until the
// returned function is called.
Acquire() func()
}
// A Promise represents the future result of a call to a function.
type Promise struct {
debug string // for observability
// refcount is the reference count in the containing Store, used by
// Store.Promise. It is guarded by Store.promisesMu on the containing Store.
refcount int32
mu sync.Mutex
// A Promise starts out IDLE, waiting for something to demand
// its evaluation. It then transitions into RUNNING state.
//
// While RUNNING, waiters tracks the number of Get calls
// waiting for a result, and the done channel is used to
// notify waiters of the next state transition. Once
// evaluation finishes, value is set, state changes to
// COMPLETED, and done is closed, unblocking waiters.
//
// Alternatively, as Get calls are cancelled, they decrement
// waiters. If it drops to zero, the inner context is
// cancelled, computation is abandoned, and state resets to
// IDLE to start the process over again.
state state
// done is set in running state, and closed when exiting it.
done chan struct{}
// cancel is set in running state. It cancels computation.
cancel context.CancelFunc
// waiters is the number of Gets outstanding.
waiters uint
// the function that will be used to populate the value
function Function
// value is set in completed state.
value interface{}
}
// NewPromise returns a promise for the future result of calling the
// specified function.
//
// The debug string is used to classify promises in logs and metrics.
// It should be drawn from a small set.
func NewPromise(debug string, function Function) *Promise {
if function == nil {
panic("nil function")
}
return &Promise{
debug: debug,
function: function,
}
}
type state int
const (
stateIdle = iota // newly constructed, or last waiter was cancelled
stateRunning // start was called and not cancelled
stateCompleted // function call ran to completion
)
// Cached returns the value associated with a promise.
//
// It will never cause the value to be generated.
// It will return the cached value, if present.
func (p *Promise) Cached() interface{} {
p.mu.Lock()
defer p.mu.Unlock()
if p.state == stateCompleted {
return p.value
}
return nil
}
// Get returns the value associated with a promise.
//
// All calls to Promise.Get on a given promise return the
// same result but the function is called (to completion) at most once.
//
// If the value is not yet ready, the underlying function will be invoked.
//
// If ctx is cancelled, Get returns (nil, Canceled).
// If all concurrent calls to Get are cancelled, the context provided
// to the function is cancelled. A later call to Get may attempt to
// call the function again.
func (p *Promise) Get(ctx context.Context, arg interface{}) (interface{}, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
p.mu.Lock()
switch p.state {
case stateIdle:
return p.run(ctx, arg)
case stateRunning:
return p.wait(ctx)
case stateCompleted:
defer p.mu.Unlock()
return p.value, nil
default:
panic("unknown state")
}
}
// run starts p.function and returns the result. p.mu must be locked.
func (p *Promise) run(ctx context.Context, arg interface{}) (interface{}, error) {
childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
p.cancel = cancel
p.state = stateRunning
p.done = make(chan struct{})
function := p.function // Read under the lock
// Make sure that the argument isn't destroyed while we're running in it.
release := func() {}
if rc, ok := arg.(RefCounted); ok {
release = rc.Acquire()
}
go func() {
trace.WithRegion(childCtx, fmt.Sprintf("Promise.run %s", p.debug), func() {
defer release()
// Just in case the function does something expensive without checking
// the context, double-check we're still alive.
if childCtx.Err() != nil {
return
}
v := function(childCtx, arg)
if childCtx.Err() != nil {
return
}
p.mu.Lock()
defer p.mu.Unlock()
// It's theoretically possible that the promise has been cancelled out
// of the run that started us, and then started running again since we
// checked childCtx above. Even so, that should be harmless, since each
// run should produce the same results.
if p.state != stateRunning {
return
}
p.value = v
p.function = nil // aid GC
p.state = stateCompleted
close(p.done)
})
}()
return p.wait(ctx)
}
// wait waits for the value to be computed, or ctx to be cancelled. p.mu must be locked.
func (p *Promise) wait(ctx context.Context) (interface{}, error) {
p.waiters++
done := p.done
p.mu.Unlock()
select {
case <-done:
p.mu.Lock()
defer p.mu.Unlock()
if p.state == stateCompleted {
return p.value, nil
}
return nil, nil
case <-ctx.Done():
p.mu.Lock()
defer p.mu.Unlock()
p.waiters--
if p.waiters == 0 && p.state == stateRunning {
p.cancel()
close(p.done)
p.state = stateIdle
p.done = nil
p.cancel = nil
}
return nil, ctx.Err()
}
}
// An EvictionPolicy controls the eviction behavior of keys in a Store when
// they no longer have any references.
type EvictionPolicy int
const (
// ImmediatelyEvict evicts keys as soon as they no longer have references.
ImmediatelyEvict EvictionPolicy = iota
// NeverEvict does not evict keys.
NeverEvict
)
// A Store maps arbitrary keys to reference-counted promises.
//
// The zero value is a valid Store, though a store may also be created via
// NewStore if a custom EvictionPolicy is required.
type Store struct {
evictionPolicy EvictionPolicy
promisesMu sync.Mutex
promises map[interface{}]*Promise
}
// NewStore creates a new store with the given eviction policy.
func NewStore(policy EvictionPolicy) *Store {
return &Store{evictionPolicy: policy}
}
// Promise returns a reference-counted promise for the future result of
// calling the specified function.
//
// Calls to Promise with the same key return the same promise, incrementing its
// reference count. The caller must call the returned function to decrement
// the promise's reference count when it is no longer needed. The returned
// function must not be called more than once.
//
// Once the last reference has been released, the promise is removed from the
// store.
func (store *Store) Promise(key interface{}, function Function) (*Promise, func()) {
store.promisesMu.Lock()
p, ok := store.promises[key]
if !ok {
p = NewPromise(reflect.TypeOf(key).String(), function)
if store.promises == nil {
store.promises = map[interface{}]*Promise{}
}
store.promises[key] = p
}
p.refcount++
store.promisesMu.Unlock()
var released int32
release := func() {
if !atomic.CompareAndSwapInt32(&released, 0, 1) {
panic("release called more than once")
}
store.promisesMu.Lock()
p.refcount--
if p.refcount == 0 && store.evictionPolicy != NeverEvict {
// Inv: if p.refcount > 0, then store.promises[key] == p.
delete(store.promises, key)
}
store.promisesMu.Unlock()
}
return p, release
}
// Stats returns the number of each type of key in the store.
func (s *Store) Stats() map[reflect.Type]int {
result := map[reflect.Type]int{}
s.promisesMu.Lock()
defer s.promisesMu.Unlock()
for k := range s.promises {
result[reflect.TypeOf(k)]++
}
return result
}
// DebugOnlyIterate iterates through the store and, for each completed
// promise, calls f(k, v) for the map key k and function result v. It
// should only be used for debugging purposes.
func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
s.promisesMu.Lock()
defer s.promisesMu.Unlock()
for k, p := range s.promises {
if v := p.Cached(); v != nil {
f(k, v)
}
}
}