-
Notifications
You must be signed in to change notification settings - Fork 58
/
WorkerGroup.java
108 lines (95 loc) · 3.27 KB
/
WorkerGroup.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.wlm;
import ai.djl.Device;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
/** The {@link WorkerGroup} manages the {@link WorkerPool} for a particular {@link Device}. */
public class WorkerGroup<I, O> {
private WorkerPool<I, O> workerPool;
private Device device;
private int minWorkers;
int maxWorkers;
List<WorkerThread<I, O>> workers;
WorkerGroup(WorkerPool<I, O> workerPool, Device device) {
this.workerPool = workerPool;
this.device = device;
workers = new CopyOnWriteArrayList<>();
WorkerPoolConfig<I, O> wpc = workerPool.getWpc();
// Default workers from worker type, may be overridden by configureWorkers on init or scale
minWorkers = wpc.getMinWorkers(device);
maxWorkers = wpc.getMaxWorkers(device);
minWorkers = Math.min(minWorkers, maxWorkers);
}
/**
* Returns the device of the worker group.
*
* @return the device of the worker group
*/
public Device getDevice() {
return device;
}
/**
* Returns a list of workers.
*
* @return a list of workers
*/
public List<WorkerThread<I, O>> getWorkers() {
return workers;
}
/**
* Returns the min number of workers for the model and device.
*
* @return the min number of workers for the model and device
*/
public int getMinWorkers() {
return minWorkers;
}
/**
* Returns the max number of workers for the model and device.
*
* @return the max number of workers for the model and device
*/
public int getMaxWorkers() {
return maxWorkers;
}
/**
* Configures minimum and maximum number of workers.
*
* @param minWorkers the minimum number of workers
* @param maxWorkers the maximum number of workers
*/
public void configureWorkers(int minWorkers, int maxWorkers) {
if (minWorkers >= 0) {
this.minWorkers = minWorkers;
}
if (maxWorkers >= 0) {
this.maxWorkers = maxWorkers;
}
}
void addThreads(int count, boolean permanent) {
WorkerPoolConfig<I, O> wpc = workerPool.getWpc();
ExecutorService threadPool = workerPool.getThreadPool();
for (int i = 0; i < count; ++i) {
WorkerThread<I, O> thread =
WorkerThread.builder(wpc)
.setDevice(device)
.setJobQueue(workerPool.getJobQueue())
.optFixPoolThread(permanent)
.build();
workers.add(thread);
threadPool.submit(thread);
}
}
}