-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsocket.scala
235 lines (199 loc) · 8.48 KB
/
socket.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
package org.apache.spark.sql.execution.streaming
// CLONE
//
// This file contains a copy of org.apache.spark.sql.execution.streaming.TextSocketSource
// modified to emit more logging information in the methods initialize(), getOffset() and getBatch().
// So that the additional logging does not overly slow down processing we don't print the content of every event
// received over the socket right away. We put those into a buffer, whose contents we can retrieve all
// in one shot via getMsgs().
//
// Note that we use the same package name as the original DataSource because the original classes accessed
// some methods that were only visible in the scope of this package.
//
// Our streaming job will use this modified class if we specify its name as an argument to format() when
// constructing a stream reader, as in:
//
// sparkSession.readStream.format("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider2")
import java.io.{BufferedReader, IOException, InputStreamReader}
import java.net.Socket
import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.util.{Calendar, Date, Locale}
import javax.annotation.concurrent.GuardedBy
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.util.{Failure, Success, Try}
object TextSocketSource2 {
val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
StructField("timestamp", TimestampType) :: Nil)
val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
val msgs: ListBuffer[String] = mutable.ListBuffer[String]() // CLONE
def getMsgs(): List[String] = msgs.toList // CLONE
}
/**
* A source that reads text lines through a TCP socket, designed only for tutorials and debugging.
* This source will *not* work in production applications due to multiple reasons, including no
* support for fault recovery and keeping all of the text read in memory forever.
*/
class TextSocketSource2(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext)
extends Source with Logging {
@GuardedBy("this")
private var socket: Socket = null
@GuardedBy("this")
private var readThread: Thread = null
/**
* All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
protected val batches = new ListBuffer[(String, Timestamp)]
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
initialize()
private def initialize(): Unit = synchronized {
socket = new Socket(host, port)
val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
readThread = new Thread(s"TextSocketSource($host, $port)") {
setDaemon(true)
override def run(): Unit = {
try {
while (true) {
val line: String = reader.readLine()
TextSocketSource2.msgs += s"socket read at ${new Date().toString} line:" + line
//System.out.println(s"socket read at ${new Date().toString} line:" + line); // CLONE
if (line == null) {
// End of file reached
logWarning(s"Stream closed by $host:$port")
return
}
TextSocketSource2.this.synchronized {
val newData = (line,
Timestamp.valueOf(
TextSocketSource2.DATE_FORMAT.format(Calendar.getInstance().getTime()))
)
currentOffset = currentOffset + 1
batches.append(newData)
}
}
} catch {
case e: IOException =>
}
}
}
readThread.start()
}
/** Returns the schema of the data from this source */
override def schema: StructType = if (includeTimestamp) TextSocketSource2.SCHEMA_TIMESTAMP
else TextSocketSource2.SCHEMA_REGULAR
override def getOffset: Option[Offset] = synchronized {
val retval = if (currentOffset.offset == -1) {
None
} else {
Some(currentOffset)
}
println(s" at ${new Date().toString} getOffset: " + retval) // CLONE
retval
}
/** Returns the data that is between the offsets (`start`, `end`]. */
override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
println(s" at ${new Date().toString} getBatch start:" + start + ". end: " + end) // CLONE
val startOrdinal =
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
// Internal buffer only holds the batches after lastOffsetCommitted
val rawList = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
batches.slice(sliceStart, sliceEnd)
}
val rdd = sqlContext.sparkContext
.parallelize(rawList)
.map {
case (v, ts) =>
//println(s" to row at ${new Date().toString} $v")
InternalRow(UTF8String.fromString(v), ts.getTime)
}
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
override def commit(end: Offset): Unit = synchronized {
val newOffset = LongOffset.convert(end).getOrElse(
sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " +
s"originate with an instance of this class")
)
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
if (offsetDiff < 0) {
sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
}
batches.trimStart(offsetDiff)
lastOffsetCommitted = newOffset
}
/** Stop this source. */
override def stop(): Unit = synchronized {
if (socket != null) {
try {
// Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to
// stop the readThread is to close the socket.
socket.close()
} catch {
case e: IOException =>
}
socket = null
}
}
override def toString: String = s"TextSocketSource[host: $host, port: $port]"
}
class TextSocketSourceProvider2 extends StreamSourceProvider with DataSourceRegister with Logging {
private def parseIncludeTimestamp(params: Map[String, String]): Boolean = {
Try(params.getOrElse("includeTimestamp", "false").toBoolean) match {
case Success(bool) => bool
case Failure(_) =>
throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
}
}
/** Returns the name and schema of the source that can be used to continually read data. */
override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
logWarning("The socket source should not be used for production applications! " +
"It does not support recovery.")
if (!parameters.contains("host")) {
throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
}
if (!parameters.contains("port")) {
throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
}
if (schema.nonEmpty) {
throw new AnalysisException("The socket source does not support a user-specified schema.")
}
val sourceSchema =
if (parseIncludeTimestamp(parameters)) {
TextSocketSource2.SCHEMA_TIMESTAMP
} else {
TextSocketSource2.SCHEMA_REGULAR
}
("textSocket", sourceSchema)
}
override def createSource(
sqlContext: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
val host = parameters("host")
val port = parameters("port").toInt
new TextSocketSource2(host, port, parseIncludeTimestamp(parameters), sqlContext)
}
/** String that represents the format that this data source provider uses. */
override def shortName(): String = "socket2" // CLONE
}