Skip to content
Permalink
Browse files

Update UDPPeerGroup to multiplex based on local and remote addresses.

  • Loading branch information...
jtownson committed May 16, 2019
1 parent 95383a9 commit 96ab987392d782ea2cd261030412a569165a5235
@@ -19,4 +19,4 @@ case class InetMultiAddress(private[scalanet] val inetSocketAddress: InetSocketA
val state = Seq(inetAddress)
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}
}
}
@@ -17,13 +17,13 @@ import org.slf4j.LoggerFactory
* There is no enrollment process. Instances are configured with a static table of all known peers.
*/
class SimplestPeerGroup[A, AA, M](
val config: Config[A, AA],
underLyingPeerGroup: PeerGroup[AA, Either[ControlMessage[A, AA], M]]
)(
implicit aCodec: Codec[A],
aaCodec: Codec[AA],
scheduler: Scheduler
) extends PeerGroup[A, M] {
val config: Config[A, AA],
underLyingPeerGroup: PeerGroup[AA, Either[ControlMessage[A, AA], M]]
)(
implicit aCodec: Codec[A],
aaCodec: Codec[AA],
scheduler: Scheduler
) extends PeerGroup[A, M] {

private val log = LoggerFactory.getLogger(getClass)

@@ -51,7 +51,7 @@ class SimplestPeerGroup[A, AA, M](
}

private class ChannelImpl(val to: A, underlyingChannel: Channel[AA, Either[ControlMessage[A, AA], M]])
extends Channel[A, M] {
extends Channel[A, M] {

override def sendMessage(message: M): Task[Unit] = {
underlyingChannel.sendMessage(Right(message))
@@ -78,24 +78,28 @@ class TCPPeerGroup[M](val config: Config)(implicit scheduler: Scheduler, codec:

object TCPPeerGroup {
case class Config(
bindAddress: InetSocketAddress,
processAddress: InetMultiAddress,
remoteHostConfig: Map[InetAddress, Int] = Map.empty[InetAddress, Int]
bindAddress: InetSocketAddress,
processAddress: InetMultiAddress,
remoteHostConfig: Map[InetAddress, Int] = Map.empty[InetAddress, Int]
)

object Config {
def apply(bindAddress: InetSocketAddress): Config = Config(bindAddress, new InetMultiAddress(bindAddress))
}

private[scalanet] class ServerChannelImpl[M](val nettyChannel: SocketChannel)(implicit codec: Codec[M], scheduler: Scheduler)
extends Channel[InetMultiAddress, M] {
private[scalanet] class ServerChannelImpl[M](val nettyChannel: SocketChannel)(
implicit codec: Codec[M],
scheduler: Scheduler
) extends Channel[InetMultiAddress, M] {

private val log = LoggerFactory.getLogger(getClass)

log.debug(s"Creating server channel from ${nettyChannel.localAddress()} to ${nettyChannel.remoteAddress()} with channel id ${nettyChannel.id}")
log.debug(
s"Creating server channel from ${nettyChannel.localAddress()} to ${nettyChannel.remoteAddress()} with channel id ${nettyChannel.id}"
)

private val messageSubject = ReplaySubject[M]()
//new Subscribers[M](s"Subscribers for ServerChannelImpl@${nettyChannel.id}")
//new Subscribers[M](s"Subscribers for ServerChannelImpl@${nettyChannel.id}")

nettyChannel
.pipeline()
@@ -120,8 +124,10 @@ object TCPPeerGroup {
}
}

private class ClientChannelImpl[M](inetSocketAddress: InetSocketAddress, clientBootstrap: Bootstrap)(implicit codec: Codec[M], scheduler: Scheduler)
extends Channel[InetMultiAddress, M] {
private class ClientChannelImpl[M](inetSocketAddress: InetSocketAddress, clientBootstrap: Bootstrap)(
implicit codec: Codec[M],
scheduler: Scheduler
) extends Channel[InetMultiAddress, M] {

private val log = LoggerFactory.getLogger(getClass)

@@ -144,8 +150,10 @@ object TCPPeerGroup {
.addLast(new ByteArrayEncoder())
.addLast(new ChannelInboundHandlerAdapter() {
override def channelActive(ctx: ChannelHandlerContext): Unit = {
log.debug(s"Creating client channel from ${ctx.channel().localAddress()} " +
s"to ${ctx.channel().remoteAddress()} with channel id ${ctx.channel().id}")
log.debug(
s"Creating client channel from ${ctx.channel().localAddress()} " +
s"to ${ctx.channel().remoteAddress()} with channel id ${ctx.channel().id}"
)
activation.complete(Success(ctx))
}

@@ -184,7 +192,8 @@ object TCPPeerGroup {
}
}

private class MessageNotifier[M](val messageSubject: Subject[M, M])(implicit codec: Codec[M]) extends ChannelInboundHandlerAdapter {
private class MessageNotifier[M](val messageSubject: Subject[M, M])(implicit codec: Codec[M])
extends ChannelInboundHandlerAdapter {

private val log = LoggerFactory.getLogger(getClass)

@@ -1,9 +1,10 @@
package io.iohk.scalanet.peergroup

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap

import io.iohk.decco.Codec
import io.iohk.decco.{Codec, DecodeFailure}
import io.iohk.scalanet.peergroup.PeerGroup.TerminalPeerGroup
import io.iohk.scalanet.peergroup.UDPPeerGroup._
import io.netty.bootstrap.Bootstrap
@@ -12,16 +13,15 @@ import io.netty.channel._
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.DatagramPacket
import io.netty.channel.socket.nio.NioDatagramChannel
import io.netty.util
import io.netty.{channel, util}
import monix.eval.Task
import monix.execution.Scheduler
import monix.reactive.Observable
import monix.reactive.subjects.{PublishSubject, ReplaySubject, Subject}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future, Promise}
import scala.concurrent.Promise
import scala.util.Success

class UDPPeerGroup[M](val config: Config)(implicit scheduler: Scheduler, codec: Codec[M])
@@ -33,33 +33,134 @@ class UDPPeerGroup[M](val config: Config)(implicit scheduler: Scheduler, codec:

private val workerGroup = new NioEventLoopGroup()

private val activeChannels = new ConcurrentHashMap[ChannelId, Subject[M, M]]().asScala
private val activeChannels = new ConcurrentHashMap[Seq[Byte], ChannelImpl]().asScala

/**
* 64 kilobytes is the theoretical maximum size of a complete IP datagram
* https://stackoverflow.com/questions/9203403/java-datagrampacket-udp-maximum-send-recv-buffer-size
*/
private val bootstrap = new Bootstrap()
private val clientBootstrap = new Bootstrap()
.group(workerGroup)
.channel(classOf[NioDatagramChannel])
.option[RecvByteBufAllocator](ChannelOption.RCVBUF_ALLOCATOR, new DefaultMaxBytesRecvByteBufAllocator)
.handler(new ChannelInitializer[NioDatagramChannel]() {
override def initChannel(ch: NioDatagramChannel): Unit = {
new ChannelImpl(ch, Promise[InetMultiAddress]())
override def initChannel(nettyChannel: NioDatagramChannel): Unit = {
nettyChannel
.pipeline()
.addLast(new channel.ChannelInboundHandlerAdapter() {
override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {
val datagram = msg.asInstanceOf[DatagramPacket]
val remoteAddress = datagram.sender()
val localAddress = datagram.recipient()
val messageE: Either[DecodeFailure, M] = codec.decode(datagram.content().nioBuffer().asReadOnlyBuffer())
log.info(s"Client channel read message $messageE with remote $remoteAddress and local $localAddress")

val channelId = getChannelId(remoteAddress, localAddress)

if (!activeChannels.contains(channelId)) {
throw new IllegalStateException(s"Missing channel instance for channelId $channelId")
}

val channel = activeChannels(channelId)
messageE.foreach(message => channel.messageSubject.onNext(message))
}
})
}
})

private val serverBind: ChannelFuture = bootstrap.bind(config.bindAddress)
private val serverBootstrap = new Bootstrap()
.group(workerGroup)
.channel(classOf[NioDatagramChannel])
.option[RecvByteBufAllocator](ChannelOption.RCVBUF_ALLOCATOR, new DefaultMaxBytesRecvByteBufAllocator)
.handler(new ChannelInitializer[NioDatagramChannel]() {
override def initChannel(nettyChannel: NioDatagramChannel): Unit = {
nettyChannel
.pipeline()
.addLast(new ChannelInboundHandlerAdapter() {
override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {

val datagram = msg.asInstanceOf[DatagramPacket]
val remoteAddress = datagram.sender()
val localAddress = processAddress.inetSocketAddress //datagram.recipient()

val messageE: Either[DecodeFailure, M] = codec.decode(datagram.content().nioBuffer().asReadOnlyBuffer())

log.debug(s"Server read $messageE")
val nettyChannel: NioDatagramChannel = ctx.channel().asInstanceOf[NioDatagramChannel]
val channelId = getChannelId(remoteAddress, localAddress)

if (activeChannels.contains(channelId)) {
log.debug(s"Channel with id $channelId found in active channels table.")
val channel = activeChannels(channelId)
messageE.foreach(message => channel.messageSubject.onNext(message))
} else {
val channel = new ChannelImpl(nettyChannel, localAddress, remoteAddress, ReplaySubject[M]())
log.debug(s"Channel with id $channelId NOT found in active channels table. Creating a new one")
activeChannels.put(channelId, channel)
channelSubject.onNext(channel)
messageE.foreach(message => channel.messageSubject.onNext(message))
}
}
})
}
})

class ChannelImpl(
val nettyChannel: NioDatagramChannel,
localAddress: InetSocketAddress,
remoteAddress: InetSocketAddress,
val messageSubject: Subject[M, M]
) extends Channel[InetMultiAddress, M] {

log.debug(
s"Setting up new channel from local address $localAddress " +
s"to remote address $remoteAddress. Netty channelId is ${nettyChannel.id()}. " +
s"My channelId is ${getChannelId(remoteAddress, localAddress)}"
)

override val to: InetMultiAddress = InetMultiAddress(remoteAddress)

override def sendMessage(message: M): Task[Unit] = sendMessage(message, localAddress, remoteAddress, nettyChannel)

override def in: Observable[M] = messageSubject

override def close(): Task[Unit] = {
messageSubject.onComplete()
Task.unit
}

private def sendMessage(
message: M,
sender: InetSocketAddress,
recipient: InetSocketAddress,
nettyChannel: NioDatagramChannel
): Task[Unit] = {
val nettyBuffer = Unpooled.wrappedBuffer(codec.encode(message))
toTask(nettyChannel.writeAndFlush(new DatagramPacket(nettyBuffer, recipient, sender)))
}
}

private val serverBind: ChannelFuture = serverBootstrap.bind(config.bindAddress)

override def initialize(): Task[Unit] =
toTask(serverBind).map(_ => log.info(s"Server bound to address ${config.bindAddress}"))

override def processAddress: InetMultiAddress = config.processAddress

override def client(to: InetMultiAddress): Task[Channel[InetMultiAddress, M]] = {
val cf = bootstrap.connect(to.inetSocketAddress)
val cf = clientBootstrap.connect(to.inetSocketAddress)
val ct: Task[NioDatagramChannel] = toTask(cf).map(_ => cf.channel().asInstanceOf[NioDatagramChannel])
ct.map(nettyChannel => new ChannelImpl(nettyChannel, Promise().complete(Success(to))))
ct.map { nettyChannel =>
val localAddress = nettyChannel.localAddress()
log.debug(s"Generated local address for new client is $localAddress")
val channelId = getChannelId(to.inetSocketAddress, localAddress)
if (activeChannels.contains(channelId)) {
log.warn(s"HOUSTON, WE HAVE A MULTIPLEXING PROBLEM")
}
val channel = new ChannelImpl(nettyChannel, localAddress, to.inetSocketAddress, ReplaySubject[M]())
activeChannels.put(channelId, channel)
channel
}
}

override def server(): Observable[Channel[InetMultiAddress, M]] = channelSubject
@@ -72,77 +173,27 @@ class UDPPeerGroup[M](val config: Config)(implicit scheduler: Scheduler, codec:
} yield ()
}

private class ChannelImpl(
val nettyChannel: NioDatagramChannel,
promisedRemoteAddress: Promise[InetMultiAddress]
)(implicit codec: Codec[M])
extends ChannelInboundHandlerAdapter
with Channel[InetMultiAddress, M] {

nettyChannel.pipeline().addLast(this)

private val messageSubjectF: Future[Subject[M, M]] = for {
remoteAddress <- promisedRemoteAddress.future
} yield {
log.debug(
s"New channel created with id ${nettyChannel.id()} from ${nettyChannel.localAddress()} to $remoteAddress"
)
activeChannels.getOrElseUpdate(nettyChannel.id, ReplaySubject[M]())
}

override def to: InetMultiAddress = Await.result(promisedRemoteAddress.future, Duration.Inf)

override def sendMessage(message: M): Task[Unit] =
for {
remoteAddress <- Task.fromFuture(promisedRemoteAddress.future)
sendResult <- sendMessage(message, remoteAddress, nettyChannel)
} yield sendResult

override def in: Observable[M] = Await.result(messageSubjectF, Duration.Inf)

override def close(): Task[Unit] = {
activeChannels.remove(nettyChannel.id)
Task.unit
}

override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {
val datagram = msg.asInstanceOf[DatagramPacket]
val remoteAddress = datagram.sender()
if (!promisedRemoteAddress.isCompleted) {
promisedRemoteAddress.complete(Success(new InetMultiAddress(remoteAddress)))
channelSubject.onNext(this)
}

codec.decode(datagram.content().nioBuffer().asReadOnlyBuffer()).map { m =>
messageSubjectF.foreach { messageSubscribers =>
log.debug(
s"Processing inbound message from remote address remote ${ctx.channel().remoteAddress()} " +
s"to local address ${ctx.channel().localAddress()} via channel id ChannelId ${ctx.channel().id()}."
)
messageSubscribers.onNext(m)
}
}
}

private def sendMessage(message: M, to: InetMultiAddress, nettyChannel: NioDatagramChannel): Task[Unit] = {
val nettyBuffer = Unpooled.wrappedBuffer(codec.encode(message))
toTask(nettyChannel.writeAndFlush(new DatagramPacket(nettyBuffer, to.inetSocketAddress, processAddress.inetSocketAddress)))
}

}

private def toTask(f: util.concurrent.Future[_]): Task[Unit] = {
val promisedCompletion = Promise[Unit]()
f.addListener((_: util.concurrent.Future[_]) => promisedCompletion.complete(Success(())))
Task.fromFuture(promisedCompletion.future)
}

private def getChannelId(remoteAddress: InetSocketAddress, localAddress: InetSocketAddress): Seq[Byte] = {
val b = ByteBuffer.allocate(16)
b.put(remoteAddress.getAddress.getAddress)
b.putInt(remoteAddress.getPort)
b.put(localAddress.getAddress.getAddress)
b.putInt(localAddress.getPort)
b.array().toIndexedSeq
}
}

object UDPPeerGroup {

case class Config(bindAddress: InetSocketAddress, processAddress: InetMultiAddress)

object Config {
def apply(bindAddress: InetSocketAddress): Config = Config(bindAddress, new InetMultiAddress(bindAddress))
def apply(bindAddress: InetSocketAddress): Config = Config(bindAddress, InetMultiAddress(bindAddress))
}
}
@@ -162,7 +162,7 @@ class SimplePeerGroupSpec extends FlatSpec {
// }

trait SimpleTerminalPeerGroups {
val terminalPeerGroups = List(TcpTerminalPeerGroup/* UdpTerminalPeerGroup*/)
val terminalPeerGroups = List(TcpTerminalPeerGroup /* UdpTerminalPeerGroup*/ )
}

private def withASimplePeerGroup(
Oops, something went wrong.

0 comments on commit 96ab987

Please sign in to comment.
You can’t perform that action at this time.