Skip to content

Commit

Permalink
Added binary messages
Browse files Browse the repository at this point in the history
  • Loading branch information
emstlk committed Apr 18, 2014
1 parent 66bbe52 commit e68d9f2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
12 changes: 9 additions & 3 deletions src/main/scala/com/twitter/finagle/HttpWebSocket.scala
Expand Up @@ -11,10 +11,16 @@ import java.net.{SocketAddress, URI}

trait WebSocketRichClient {
def open(out: Offer[String], uri: String): Future[WebSocket] =
open(out, new URI(uri))
open(out, Offer.never, new URI(uri))

def open(out: Offer[String], uri: URI): Future[WebSocket] = {
val socket = WebSocket(messages = out, uri = uri)
def open(out: Offer[String], uri: URI): Future[WebSocket] =
open(out, Offer.never, uri)

def open(out: Offer[String], binaryOut: Offer[Array[Byte]], uri: String): Future[WebSocket] =
open(out, binaryOut, new URI(uri))

def open(out: Offer[String], binaryOut: Offer[Array[Byte]], uri: URI): Future[WebSocket] = {
val socket = WebSocket(messages = out, binaryMessages = binaryOut, uri = uri)
val addr = uri.getHost + ":" + uri.getPort
HttpWebSocket.newClient(addr).toService(socket)
}
Expand Down
Expand Up @@ -8,6 +8,7 @@ import java.net.SocketAddress

case class WebSocket(
messages: Offer[String],
binaryMessages: Offer[Array[Byte]],
uri: URI,
headers: Map[String, String] = Map.empty[String, String],
remoteAddress: SocketAddress = new SocketAddress {},
Expand Down
Expand Up @@ -6,7 +6,7 @@ import com.twitter.finagle.netty3.Conversions._
import com.twitter.finagle.netty3.{Cancelled, Ok, Error}
import com.twitter.util.{Promise, Return, Throw, Try}
import java.net.{URI, InetSocketAddress}
import org.jboss.netty.buffer.ChannelBuffer
import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.http.websocketx._
import org.jboss.netty.handler.codec.http.{
Expand All @@ -15,6 +15,7 @@ import scala.collection.JavaConversions._

class WebSocketHandler extends SimpleChannelHandler {
protected[this] val messagesBroker = new Broker[String]
protected[this] val binaryMessagesBroker = new Broker[Array[Byte]]
protected[this] val closer = new Promise[Unit]

protected[this] def write(
Expand All @@ -37,12 +38,22 @@ class WebSocketHandler extends SimpleChannelHandler {
}

case None =>
sock.messages { message =>
val frame = new TextWebSocketFrame(message)
val writeFuture = Channels.future(ctx.getChannel)
Channels.write(ctx, writeFuture, frame)
write(ctx, sock, Some(writeFuture.toTwitterFuture.toOffer))
}
Offer.choose(
sock.messages {
message =>
val frame = new TextWebSocketFrame(message)
val writeFuture = Channels.future(ctx.getChannel)
Channels.write(ctx, writeFuture, frame)
write(ctx, sock, Some(writeFuture.toTwitterFuture.toOffer))
},
sock.binaryMessages {
binary =>
val frame = new BinaryWebSocketFrame(ChannelBuffers.wrappedBuffer(binary))
val writeFuture = Channels.future(ctx.getChannel)
Channels.write(ctx, writeFuture, frame)
write(ctx, sock, Some(writeFuture.toTwitterFuture.toOffer))
}
)
}
awaitAck.sync()
}
Expand Down Expand Up @@ -71,6 +82,7 @@ class WebSocketServerHandler extends WebSocketHandler {

val webSocket = WebSocket(
messages = messagesBroker.recv,
binaryMessages = binaryMessagesBroker.recv,
uri = new URI(req.getUri),
headers = req.getHeaderNames().map(name => name -> req.getHeader(name)).toMap,
remoteAddress = ctx.getChannel.getRemoteAddress,
Expand All @@ -91,6 +103,11 @@ class WebSocketServerHandler extends WebSocketHandler {
ch.setReadable(false)
(messagesBroker ! frame.getText) ensure { ch.setReadable(true) }

case frame: BinaryWebSocketFrame =>
val ch = ctx.getChannel
ch.setReadable(false)
(binaryMessagesBroker ! frame.getBinaryData.array) ensure { ch.setReadable(true) }

case invalid =>
Channels.fireExceptionCaught(ctx,
new IllegalArgumentException("invalid message \"%s\"".format(invalid)))
Expand Down Expand Up @@ -136,6 +153,11 @@ class WebSocketClientHandler extends WebSocketHandler {
ch.setReadable(false)
(messagesBroker ! frame.getText) ensure { ch.setReadable(true) }

case frame: BinaryWebSocketFrame =>
val ch = ctx.getChannel
ch.setReadable(false)
(binaryMessagesBroker ! frame.getBinaryData.array) ensure { ch.setReadable(true) }

case invalid =>
Channels.fireExceptionCaught(ctx,
new IllegalArgumentException("invalid message \"%s\"".format(invalid)))
Expand All @@ -152,6 +174,7 @@ class WebSocketClientHandler extends WebSocketHandler {

val webSocket = sock.copy(
messages = messagesBroker.recv,
binaryMessages = binaryMessagesBroker.recv,
onClose = closer,
close = close)

Expand Down
22 changes: 16 additions & 6 deletions src/test/scala/com/twitter/finagle/websocket/EndToEndTest.scala
Expand Up @@ -8,22 +8,29 @@ import com.twitter.concurrent.Broker
import com.twitter.finagle.{HttpWebSocket, Service}
import com.twitter.util._
import java.net.InetSocketAddress
import scala.collection.mutable.ArrayBuffer

@RunWith(classOf[JUnitRunner])
class EndToEndTest extends FunSuite {
test("multi client") {
var result = ""
val binaryResult = ArrayBuffer.empty[Byte]
val addr = RandomSocket()
val latch = new CountDownLatch(5)
val latch = new CountDownLatch(10)

val server = HttpWebSocket.serve(addr, new Service[WebSocket, WebSocket] {
def apply(req: WebSocket): Future[WebSocket] = {
val outgoing = new Broker[String]
val socket = req.copy(messages = outgoing.recv)
val binaryOutgoing = new Broker[Array[Byte]]
val socket = req.copy(messages = outgoing.recv, binaryMessages = binaryOutgoing.recv)
req.messages foreach { msg =>
synchronized { result += msg }
latch.countDown()
}
req.binaryMessages foreach { binary =>
synchronized { binaryResult ++= binary }
latch.countDown()
}
Future.value(socket)
}
})
Expand All @@ -32,15 +39,18 @@ class EndToEndTest extends FunSuite {

val brokers = (0 until 5) map { _ =>
val out = new Broker[String]
Await.ready(HttpWebSocket.open(out.recv, target))
out
val binaryOut = new Broker[Array[Byte]]
Await.ready(HttpWebSocket.open(out.recv, binaryOut.recv, target))
(out, binaryOut)
}

brokers foreach { out =>
FuturePool.unboundedPool { out !! "1" }
brokers foreach { pair =>

This comment has been minimized.

Copy link
@sprsquish

sprsquish Apr 18, 2014

match out the pair to name the streams for clarity.

FuturePool.unboundedPool { pair._1 !! "1" }
FuturePool.unboundedPool { pair._2 !! Array[Byte](0x01) }
}

latch.within(1.second)
assert(result === "11111")
assert(binaryResult === ArrayBuffer(0x01, 0x01, 0x01, 0x01, 0x01))
}
}

0 comments on commit e68d9f2

Please sign in to comment.