/
ConcurrentRestrictions.scala
198 lines (176 loc) · 6.53 KB
/
ConcurrentRestrictions.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
package sbt
/** Describes restrictions on concurrent execution for a set of tasks.
*
* @tparam A the type of a task
*/
trait ConcurrentRestrictions[A]
{
/** Internal state type used to describe a set of tasks. */
type G
/** Representation of zero tasks.*/
def empty: G
/** Updates the description `g` to include a new task `a`.*/
def add(g: G, a: A): G
/** Updates the description `g` to remove a previously added task `a`.*/
def remove(g: G, a: A): G
/**
* Returns true if the tasks described by `g` are allowed to execute concurrently.
* The methods in this class must obey the following laws:
*
* 1. forall g: G, a: A; valid(g) => valid(remove(g,a))
* 2. forall a: A; valid(add(empty, a))
* 3. forall g: G, a: A; valid(g) <=> valid(remove(add(g, a), a))
* 4. (implied by 1,2,3) valid(empty)
* 5. forall g: G, a: A, b: A; !valid(add(g,a)) => !valid(add(add(g,b), a))
*/
def valid(g: G): Boolean
}
import java.util.{LinkedList,Queue}
import java.util.concurrent.{Executor, Executors, ExecutorCompletionService}
import annotation.tailrec
object ConcurrentRestrictions
{
/** A ConcurrentRestrictions instance that places no restrictions on concurrently executing tasks.
* @param zero the constant placeholder used for t */
def unrestricted[A]: ConcurrentRestrictions[A] =
new ConcurrentRestrictions[A]
{
type G = Unit
def empty = ()
def add(g: G, a: A) = ()
def remove(g: G, a: A) = ()
def valid(g: G) = true
}
def limitTotal[A](i: Int): ConcurrentRestrictions[A] =
{
assert(i >= 1, "Maximum must be at least 1 (was " + i + ")")
new ConcurrentRestrictions[A]
{
type G = Int
def empty = 0
def add(g: Int, a: A) = g + 1
def remove(g: Int, a: A) = g - 1
def valid(g: Int) = g <= i
}
}
/** A key object used for associating information with a task.*/
final case class Tag(name: String)
val tagsKey = AttributeKey[TagMap]("tags", "Attributes restricting concurrent execution of tasks.")
/** A standard tag describing the number of tasks that do not otherwise have any tags.*/
val Untagged = Tag("untagged")
/** A standard tag describing the total number of tasks. */
val All = Tag("all")
type TagMap = Map[Tag, Int]
/** Implements concurrency restrictions on tasks based on Tags.
* @tparam A type of a task
* @param get extracts tags from a task
* @param validF defines whether a set of tasks are allowed to execute concurrently based on their merged tags*/
def tagged[A](get: A => TagMap, validF: TagMap => Boolean): ConcurrentRestrictions[A] =
new ConcurrentRestrictions[A]
{
type G = TagMap
def empty = Map.empty
def add(g: TagMap, a: A) = merge(g, a, get)(_ + _)
def remove(g: TagMap, a: A) = merge(g, a, get)(_ - _)
def valid(g: TagMap) = validF(g)
}
private[this] def merge[A](m: TagMap, a: A, get: A => TagMap)(f: (Int,Int) => Int): TagMap =
{
val base = merge(m, get(a))(f)
val un = if(base.isEmpty) update(base, Untagged, 1)(f) else base
update(un, All, 1)(f)
}
private[this] def update[A,B](m: Map[A,B], a: A, b: B)(f: (B,B) => B): Map[A,B] =
{
val newb =
(m get a) match {
case Some(bv) => f(bv,b)
case None => b
}
m.updated(a,newb)
}
private[this] def merge[A,B](m: Map[A,B], n: Map[A,B])(f: (B,B) => B): Map[A,B] =
(m /: n) { case (acc, (a,b)) => update(acc, a, b)(f) }
/** Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution.
* @return a pair, with _1 being the CompletionService and _2 a function to shutdown the service.
* @tparam A the task type
* @tparam G describes a set of tasks
* @tparam R the type of data that will be computed by the CompletionService. */
def completionService[A,R](tags: ConcurrentRestrictions[A], warn: String => Unit): (CompletionService[A,R], () => Unit) =
{
val pool = Executors.newCachedThreadPool()
(completionService[A,R](pool, tags, warn), () => pool.shutdownNow() )
}
/** Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution
* and using the provided Executor to manage execution on threads. */
def completionService[A,R](backing: Executor, tags: ConcurrentRestrictions[A], warn: String => Unit): CompletionService[A,R] =
{
/** Represents submitted work for a task.*/
final class Enqueue(val node: A, val work: () => R)
new CompletionService[A,R]
{
/** Backing service used to manage execution on threads once all constraints are satisfied. */
private[this] val jservice = new ExecutorCompletionService[R](backing)
/** The description of the currently running tasks, used by `tags` to manage restrictions.*/
private[this] var tagState = tags.empty
/** The number of running tasks. */
private[this] var running = 0
/** Tasks that cannot be run yet because they cannot execute concurrently with the currently running tasks.*/
private[this] val pending = new LinkedList[Enqueue]
def submit(node: A, work: () => R): Unit = synchronized
{
val newState = tags.add(tagState, node)
// if the new task is allowed to run concurrently with the currently running tasks,
// submit it to be run by the backing j.u.c.CompletionService
if(tags valid newState)
{
tagState = newState
submitValid( node, work )
}
else
{
if(running == 0) errorAddingToIdle()
pending.add( new Enqueue(node, work) )
}
}
private[this] def submitValid(node: A, work: () => R) =
{
running += 1
val wrappedWork = () => try work() finally cleanup(node)
CompletionService.submit(wrappedWork, jservice)
}
private[this] def cleanup(node: A): Unit = synchronized
{
running -= 1
tagState = tags.remove(tagState, node)
if(!tags.valid(tagState)) warn("Invalid restriction: removing a completed node from a valid system must result in a valid system.")
submitValid(new LinkedList)
}
private[this] def errorAddingToIdle() = warn("Invalid restriction: adding a node to an idle system must be allowed.")
/** Submits pending tasks that are now allowed to executed. */
@tailrec private[this] def submitValid(tried: Queue[Enqueue]): Unit =
if(pending.isEmpty)
{
if(!tried.isEmpty)
{
if(running == 0) errorAddingToIdle()
pending.addAll(tried)
}
}
else
{
val next = pending.remove()
val newState = tags.add(tagState, next.node)
if(tags.valid(newState))
{
tagState = newState
submitValid(next.node, next.work)
}
else
tried.add(next)
submitValid(tried)
}
def take(): R = jservice.take().get()
}
}
}