/
build.gradle
86 lines (77 loc) · 3.29 KB
/
build.gradle
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
plugins {
id 'application'
}
repositories {
maven {
url "https://oss.sonatype.org/content/repositories/snapshots/"
}
}
dependencies {
implementation "commons-cli:commons-cli:${commons_cli_version}"
implementation "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
implementation project(":basicdataset")
implementation project(":model-zoo")
if (System.getProperty("ai.djl.default_engine") == "PyTorch") {
runtimeOnly project(":pytorch:pytorch-model-zoo")
runtimeOnly "ai.djl.pytorch:pytorch-native-auto:${pytorch_version}"
} else if (System.getProperty("ai.djl.default_engine") == "TensorFlow") {
runtimeOnly project(":tensorflow:tensorflow-model-zoo")
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:${tensorflow_version}"
} else {
runtimeOnly project(":mxnet:mxnet-model-zoo")
runtimeOnly "ai.djl.mxnet:mxnet-native-auto:${mxnet_version}"
}
testImplementation("org.testng:testng:${testng_version}") {
exclude group: "junit", module: "junit"
}
}
application {
mainClassName = System.getProperty("main", "ai.djl.examples.inference.ObjectDetection")
}
run {
environment("TF_CPP_MIN_LOG_LEVEL", "1") // turn off TensorFlow print out
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
}
task listmodels(type: JavaExec) {
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
classpath = sourceSets.main.runtimeClasspath
main = "ai.djl.examples.inference.ListModels"
}
task benchmark(type: JavaExec) {
environment("TF_CPP_MIN_LOG_LEVEL", "1") // turn off TensorFlow print out
List<String> arguments = gradle.startParameter["taskRequests"]["args"].getAt(0)
for (String argument : arguments) {
if (argument.trim().startsWith("--args")) {
String[] line = argument.split("=", 2)
if (line.length == 2) {
line = line[1].split(" ");
if (line.contains("-t")) {
if (System.properties["ai.djl.default_engine"] == "MXNet") {
environment("MXNET_ENGINE_TYPE", "NaiveEngine")
environment("OMP_NUM_THREADS", "1")
} else if (System.properties["ai.djl.default_engine"] == "PyTorch") {
System.setProperty("ai.djl.pytorch.num_interop_threads", "1")
System.setProperty("ai.djl.pytorch.num_threads", "1")
} else if (System.properties["ai.djl.default_engine"] == "TensorFlow") {
environment("OMP_NUM_THREADS", "1")
environment("TF_NUM_INTEROP_THREADS", "1")
environment("TF_NUM_INTRAOP_THREADS", "1")
}
}
break;
}
}
}
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
classpath = sourceSets.main.runtimeClasspath
// restrict the jvm heap size for better monitoring benchmark
applicationDefaultJvmArgs = ["-Xmx2g"]
main = "ai.djl.examples.inference.benchmark.Benchmark"
}
tasks.distTar.enabled = false