-
Notifications
You must be signed in to change notification settings - Fork 63
/
NIO2SocketServerGroup.scala
161 lines (142 loc) · 5.64 KB
/
NIO2SocketServerGroup.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*
* Copyright 2014 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.http4s.blaze.channel.nio2
import java.net.InetSocketAddress
import java.nio.channels._
import java.util.Date
import java.util.concurrent.ThreadFactory
import org.http4s.blaze.channel._
import org.http4s.blaze.pipeline.Command.Connected
import org.http4s.blaze.util.{BasicThreadFactory, Execution}
import org.log4s.getLogger
import scala.util.{Failure, Success, Try}
import scala.util.control.NonFatal
@deprecated("Prefer NIO1 over NIO2", "0.14.15")
object NIO2SocketServerGroup {
/** Create a new fixed size NIO2 SocketServerGroup
*
* @param workerThreads number of worker threads for the new group
* @param bufferSize buffer size use for IO operations
* @param channelOptions options to apply to the client connections
*/
def fixedGroup(
workerThreads: Int = DefaultPoolSize,
bufferSize: Int = DefaultBufferSize,
channelOptions: ChannelOptions = ChannelOptions.DefaultOptions,
threadFactory: ThreadFactory = DefaultThreadFactory
): NIO2SocketServerGroup = {
val group =
AsynchronousChannelGroup.withFixedThreadPool(workerThreads, threadFactory)
apply(bufferSize, Some(group), channelOptions)
}
private val DefaultThreadFactory =
BasicThreadFactory(prefix = s"blaze-nio2-fixed-pool", daemonThreads = false)
/** Create a new NIO2 SocketServerGroup
*
* @param bufferSize buffer size use for IO operations
* @param group optional `AsynchronousChannelGroup`, uses the system default if `None`
* @param channelOptions options to apply to the client connections
*/
def apply(
bufferSize: Int = 8 * 1024,
group: Option[AsynchronousChannelGroup] = None,
channelOptions: ChannelOptions = ChannelOptions.DefaultOptions
): NIO2SocketServerGroup =
new NIO2SocketServerGroup(bufferSize, group.orNull, channelOptions)
}
final class NIO2SocketServerGroup private (
bufferSize: Int,
group: AsynchronousChannelGroup,
channelOptions: ChannelOptions)
extends ServerChannelGroup {
private[this] val logger = getLogger
/** Closes the group along with all current connections.
*
* __WARNING:__ the default group, or the system wide group, will __NOT__ be shut down and
* will result in an `IllegalStateException`.
*/
override def closeGroup(): Unit =
if (group != null) {
logger.info("Closing NIO2 SocketChannelServerGroup")
group.shutdownNow()
} else
throw new IllegalStateException(
"Cannot shut down the system default AsynchronousChannelGroup.")
def bind(address: InetSocketAddress, service: SocketPipelineBuilder): Try[ServerChannel] =
Try {
val ch = AsynchronousServerSocketChannel.open(group).bind(address)
val serverChannel =
new NIO2ServerChannel(ch.getLocalAddress.asInstanceOf[InetSocketAddress], ch, service)
serverChannel.listen()
serverChannel
}
private[this] final class NIO2ServerChannel(
val socketAddress: InetSocketAddress,
ch: AsynchronousServerSocketChannel,
service: SocketPipelineBuilder)
extends ServerChannel {
override protected def closeChannel(): Unit =
if (ch.isOpen()) {
logger.info(s"Closing NIO2 channel $socketAddress at ${new Date}")
try ch.close()
catch {
case NonFatal(t) => logger.debug(t)("Failure during channel close")
}
}
def errorClose(e: Throwable): Unit = {
logger.error(e)("Server socket channel closed with error.")
normalClose()
}
def normalClose(): Unit =
try close()
catch {
case NonFatal(e) =>
logger.error(e)("Error on NIO2ServerChannel shutdown invoked by listen loop.")
}
def listen(): Unit = {
val handler = new CompletionHandler[AsynchronousSocketChannel, Null] {
override def completed(ch: AsynchronousSocketChannel, attachment: Null): Unit = {
// Constructs a new pipeline, presuming the
// pipeline builder accepts the socket
service(new NIO2SocketConnection(ch)).onComplete {
case Success(tail) =>
channelOptions.applyToChannel(ch)
tail
.base(new ByteBufferHead(ch, bufferSize))
.sendInboundCommand(Connected)
case Failure(ex) =>
val address = ch.getRemoteAddress
ch.close()
logger.info(ex)(s"Rejected connection from $address")
}(Execution.trampoline)
listen() // Continue the circle of life
}
override def failed(exc: Throwable, attachment: Null): Unit =
exc match {
case _: AsynchronousCloseException => normalClose()
case _: ClosedChannelException => normalClose()
case _: ShutdownChannelGroupException => normalClose()
case _ =>
logger.error(exc)(s"Error accepting connection on address $socketAddress")
// If the server channel cannot go on, disconnect it.
if (!ch.isOpen()) errorClose(exc)
}
}
try ch.accept(null, handler)
catch { case NonFatal(t) => handler.failed(t, null) }
}
}
}