Skip to content
This repository has been archived by the owner on Dec 3, 2019. It is now read-only.

Commit

Permalink
Add SSL support..
Browse files Browse the repository at this point in the history
SSL is disabled by default to avoid POLA violations.
It is possible to enable and control SSL behavior via url parameters:
- `sslmode=<mode>` enable ssl (prefer/require/verify-ca/verify-full [recommended])
- `sslrootcert=<path.pem>` specifies trusted certificates (JDK cacert if missing)

Client certificate authentication is not implemented, due to lack of
time and interest, but it should be easy to add.
  • Loading branch information
alexdupre committed Mar 7, 2016
1 parent c3747b5 commit 0f9a587
Show file tree
Hide file tree
Showing 21 changed files with 364 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object Configuration {
* @param port database port, defaults to 5432
* @param password password, defaults to no password
* @param database database name, defaults to no database
* @param ssl ssl configuration
* @param charset charset for the connection, defaults to UTF-8, make sure you know what you are doing if you
* change this
* @param maximumMessageSize the maximum size a message from the server could possibly have, this limits possible
Expand All @@ -55,6 +56,7 @@ case class Configuration(username: String,
port: Int = 5432,
password: Option[String] = None,
database: Option[String] = None,
ssl: SSLConfiguration = SSLConfiguration(),
charset: Charset = Configuration.DefaultCharset,
maximumMessageSize: Int = 16777216,
allocator: ByteBufAllocator = PooledByteBufAllocator.DEFAULT,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.github.mauricio.async.db

import java.io.File

import SSLConfiguration.Mode

/**
*
* Contains the SSL configuration necessary to connect to a database.
*
* @param mode whether and with what priority a SSL connection will be negotiated, default disabled
* @param rootCert path to PEM encoded trusted root certificates, None to use internal JDK cacerts, defaults to None
*
*/
case class SSLConfiguration(mode: Mode.Value = Mode.Disable, rootCert: Option[java.io.File] = None)

object SSLConfiguration {

object Mode extends Enumeration {
val Disable = Value("disable") // only try a non-SSL connection
val Prefer = Value("prefer") // first try an SSL connection; if that fails, try a non-SSL connection
val Require = Value("require") // only try an SSL connection, but don't verify Certificate Authority
val VerifyCA = Value("verify-ca") // only try an SSL connection, and verify that the server certificate is issued by a trusted certificate authority (CA)
val VerifyFull = Value("verify-full") // only try an SSL connection, verify that the server certificate is issued by a trusted CA and that the server host name matches that in the certificate
}

def apply(properties: Map[String, String]): SSLConfiguration = SSLConfiguration(
mode = Mode.withName(properties.get("sslmode").getOrElse("disable")),
rootCert = properties.get("sslrootcert").map(new File(_))
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.github.mauricio.async.db.postgresql.codec

import com.github.mauricio.async.db.postgresql.exceptions.{MessageTooLongException}
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
import com.github.mauricio.async.db.postgresql.messages.backend.{ServerMessage, SSLResponseMessage}
import com.github.mauricio.async.db.postgresql.parsers.{AuthenticationStartupParser, MessageParsersRegistry}
import com.github.mauricio.async.db.util.{BufferDumper, Log}
import java.nio.charset.Charset
Expand All @@ -31,15 +31,21 @@ object MessageDecoder {
val DefaultMaximumSize = 16777216
}

class MessageDecoder(charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
class MessageDecoder(sslEnabled: Boolean, charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {

import MessageDecoder.log

private val parser = new MessageParsersRegistry(charset)

private var sslChecked = false

override def decode(ctx: ChannelHandlerContext, b: ByteBuf, out: java.util.List[Object]): Unit = {

if (b.readableBytes() >= 5) {
if (sslEnabled & !sslChecked) {
val code = b.readByte()
sslChecked = true
out.add(new SSLResponseMessage(code == 'S'))
} else if (b.readableBytes() >= 5) {

b.markReaderIndex()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e
override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = {

val buffer = msg match {
case SSLRequestMessage => SSLMessageEncoder.encode()
case message: StartupMessage => startupEncoder.encode(message)
case message: ClientMessage => {
val encoder = (message.kind: @switch) match {
case ServerMessage.Close => CloseMessageEncoder
case ServerMessage.Execute => this.executeEncoder
case ServerMessage.Parse => this.openEncoder
case ServerMessage.Startup => this.startupEncoder
case ServerMessage.Query => this.queryEncoder
case ServerMessage.PasswordMessage => this.credentialEncoder
case _ => throw new EncoderNotAvailableException(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.github.mauricio.async.db.postgresql.codec

import com.github.mauricio.async.db.Configuration
import com.github.mauricio.async.db.SSLConfiguration.Mode
import com.github.mauricio.async.db.column.{ColumnDecoderRegistry, ColumnEncoderRegistry}
import com.github.mauricio.async.db.postgresql.exceptions._
import com.github.mauricio.async.db.postgresql.messages.backend._
Expand All @@ -38,6 +39,12 @@ import com.github.mauricio.async.db.postgresql.messages.backend.RowDescriptionMe
import com.github.mauricio.async.db.postgresql.messages.backend.ParameterStatusMessage
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.CodecException
import io.netty.handler.ssl.{SslContextBuilder, SslHandler}
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
import io.netty.util.concurrent.FutureListener
import javax.net.ssl.{SSLParameters, TrustManagerFactory}
import java.security.KeyStore
import java.io.FileInputStream

object PostgreSQLConnectionHandler {
final val log = Log.get[PostgreSQLConnectionHandler]
Expand Down Expand Up @@ -79,7 +86,7 @@ class PostgreSQLConnectionHandler

override def initChannel(ch: channel.Channel): Unit = {
ch.pipeline.addLast(
new MessageDecoder(configuration.charset, configuration.maximumMessageSize),
new MessageDecoder(configuration.ssl.mode != Mode.Disable, configuration.charset, configuration.maximumMessageSize),
new MessageEncoder(configuration.charset, encoderRegistry),
PostgreSQLConnectionHandler.this)
}
Expand Down Expand Up @@ -120,13 +127,61 @@ class PostgreSQLConnectionHandler
}

override def channelActive(ctx: ChannelHandlerContext): Unit = {
ctx.writeAndFlush(new StartupMessage(this.properties))
if (configuration.ssl.mode == Mode.Disable)
ctx.writeAndFlush(new StartupMessage(this.properties))
else
ctx.writeAndFlush(SSLRequestMessage)
}

override def channelRead0(ctx: ChannelHandlerContext, msg: Object): Unit = {

msg match {

case SSLResponseMessage(supported) =>
if (supported) {
val ctxBuilder = SslContextBuilder.forClient()
if (configuration.ssl.mode >= Mode.VerifyCA) {
configuration.ssl.rootCert.fold {
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val ks = KeyStore.getInstance(KeyStore.getDefaultType())
val cacerts = new FileInputStream(System.getProperty("java.home") + "/lib/security/cacerts")
try {
ks.load(cacerts, "changeit".toCharArray)
} finally {
cacerts.close()
}
tmf.init(ks)
ctxBuilder.trustManager(tmf)
} { path =>
ctxBuilder.trustManager(path)
}
} else {
ctxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE)
}
val sslContext = ctxBuilder.build()
val sslEngine = sslContext.newEngine(ctx.alloc(), configuration.host, configuration.port)
if (configuration.ssl.mode >= Mode.VerifyFull) {
val sslParams = sslEngine.getSSLParameters()
sslParams.setEndpointIdentificationAlgorithm("HTTPS")
sslEngine.setSSLParameters(sslParams)
}
val handler = new SslHandler(sslEngine)
ctx.pipeline().addFirst(handler)
handler.handshakeFuture.addListener(new FutureListener[channel.Channel]() {
def operationComplete(future: io.netty.util.concurrent.Future[channel.Channel]) {
if (future.isSuccess()) {
ctx.writeAndFlush(new StartupMessage(properties))
} else {
connectionDelegate.onError(future.cause())
}
}
})
} else if (configuration.ssl.mode < Mode.Require) {
ctx.writeAndFlush(new StartupMessage(properties))
} else {
connectionDelegate.onError(new IllegalArgumentException("SSL is not supported on server"))
}

case m: ServerMessage => {

(m.kind : @switch) match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.github.mauricio.async.db.postgresql.encoders

import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled

object SSLMessageEncoder {

def encode(): ByteBuf = {
val buffer = Unpooled.buffer()
buffer.writeInt(8)
buffer.writeShort(1234)
buffer.writeShort(5679)
buffer
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ import com.github.mauricio.async.db.util.ByteBufferUtils
import java.nio.charset.Charset
import io.netty.buffer.{Unpooled, ByteBuf}

class StartupMessageEncoder(charset: Charset) extends Encoder {
class StartupMessageEncoder(charset: Charset) {

//private val log = Log.getByName("StartupMessageEncoder")

override def encode(message: ClientMessage): ByteBuf = {

val startup = message.asInstanceOf[StartupMessage]
def encode(startup: StartupMessage): ByteBuf = {

val buffer = Unpooled.buffer()
buffer.writeInt(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.github.mauricio.async.db.postgresql.messages.backend

case class SSLResponseMessage(supported: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ object ServerMessage {
final val Query = 'Q'
final val RowDescription = 'T'
final val ReadyForQuery = 'Z'
final val Startup = '0'
final val Sync = 'S'
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.github.mauricio.async.db.postgresql.messages.frontend

trait InitialClientMessage
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.github.mauricio.async.db.postgresql.messages.frontend

import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage

object SSLRequestMessage extends InitialClientMessage
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,4 @@

package com.github.mauricio.async.db.postgresql.messages.frontend

import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage

class StartupMessage(val parameters: List[(String, Any)]) extends ClientMessage(ServerMessage.Startup)
class StartupMessage(val parameters: List[(String, Any)]) extends InitialClientMessage
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,37 @@ object ParserURL {
val PGPORT = "port"
val PGDBNAME = "database"
val PGHOST = "host"
val PGUSERNAME = "username"
val PGUSERNAME = "user"
val PGPASSWORD = "password"

val DEFAULT_PORT = "5432"

private val pgurl1 = """(jdbc:postgresql):(?://([^/:]*|\[.+\])(?::(\d+))?)?(?:/([^/?]*))?(?:\?user=(.*)&password=(.*))?""".r
private val pgurl2 = """(postgres|postgresql)://(.*):(.*)@(.*):(\d+)/(.*)""".r
private val pgurl1 = """(jdbc:postgresql):(?://([^/:]*|\[.+\])(?::(\d+))?)?(?:/([^/?]*))?(?:\?(.*))?""".r
private val pgurl2 = """(postgres|postgresql)://(.*):(.*)@(.*):(\d+)/([^/?]*)(?:\?(.*))?""".r

def parse(connectionURL: String): Map[String, String] = {
val properties: Map[String, String] = Map()

def parseOptions(optionsStr: String): Map[String, String] =
optionsStr.split("&").map { o =>
o.span(_ != '=') match {
case (name, value) => name -> value.drop(1)
}
}.toMap

connectionURL match {
case pgurl1(protocol, server, port, dbname, username, password) => {
case pgurl1(protocol, server, port, dbname, params) => {
var result = properties
if (server != null) result += (PGHOST -> unwrapIpv6address(server))
if (dbname != null && dbname.nonEmpty) result += (PGDBNAME -> dbname)
if(port != null) result += (PGPORT -> port)
if(username != null) result = (result + (PGUSERNAME -> username) + (PGPASSWORD -> password))
if (port != null) result += (PGPORT -> port)
if (params != null) result ++= parseOptions(params)
result
}
case pgurl2(protocol, username, password, server, port, dbname) => {
properties + (PGHOST -> unwrapIpv6address(server)) + (PGPORT -> port) + (PGDBNAME -> dbname) + (PGUSERNAME -> username) + (PGPASSWORD -> password)
case pgurl2(protocol, username, password, server, port, dbname, params) => {
var result = properties + (PGHOST -> unwrapIpv6address(server)) + (PGPORT -> port) + (PGDBNAME -> dbname) + (PGUSERNAME -> username) + (PGPASSWORD -> password)
if (params != null) result ++= parseOptions(params)
result
}
case _ => {
logger.warn(s"Connection url '$connectionURL' could not be parsed.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package com.github.mauricio.async.db.postgresql.util

import com.github.mauricio.async.db.Configuration
import com.github.mauricio.async.db.{Configuration, SSLConfiguration}
import java.nio.charset.Charset

object URLParser {

private val Username = "username"
private val Password = "password"

import Configuration.Default

def parse(url: String,
Expand All @@ -35,11 +32,12 @@ object URLParser {
val port = properties.get(ParserURL.PGPORT).getOrElse(ParserURL.DEFAULT_PORT).toInt

new Configuration(
username = properties.get(Username).getOrElse(Default.username),
password = properties.get(Password),
username = properties.get(ParserURL.PGUSERNAME).getOrElse(Default.username),
password = properties.get(ParserURL.PGPASSWORD),
database = properties.get(ParserURL.PGDBNAME),
host = properties.getOrElse(ParserURL.PGHOST, Default.host),
port = port,
ssl = SSLConfiguration(properties),
charset = charset
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ package com.github.mauricio.async.db.postgresql

import com.github.mauricio.async.db.util.Log
import com.github.mauricio.async.db.{Connection, Configuration}
import java.io.File
import java.util.concurrent.{TimeoutException, TimeUnit}
import scala.Some
import scala.concurrent.duration._
import scala.concurrent.{Future, Await}
import com.github.mauricio.async.db.SSLConfiguration
import com.github.mauricio.async.db.SSLConfiguration.Mode

object DatabaseTestHelper {
val log = Log.get[DatabaseTestHelper]
Expand Down Expand Up @@ -54,6 +56,16 @@ trait DatabaseTestHelper {
withHandler(this.timeTestConfiguration, fn)
}

def withSSLHandler[T](mode: SSLConfiguration.Mode.Value, host: String = "localhost", rootCert: Option[File] = Some(new File("script/server.crt")))(fn: (PostgreSQLConnection) => T): T = {
val config = new Configuration(
host = host,
port = databasePort,
username = "postgres",
database = databaseName,
ssl = SSLConfiguration(mode = mode, rootCert = rootCert))
withHandler(config, fn)
}

def withHandler[T](configuration: Configuration, fn: (PostgreSQLConnection) => T): T = {

val handler = new PostgreSQLConnection(configuration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.util

class MessageDecoderSpec extends Specification {

val decoder = new MessageDecoder(CharsetUtil.UTF_8)
val decoder = new MessageDecoder(false, CharsetUtil.UTF_8)

"message decoder" should {

Expand Down
Loading

0 comments on commit 0f9a587

Please sign in to comment.