Skip to content

Commit

Permalink
Change signature of pipeline factories to accept a SocketConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
bryce-anderson committed May 26, 2014
1 parent 760ca2d commit dee6f09
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.http4s.blaze.channel

import java.net.{InetAddress, InetSocketAddress, SocketAddress}

/**
* Created by Bryce Anderson on 5/26/14.
*/
trait SocketConnection {
/** Return the SocketAddress of the remote connection */
def remote: SocketAddress

/** Return the local SocketAddress associated with the connection */
def local: SocketAddress

/** Return of the connection is currently open */
def isOpen: Boolean

/** Close this Connection */
def close(): Unit

final def remoteInetAddress: Option[InetAddress] = remote match {
case addr: InetSocketAddress => Option(addr.getAddress)
case _ => None
}

final def localInetAddress: Option[InetAddress] = local match {
case addr: InetSocketAddress => Option(addr.getAddress)
case _ => None
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.http4s.blaze.channel.nio1

import java.nio.channels.{SelectableChannel, SocketChannel}
import org.http4s.blaze.channel.SocketConnection
import java.net.SocketAddress

/**
* Created by Bryce Anderson on 5/26/14.
*/

object NIO1Connection {
def apply(connection: SelectableChannel): SocketConnection = connection match {
case ch: SocketChannel => NIO1SocketConnection(ch)
case _ =>
// We don't know what type this is, so implement what we can
new SocketConnection {
override def remote: SocketAddress = local

override def local: SocketAddress = new SocketAddress {}

override def close(): Unit = connection.close()

override def isOpen: Boolean = connection.isOpen
}
}

}

case class NIO1SocketConnection(connection: SocketChannel) extends SocketConnection {

override def remote: SocketAddress = connection.getRemoteAddress

override def local: SocketAddress = connection.getLocalAddress

override def isOpen: Boolean = connection.isConnected

override def close(): Unit = connection.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.RejectedExecutionException

import org.http4s.blaze.pipeline._
import org.http4s.blaze.channel.BufferPipelineBuilder

/**
* @author Bryce Anderson
Expand Down Expand Up @@ -174,7 +175,7 @@ final class SelectorLoop(selector: Selector, bufferSize: Int)

def wakeup(): Unit = selector.wakeup()

def initChannel(builder: () => LeafBuilder[ByteBuffer], ch: SelectableChannel, mkStage: SelectionKey => NIO1HeadStage) {
def initChannel(builder: BufferPipelineBuilder, ch: SelectableChannel, mkStage: SelectionKey => NIO1HeadStage) {
enqueTask( new Runnable {
def run() {
try {
Expand All @@ -185,7 +186,7 @@ final class SelectorLoop(selector: Selector, bufferSize: Int)
key.attach(head)

// construct the pipeline
builder().base(head)
builder(NIO1Connection(ch)).base(head)

head.inboundCommand(Command.Connect)
logger.trace("Started channel.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ class SocketServerChannelFactory(pipeFactory: BufferPipelineBuilder, pool: Selec

import SocketServerChannelFactory.brokePipeMessages

def this(pipeFactory: BufferPipelineBuilder, workerThreads: Int = 8, bufferSize: Int = 4*1024) = {
def this(pipeFactory: BufferPipelineBuilder, workerThreads: Int = 8, bufferSize: Int = 4*1024) =
this(pipeFactory, new FixedArraySelectorPool(workerThreads, bufferSize))
}

//////////////// End of constructors /////////////////////////////////////////////////////////

def doBind(address: SocketAddress): ServerSocketChannel = ServerSocketChannel.open().bind(address)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import com.typesafe.scalalogging.slf4j.LazyLogging
* @author Bryce Anderson
* Created on 1/4/14
*/
class NIO2ServerChannelFactory(pipeFactory: () => LeafBuilder[ByteBuffer], group: AsynchronousChannelGroup = null)
class NIO2ServerChannelFactory(pipeFactory: BufferPipelineBuilder, group: AsynchronousChannelGroup = null)
extends ServerChannelFactory[AsynchronousServerSocketChannel] with LazyLogging {

// Intended to be overridden in order to allow the reject of connections
Expand Down Expand Up @@ -48,7 +48,7 @@ class NIO2ServerChannelFactory(pipeFactory: () => LeafBuilder[ByteBuffer], group
}
else {
logger.trace(s"Connection to ${ch.getRemoteAddress} accepted at ${new Date}")
pipeFactory().base(new ByteBufferHead(ch)).sendInboundCommand(Connect)
pipeFactory(NIO2SocketConnection(ch)).base(new ByteBufferHead(ch)).sendInboundCommand(Connect)
}

} catch {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.http4s.blaze.channel.nio2

import org.http4s.blaze.channel.SocketConnection
import java.nio.channels.AsynchronousSocketChannel
import java.net.SocketAddress

/**
* Created by Bryce Anderson on 5/26/14.
*/
case class NIO2SocketConnection(connection: AsynchronousSocketChannel) extends SocketConnection {

override def remote: SocketAddress = connection.getRemoteAddress

override def local: SocketAddress = connection.getLocalAddress

override def isOpen: Boolean = connection.isOpen

override def close(): Unit = connection.close()
}
2 changes: 1 addition & 1 deletion core/src/main/scala/org/http4s/blaze/channel/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import org.http4s.blaze.pipeline.LeafBuilder
* Created on 1/5/14
*/
package object channel {
type BufferPipelineBuilder = () => LeafBuilder[ByteBuffer]
type BufferPipelineBuilder = SocketConnection => LeafBuilder[ByteBuffer]
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ class ChannelSpec extends Specification {
"Channels" should {

"Bind the port and then be closed" in {
val channel = new BasicServer(() => new EchoStage).prepare(new InetSocketAddress(0))
val channel = new BasicServer(_ => new EchoStage).prepare(new InetSocketAddress(0))
channel.close()
true should_== true
}

"Execute shutdown hooks" in {
val i = new AtomicInteger(0)
val channel = new BasicServer(() => new EchoStage).prepare(new InetSocketAddress(0))
val channel = new BasicServer(_ => new EchoStage).prepare(new InetSocketAddress(0))
channel.addShutdownHook{ () => i.incrementAndGet() }
val t = channel.runAsync()
channel.close()
Expand All @@ -53,7 +53,7 @@ class ChannelSpec extends Specification {

"Execute shutdown hooks when one throws an exception" in {
val i = new AtomicInteger(0)
val channel = new BasicServer(() => new EchoStage).prepare(new InetSocketAddress(0))
val channel = new BasicServer(_ => new EchoStage).prepare(new InetSocketAddress(0))
channel.addShutdownHook{ () => i.incrementAndGet() }
channel.addShutdownHook{ () => sys.error("Foo") }
channel.addShutdownHook{ () => i.incrementAndGet() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.typesafe.scalalogging.slf4j.StrictLogging
class EchoServer extends StrictLogging {

def prepare(address: InetSocketAddress): ServerChannel = {
val f: BufferPipelineBuilder = () => new EchoStage
val f: BufferPipelineBuilder = _ => new EchoStage

val factory = new NIO2ServerChannelFactory(f)
factory.bind(address)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.http4s.blaze.pipeline.LeafBuilder
*/
class HttpServer(port: Int) {

private val f: BufferPipelineBuilder = () => LeafBuilder(new ExampleHttpServerStage(10*1024))
private val f: BufferPipelineBuilder = _ => LeafBuilder(new ExampleHttpServerStage(10*1024))

val group = AsynchronousChannelGroup.withFixedThreadPool(10, java.util.concurrent.Executors.defaultThreadFactory())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.http4s.blaze.pipeline.LeafBuilder
*/
class NIO1HttpServer(port: Int) {

private val f: BufferPipelineBuilder = () => LeafBuilder(new ExampleHttpServerStage(10*1024))
private val f: BufferPipelineBuilder = _ => LeafBuilder(new ExampleHttpServerStage(10*1024))

private val factory = new SocketServerChannelFactory(f, workerThreads = 6)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SSLHttpServer(port: Int) {
}


private val f: BufferPipelineBuilder = { () =>
private val f: BufferPipelineBuilder = { _ =>
val eng = sslContext.createSSLEngine()
eng.setUseClientMode(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.http4s.blaze.channel.nio2.NIO2ServerChannelFactory
* Created on 1/18/14
*/
class WebSocketServer(port: Int) {
private val f: BufferPipelineBuilder = () => LeafBuilder(new ExampleWebSocketHttpServerStage)
private val f: BufferPipelineBuilder = _ => LeafBuilder(new ExampleWebSocketHttpServerStage)

val group = AsynchronousChannelGroup.withFixedThreadPool(10, java.util.concurrent.Executors.defaultThreadFactory())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SpdyServer(port: Int) {
}


private val f: BufferPipelineBuilder = { () =>
private val f: BufferPipelineBuilder = { _ =>
val eng = sslContext.createSSLEngine()
eng.setUseClientMode(false)

Expand Down

0 comments on commit dee6f09

Please sign in to comment.