Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change clientAuthMode to an enum #2351

Merged
merged 4 commits into from
Jan 13, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ class BlazeBuilder[F[_]](
keyManagerPassword: String,
protocol: String = "TLS",
trustStore: Option[StoreInfo] = None,
clientAuth: Boolean = false): Self = {
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested): Self = {
val bits = KeyStoreBits(keyStore, keyManagerPassword, protocol, trustStore, clientAuth)
copy(sslBits = Some(bits))
}

def withSSLContext(sslContext: SSLContext, clientAuth: Boolean = false): Self =
def withSSLContext(
sslContext: SSLContext,
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested): Self =
copy(sslBits = Some(SSLContextBits(sslContext, clientAuth)))

override def bindSocketAddress(socketAddress: InetSocketAddress): Self =
Expand Down Expand Up @@ -183,7 +185,7 @@ class BlazeBuilder[F[_]](
b.resource
}

private def getContext(): Option[(SSLContext, Boolean)] = sslBits.map {
private def getContext(): Option[(SSLContext, SSLClientAuthMode)] = sslBits.map {
case KeyStoreBits(keyStore, keyManagerPassword, protocol, trustStore, clientAuth) =>
val ksStream = new FileInputStream(keyStore.path)
val ks = KeyStore.getInstance("JKS")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,14 @@ class BlazeServerBuilder[F[_]](
keyManagerPassword: String,
protocol: String = "TLS",
trustStore: Option[StoreInfo] = None,
clientAuth: Boolean = false): Self = {
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested): Self = {
val bits = KeyStoreBits(keyStore, keyManagerPassword, protocol, trustStore, clientAuth)
copy(sslBits = Some(bits))
}

def withSSLContext(sslContext: SSLContext, clientAuth: Boolean = false): Self =
def withSSLContext(
sslContext: SSLContext,
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested): Self =
copy(sslBits = Some(SSLContextBits(sslContext, clientAuth)))

override def bindSocketAddress(socketAddress: InetSocketAddress): Self =
Expand Down Expand Up @@ -264,7 +266,18 @@ class BlazeServerBuilder[F[_]](
case Some((ctx, clientAuth)) =>
val engine = ctx.createSSLEngine()
engine.setUseClientMode(false)
engine.setNeedClientAuth(clientAuth)

clientAuth match {
case SSLClientAuthMode.NotRequested =>
engine.setWantClientAuth(false)
engine.setNeedClientAuth(false)

case SSLClientAuthMode.Requested =>
engine.setWantClientAuth(true)

case SSLClientAuthMode.Required =>
engine.setNeedClientAuth(true)
}

LeafBuilder(
if (isHttp2Enabled) http2Stage(engine)
Expand Down Expand Up @@ -317,7 +330,7 @@ class BlazeServerBuilder[F[_]](
})
}

private def getContext(): Option[(SSLContext, Boolean)] = sslBits.map {
private def getContext(): Option[(SSLContext, SSLClientAuthMode)] = sslBits.map {
case KeyStoreBits(keyStore, keyManagerPassword, protocol, trustStore, clientAuth) =>
val ksStream = new FileInputStream(keyStore.path)
val ks = KeyStore.getInstance("JKS")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import java.security.KeyStore
import cats.effect.{IO, Resource}
import javax.net.ssl._
import org.http4s.dsl.io._
import org.http4s.server.Server
import org.http4s.server.ServerRequestKeys
import org.http4s.server.{SSLClientAuthMode, Server, ServerRequestKeys}
import org.http4s.{Http4sSpec, HttpApp}

import scala.concurrent.duration._
import scala.io.Source
import scala.util.Try

/**
* Test cases for mTLS support in blaze server
Expand Down Expand Up @@ -48,13 +48,25 @@ class BlazeServerMtlsSpec extends Http4sSpec {

Ok(output)

case req @ GET -> Root / "noauth" =>
req
.attributes(ServerRequestKeys.SecureSession)
.foreach { session =>
session.sslSessionId shouldNotEqual ""
session.cipherSuite shouldNotEqual ""
session.keySize shouldNotEqual 0
session.X509Certificate.size shouldEqual 0
}

Ok("success")

case _ => NotFound()
}

val serverR: Resource[IO, Server[IO]] =
def serverR(clientAuthMode: SSLClientAuthMode): Resource[IO, Server[IO]] =
builder
.bindAny()
.withSSLContext(sslContext, clientAuth = true)
.withSSLContext(sslContext, clientAuth = clientAuthMode)
.withHttpApp(service)
.resource

Expand All @@ -77,20 +89,93 @@ class BlazeServerMtlsSpec extends Http4sSpec {
sc
}

withResource(serverR) { server =>
def get(path: String): String = {
/**
* Used for no mTLS client. Required to trust self-signed certificate.
*/
lazy val noAuthClientContext: SSLContext = {

val js = KeyStore.getInstance("JKS")
js.load(getClass.getResourceAsStream("/keystore.jks"), "password".toCharArray)

val tmf = TrustManagerFactory.getInstance("SunX509")
tmf.init(js)

val sc = SSLContext.getInstance("TLSv1.2")
sc.init(null, tmf.getTrustManagers, null)

sc
}

/**
* Test "required" auth mode
*/
withResource(serverR(SSLClientAuthMode.Required)) { server =>
def get(path: String, clientAuth: Boolean = true): String = {
val url = new URL(s"https://localhost:${server.address.getPort}$path")
val conn = url.openConnection().asInstanceOf[HttpsURLConnection]
conn.setRequestMethod("GET")
conn.setSSLSocketFactory(sslContext.getSocketFactory)

Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name).getLines.mkString
if (clientAuth) {
conn.setSSLSocketFactory(sslContext.getSocketFactory)
} else {
conn.setSSLSocketFactory(noAuthClientContext.getSocketFactory)
}

Try {
Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name).getLines.mkString
}.recover {
case ex: Throwable =>
ex.getMessage
}
.toOption
.getOrElse("")
}

"Server" should {
"send mTLS request correctly" in {
get("/dummy") shouldEqual "CN=Test,OU=Test,O=Test,L=CA,ST=CA,C=US"
}

"fail for invalid client auth" in {
get("/dummy", clientAuth = false) shouldEqual "Connection reset"
}
}
}

/**
* Test "requested" auth mode
*/
withResource(serverR(SSLClientAuthMode.Requested)) { server =>
def get(path: String, clientAuth: Boolean = true): String = {
val url = new URL(s"https://localhost:${server.address.getPort}$path")
val conn = url.openConnection().asInstanceOf[HttpsURLConnection]
conn.setRequestMethod("GET")

if (clientAuth) {
conn.setSSLSocketFactory(sslContext.getSocketFactory)
} else {
conn.setSSLSocketFactory(noAuthClientContext.getSocketFactory)
}

Try {
Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name).getLines.mkString
}.recover {
case ex: Throwable =>
ex.getMessage
}
.toOption
.getOrElse("")
}

"Server" should {

"send mTLS request correctly with optional auth" in {
get("/dummy") shouldEqual "CN=Test,OU=Test,O=Test,L=CA,ST=CA,C=US"
}

"send mTLS request correctly without clientAuth" in {
get("/noauth", clientAuth = false) shouldEqual "success"
}
}
}
}
25 changes: 21 additions & 4 deletions jetty/src/main/scala/org/http4s/server/jetty/JettyBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ sealed class JettyBuilder[F[_]] private (
keyManagerPassword: String,
protocol: String = "TLS",
trustStore: Option[StoreInfo] = None,
clientAuth: Boolean = false
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested
): Self =
copy(
sslBits = Some(KeyStoreBits(keyStore, keyManagerPassword, protocol, trustStore, clientAuth)))

def withSSLContext(sslContext: SSLContext, clientAuth: Boolean = false): Self =
def withSSLContext(
sslContext: SSLContext,
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested): Self =
copy(sslBits = Some(SSLContextBits(sslContext, clientAuth)))

override def bindSocketAddress(socketAddress: InetSocketAddress): Self =
Expand Down Expand Up @@ -151,8 +153,8 @@ sealed class JettyBuilder[F[_]] private (
sslContextFactory.setKeyStorePath(keyStore.path)
sslContextFactory.setKeyStorePassword(keyStore.password)
sslContextFactory.setKeyManagerPassword(keyManagerPassword)
sslContextFactory.setNeedClientAuth(clientAuth)
sslContextFactory.setProtocol(protocol)
updateClientAuth(sslContextFactory, clientAuth)

trustStore.foreach { trustManagerBits =>
sslContextFactory.setTrustStorePath(trustManagerBits.path)
Expand All @@ -164,7 +166,7 @@ sealed class JettyBuilder[F[_]] private (
case Some(SSLContextBits(sslContext, clientAuth)) =>
val sslContextFactory = new SslContextFactory()
sslContextFactory.setSslContext(sslContext)
sslContextFactory.setNeedClientAuth(clientAuth)
updateClientAuth(sslContextFactory, clientAuth)

httpsConnector(sslContextFactory)

Expand All @@ -173,6 +175,21 @@ sealed class JettyBuilder[F[_]] private (
}
}

private def updateClientAuth(
sslContextFactory: SslContextFactory,
clientAuthMode: SSLClientAuthMode): Unit =
clientAuthMode match {
case SSLClientAuthMode.NotRequested =>
sslContextFactory.setWantClientAuth(false)
sslContextFactory.setNeedClientAuth(false)

case SSLClientAuthMode.Requested =>
sslContextFactory.setWantClientAuth(true)

case SSLClientAuthMode.Required =>
sslContextFactory.setNeedClientAuth(true)
}

def resource: Resource[F, Server[F]] =
Resource(F.delay {
val jetty = new JServer(threadPool)
Expand Down
12 changes: 12 additions & 0 deletions server/src/main/scala/org/http4s/server/SSLClientAuthMode.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.http4s.server

/**
* Client Auth mode for mTLS
*/
sealed trait SSLClientAuthMode extends Product with Serializable

object SSLClientAuthMode {
case object NotRequested extends SSLClientAuthMode
case object Requested extends SSLClientAuthMode
case object Required extends SSLClientAuthMode
}
5 changes: 3 additions & 2 deletions server/src/main/scala/org/http4s/server/ServerBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ final case class KeyStoreBits(
keyManagerPassword: String,
protocol: String,
trustStore: Option[StoreInfo],
clientAuth: Boolean)
clientAuth: SSLClientAuthMode)
extends SSLConfig

final case class SSLContextBits(sslContext: SSLContext, clientAuth: Boolean) extends SSLConfig
final case class SSLContextBits(sslContext: SSLContext, clientAuth: SSLClientAuthMode)
extends SSLConfig

object SSLKeyStoreSupport {
final case class StoreInfo(path: String, password: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ sealed class TomcatBuilder[F[_]] private (
keyManagerPassword: String,
protocol: String = "TLS",
trustStore: Option[StoreInfo] = None,
clientAuth: Boolean = false): Self =
clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested): Self =
copy(
sslBits = Some(KeyStoreBits(keyStore, keyManagerPassword, protocol, trustStore, clientAuth)))

Expand Down