forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 1
/
SparkConnectService.scala
217 lines (191 loc) · 7.03 KB
/
SparkConnectService.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.service
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
import io.grpc.{Server, Status}
import io.grpc.netty.NettyServerBuilder
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import org.apache.spark.SparkEnv
import org.apache.spark.annotation.{Since, Unstable}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, SparkConnectServiceGrpc}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.ExtendedMode
/**
* The SparkConnectService implementation.
*
* This class implements the service stub from the generated code of GRPC.
*
* @param debug
* delegates debug behavior to the handlers.
*/
@Unstable
@Since("3.4.0")
class SparkConnectService(
debug: Boolean)
extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
with Logging {
/**
* This is the main entry method for Spark Connect and all calls to execute a plan.
*
* The plan execution is delegated to the [[SparkConnectStreamHandler]]. All error handling
* should be directly implemented in the deferred implementation. But this method catches
* generic errors.
*
* @param request
* @param responseObserver
*/
override def executePlan(request: Request, responseObserver: StreamObserver[Response]): Unit = {
try {
new SparkConnectStreamHandler(responseObserver).handle(request)
} catch {
case e: Throwable =>
log.error("Error executing plan.", e)
responseObserver.onError(
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
}
}
/**
* Analyze a plan to provide metadata and debugging information.
*
* This method is called to generate the explain plan for a SparkConnect plan. In its simplest
* implementation, the plan that is generated by the [[SparkConnectPlanner]] is used to build a
* [[Dataset]] and derive the explain string from the query execution details.
*
* Errors during planning are returned via the [[StreamObserver]] interface.
*
* @param request
* @param responseObserver
*/
override def analyzePlan(
request: Request,
responseObserver: StreamObserver[AnalyzeResponse]): Unit = {
try {
val session =
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
val logicalPlan = request.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.ROOT =>
new SparkConnectPlanner(request.getPlan.getRoot, session).transform()
case _ =>
responseObserver.onError(
new UnsupportedOperationException(
s"${request.getPlan.getOpTypeCase} not supported for analysis."))
return
}
val ds = Dataset.ofRows(session, logicalPlan)
val explainString = ds.queryExecution.explainString(ExtendedMode)
val resp = proto.AnalyzeResponse
.newBuilder()
.setExplainString(explainString)
.setClientId(request.getClientId)
resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava)
resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava)
responseObserver.onNext(resp.build())
responseObserver.onCompleted()
} catch {
case e: Throwable =>
log.error("Error analyzing plan.", e)
responseObserver.onError(
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
}
}
}
/**
* Object used for referring to SparkSessions in the SessionCache.
*
* @param userId
* @param session
*/
@Unstable
@Since("3.4.0")
private[connect] case class SessionHolder(userId: String, session: SparkSession)
/**
* Static instance of the SparkConnectService.
*
* Used to start the overall SparkConnect service and provides global state to manage the
* different SparkSession from different users connecting to the cluster.
*/
@Unstable
@Since("3.4.0")
object SparkConnectService {
private val CACHE_SIZE = 100
private val CACHE_TIMEOUT_SECONDS = 3600
// Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a
// different or complex type easily.
private type SessionCacheKey = String;
private var server: Server = _
private val userSessionMapping =
cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, SessionHolder]()
// Simple builder for creating the cache of Sessions.
private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): CacheBuilder[Object, Object] = {
var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker())
if (cacheSize >= 0) {
cacheBuilder = cacheBuilder.maximumSize(cacheSize)
}
if (timeoutSeconds >= 0) {
cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS)
}
cacheBuilder
}
/**
* Based on the `key` find or create a new SparkSession.
*/
private[connect] def getOrCreateIsolatedSession(key: SessionCacheKey): SessionHolder = {
userSessionMapping.get(
key,
() => {
SessionHolder(key, newIsolatedSession())
})
}
private def newIsolatedSession(): SparkSession = {
SparkSession.active.newSession()
}
/**
* Starts the GRPC Serivce.
*
*/
def startGRPCService(): Unit = {
val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
val port = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
val sb = NettyServerBuilder
.forPort(port)
.addService(new SparkConnectService(debugMode))
// If debug mode is configured, load the ProtoReflection service so that tools like
// grpcurl can introspect the API for debugging.
if (debugMode) {
sb.addService(ProtoReflectionService.newInstance())
}
server = sb.build
server.start()
}
// Starts the service
def start(): Unit = {
startGRPCService()
}
def stop(): Unit = {
if (server != null) {
server.shutdownNow()
}
}
}