-
Notifications
You must be signed in to change notification settings - Fork 12
/
UDPPeerGroup.scala
146 lines (116 loc) · 5.48 KB
/
UDPPeerGroup.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package io.iohk.scalanet.peergroup
import java.net.InetSocketAddress
import java.util.concurrent.ConcurrentHashMap
import io.iohk.decco.Codec
import io.iohk.scalanet.peergroup.PeerGroup.TerminalPeerGroup
import io.iohk.scalanet.peergroup.UDPPeerGroup._
import io.netty.bootstrap.Bootstrap
import io.netty.buffer.Unpooled
import io.netty.channel._
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.DatagramPacket
import io.netty.channel.socket.nio.NioDatagramChannel
import io.netty.util
import monix.eval.Task
import monix.execution.Scheduler
import monix.reactive.Observable
import org.slf4j.LoggerFactory
import scala.collection.JavaConverters._
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future, Promise}
import scala.util.Success
class UDPPeerGroup[M](val config: Config)(implicit scheduler: Scheduler, codec: Codec[M])
extends TerminalPeerGroup[InetMultiAddress, M]() {
private val log = LoggerFactory.getLogger(getClass)
private val channelSubscribers =
new Subscribers[Channel[InetMultiAddress, M]](s"Channel Subscribers for UDPPeerGroup@'$processAddress'")
private val workerGroup = new NioEventLoopGroup()
private val activeChannels = new ConcurrentHashMap[ChannelId, Subscribers[M]]().asScala
/**
* 64 kilobytes is the theoretical maximum size of a complete IP datagram
* https://stackoverflow.com/questions/9203403/java-datagrampacket-udp-maximum-send-recv-buffer-size
*/
private val bootstrap = new Bootstrap()
.group(workerGroup)
.channel(classOf[NioDatagramChannel])
.option[RecvByteBufAllocator](ChannelOption.RCVBUF_ALLOCATOR, new DefaultMaxBytesRecvByteBufAllocator)
.handler(new ChannelInitializer[NioDatagramChannel]() {
override def initChannel(ch: NioDatagramChannel): Unit = {
new ChannelImpl(ch, Promise[InetMultiAddress]())
}
})
private val serverBind: ChannelFuture = bootstrap.bind(config.bindAddress)
override def initialize(): Task[Unit] =
toTask(serverBind).map(_ => log.info(s"Server bound to address ${config.bindAddress}"))
override def processAddress: InetMultiAddress = config.processAddress
override def client(to: InetMultiAddress): Task[Channel[InetMultiAddress, M]] = {
val cf = bootstrap.connect(to.inetSocketAddress)
val ct: Task[NioDatagramChannel] = toTask(cf).map(_ => cf.channel().asInstanceOf[NioDatagramChannel])
ct.map(nettyChannel => new ChannelImpl(nettyChannel, Promise().complete(Success(to))))
}
override def server(): Observable[Channel[InetMultiAddress, M]] = channelSubscribers.messageStream
override def shutdown(): Task[Unit] =
for {
_ <- toTask(serverBind.channel().close())
_ <- toTask(workerGroup.shutdownGracefully())
} yield ()
private class ChannelImpl(
val nettyChannel: NioDatagramChannel,
promisedRemoteAddress: Promise[InetMultiAddress]
)(implicit codec: Codec[M])
extends ChannelInboundHandlerAdapter
with Channel[InetMultiAddress, M] {
nettyChannel.pipeline().addLast(this)
private val messageSubscribersF: Future[Subscribers[M]] = for {
remoteAddress <- promisedRemoteAddress.future
} yield {
log.debug(
s"New channel created with id ${nettyChannel.id()} from ${nettyChannel.localAddress()} to $remoteAddress"
)
activeChannels.getOrElseUpdate(nettyChannel.id, new Subscribers[M])
}
override def to: InetMultiAddress = Await.result(promisedRemoteAddress.future, Duration.Inf)
override def sendMessage(message: M): Task[Unit] =
for {
remoteAddress <- Task.fromFuture(promisedRemoteAddress.future)
sendResult <- sendMessage(message, remoteAddress, nettyChannel)
} yield sendResult
override def in: Observable[M] = Await.result(messageSubscribersF, Duration.Inf).messageStream
override def close(): Task[Unit] = {
activeChannels.remove(nettyChannel.id)
Task.unit
}
override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {
val datagram = msg.asInstanceOf[DatagramPacket]
val remoteAddress = datagram.sender()
if (!promisedRemoteAddress.isCompleted) {
promisedRemoteAddress.complete(Success(new InetMultiAddress(remoteAddress)))
channelSubscribers.notify(this)
}
codec.decode(datagram.content().nioBuffer().asReadOnlyBuffer()).map { m =>
messageSubscribersF.foreach { messageSubscribers =>
log.debug(
s"Processing inbound message from remote address remote ${ctx.channel().remoteAddress()} " +
s"to local address ${ctx.channel().localAddress()} via channel id ChannelId ${ctx.channel().id()}."
)
messageSubscribers.notify(m)
}
}
}
private def sendMessage(message: M, to: InetMultiAddress, nettyChannel: NioDatagramChannel): Task[Unit] = {
val nettyBuffer = Unpooled.wrappedBuffer(codec.encode(message))
toTask(nettyChannel.writeAndFlush(new DatagramPacket(nettyBuffer, to.inetSocketAddress, processAddress.inetSocketAddress)))
}
}
private def toTask(f: util.concurrent.Future[_]): Task[Unit] = {
val promisedCompletion = Promise[Unit]()
f.addListener((_: util.concurrent.Future[_]) => promisedCompletion.complete(Success(())))
Task.fromFuture(promisedCompletion.future)
}
}
object UDPPeerGroup {
case class Config(bindAddress: InetSocketAddress, processAddress: InetMultiAddress)
object Config {
def apply(bindAddress: InetSocketAddress): Config = Config(bindAddress, new InetMultiAddress(bindAddress))
}
}