Skip to content

Commit

Permalink
tls refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Jan 25, 2020
1 parent d805590 commit 0ae8801
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 436 deletions.
@@ -1,7 +1,5 @@
package com.koushikdutta.scratch.tls

expect interface SSLSession

interface HostnameVerifier {
/**
* Verify that the host name is an acceptable match with
Expand All @@ -11,11 +9,13 @@ interface HostnameVerifier {
* @param session SSLSession used on the connection to host
* @return true if the host name is acceptable
*/
fun verify(hostname: String, session: SSLSession): Boolean
fun verify(engine: SSLEngine): Boolean
}

interface AsyncTlsTrustFailureCallback {
fun handleOrRethrow(throwable: Throwable)
}

class AsyncTlsOptions(internal val hostnameVerifier: HostnameVerifier? = null, internal val trustFailureCallback: AsyncTlsTrustFailureCallback?)

expect object DefaultHostnameVerifier : HostnameVerifier
14 changes: 7 additions & 7 deletions src/commonMain/kotlin/com.koushikdutta.scratch/tls/expect.kt
@@ -1,8 +1,9 @@
package com.koushikdutta.scratch.tls

import com.koushikdutta.scratch.AsyncSocket
import com.koushikdutta.scratch.AsyncWrappingSocket
import com.koushikdutta.scratch.IOException
import com.koushikdutta.scratch.buffers.AllocationTracker
import com.koushikdutta.scratch.buffers.ByteBufferList
import com.koushikdutta.scratch.buffers.WritableBuffers

class SSLEngineResult constructor(val status: SSLEngineStatus, val handshakeStatus: SSLEngineHandshakeStatus)

Expand Down Expand Up @@ -30,7 +31,10 @@ expect abstract class SSLEngine {
}
expect fun SSLEngine.runHandshakeTask()
expect fun SSLEngine.checkHandshakeStatus(): SSLEngineHandshakeStatus
expect open class SSLException: IOException
expect fun SSLEngine.unwrap(src: ByteBufferList, dst: WritableBuffers, tracker: AllocationTracker = AllocationTracker()): SSLEngineResult
expect fun SSLEngine.wrap(src: ByteBufferList, dst: WritableBuffers, tracker: AllocationTracker = AllocationTracker()): SSLEngineResult

expect open class SSLException(message: String): IOException
expect class SSLHandshakeException : SSLException

expect interface RSAPrivateKey
Expand All @@ -42,7 +46,3 @@ expect fun SSLContext.init(pk: RSAPrivateKey, certificate: X509Certificate): SSL
expect fun SSLContext.init(certificate: X509Certificate): SSLContext
expect fun createTLSContext(): SSLContext
expect fun getDefaultSSLContext(): SSLContext

expect class AsyncTlsSocket(socket: AsyncSocket, engine: SSLEngine, options: AsyncTlsOptions?) : AsyncWrappingSocket {
internal suspend fun awaitHandshake()
}
193 changes: 193 additions & 0 deletions src/commonMain/kotlin/com.koushikdutta.scratch/tls/socket.kt
@@ -0,0 +1,193 @@
package com.koushikdutta.scratch.tls

import com.koushikdutta.scratch.*
import com.koushikdutta.scratch.buffers.AllocationTracker
import com.koushikdutta.scratch.buffers.ByteBufferList
import com.koushikdutta.scratch.buffers.ReadableBuffers
import com.koushikdutta.scratch.buffers.WritableBuffers

class AsyncTlsSocket(override val socket: AsyncSocket, val engine: SSLEngine, private val options: AsyncTlsOptions?) : AsyncWrappingSocket, AsyncAffinity by socket {
private var finishedHandshake = false
private val socketRead = InterruptibleRead(socket::read)
private val decryptAllocator = AllocationTracker()
private val decryptedRead = (socketRead::read as AsyncRead).pipe {
val unfiltered = ByteBufferList();
while (true) {
val awaitingHandshake = !finishedHandshake

while (true) {
val result = engine.unwrap(unfiltered, buffer, decryptAllocator)

if (result.status == SSLEngineStatus.BUFFER_UNDERFLOW) {
// need more data, so just break and wait for another read to come in to
// trigger the read again.
break
} else if (result.handshakeStatus == SSLEngineHandshakeStatus.NEED_WRAP) {
// this may complete the handshake
encryptedWrite(ByteBufferList())
}
else if (result.handshakeStatus == SSLEngineHandshakeStatus.NEED_UNWRAP) {
continue
}

handleHandshakeStatus(result.handshakeStatus)
// flush possibly empty buffer on handshake status change to trigger handshake completion
if (awaitingHandshake && finishedHandshake) {
flush()
break
}

// if there's no handshake, and also no data left, just bail.
if (finishedHandshake && unfiltered.isEmpty)
break
}

if (!buffer.isEmpty)
flush()

if (!it(unfiltered) && unfiltered.isEmpty)
break
}
}

private val unencryptedWriteBuffer = ByteBufferList()
private val encryptedWriteBuffer = ByteBufferList()
private val encryptAllocator = AllocationTracker()
private val encryptedWrite: AsyncWrite = write@{ buffer ->
await()

if (encryptedWriteBuffer.hasRemaining()) {
socket.write(encryptedWriteBuffer)
return@write
}

// move the unencrypted data from upstream into a working buffer.
buffer.read(unencryptedWriteBuffer)

if (finishedHandshake)
encryptAllocator.minAlloc = unencryptedWriteBuffer.remaining()

val awaitingHandshake = !finishedHandshake
while (true) {
val result = engine.wrap(unencryptedWriteBuffer, encryptedWriteBuffer, encryptAllocator)

if (encryptedWriteBuffer.hasRemaining()) {
// before the handshake is completed, ensure that all writes are fully written before
// returning to the handshake loop.
// without blocking on a
// after completion, partial writes are used for back pressure.
if (!awaitingHandshake)
socket.write(encryptedWriteBuffer)
else
socket::write.drain(encryptedWriteBuffer)
}

if (result.status == SSLEngineStatus.BUFFER_UNDERFLOW) {
// this should never happen, as it is not possible to underflow
// with application data
break
} else if (result.handshakeStatus == SSLEngineHandshakeStatus.NEED_UNWRAP) {
socketRead.interrupt()
break
} else if (result.handshakeStatus == SSLEngineHandshakeStatus.NEED_WRAP) {
continue
}

handleHandshakeStatus(result.handshakeStatus)

if (awaitingHandshake && finishedHandshake)
break
if (finishedHandshake && unencryptedWriteBuffer.isEmpty)
break
}
}

var peerCertificates: Array<X509Certificate>? = null
private set

private suspend fun handleHandshakeStatus(status: SSLEngineHandshakeStatus) {
if (status == SSLEngineHandshakeStatus.NEED_TASK)
engine.runHandshakeTask()

if (!finishedHandshake && engine.checkHandshakeStatus() == SSLEngineHandshakeStatus.FINISHED) {
if (engine.useClientMode) {
var trusted = true
var peerUnverifiedCause: Exception? = null
try {
trusted = false
val verifier: HostnameVerifier = options?.hostnameVerifier ?: DefaultHostnameVerifier
if (!verifier.verify(engine))
throw SSLException("hostname verification failed for <$engine.peerHost>")
trusted = true
}
catch (exception: Exception) {
peerUnverifiedCause = exception
}

finishedHandshake = true
if (!trusted) {
if (options?.trustFailureCallback == null)
throw peerUnverifiedCause!!
else
options.trustFailureCallback.handleOrRethrow(peerUnverifiedCause!!)
}
}
else {
finishedHandshake = true
}

// upon handshake completion, trigger a wrap/unwrap so that all pending input/output
// gets flushed
socketRead.interrupt()
encryptedWrite(ByteBufferList())
}
}

override suspend fun close() {
socket.close()
}

override suspend fun write(buffer: ReadableBuffers) {
// do not allow empty writes. this causes ssl engine to terminate on some platforms.
if (buffer.isEmpty)
return
encryptedWrite(buffer)
}

override suspend fun read(buffer: WritableBuffers): Boolean {
return reader(buffer)
}

// need a reader to catch any overflow from the handshake
private var reader: AsyncRead = decryptedRead

internal suspend fun awaitHandshake() {
val handshakeBuffer = ByteBufferList()
reader = {
reader = decryptedRead
handshakeBuffer.read(it)
true
}

while (!finishedHandshake) {
// the suspending read calls in awaitData will be interrupted when an
// unwrap is necessary.
// keep wrapping/unwrapping until the handshake finishes.

// trigger a wrap call
encryptedWrite(ByteBufferList())
// trigger an unwrap call, wait for data
// this will unsuspend once the handshake completes, even if no data is available.
// an empty data set is still a valid
if (!decryptedRead(handshakeBuffer) && !finishedHandshake)
throw SSLException("socket unexpectedly closed")
}

// some ssl implementations finish the handshake, but still have a final wrap.
// ensure that happens.
// also trigger a background read to read that final packet if sent
// from the peer.
socketRead.readTransient()
encryptedWrite(ByteBufferList())
}
}
Expand Up @@ -17,6 +17,7 @@
package com.koushikdutta.scratch.external

import com.koushikdutta.scratch.tls.HostnameVerifier
import com.koushikdutta.scratch.tls.SSLEngine
import java.security.cert.CertificateParsingException
import java.security.cert.X509Certificate
import java.util.*
Expand Down Expand Up @@ -52,9 +53,11 @@ object OkHostnameVerifier : HostnameVerifier {
private const val ALT_DNS_NAME = 2
private const val ALT_IPA_NAME = 7

override fun verify(host: String, session: SSLSession): Boolean {
override fun verify(engine: SSLEngine): Boolean {
return try {
verify(host, session.peerCertificates[0] as X509Certificate)
if (engine.peerHost == null)
return true
verify(engine.peerHost, engine.session.peerCertificates[0] as X509Certificate)
} catch (_: SSLException) {
false
}
Expand Down
7 changes: 5 additions & 2 deletions src/jvmMain/kotlin/com/koushikdutta/scratch/tls/sslengine.kt
Expand Up @@ -3,6 +3,9 @@ package com.koushikdutta.scratch.tls
import com.koushikdutta.scratch.buffers.AllocationTracker
import com.koushikdutta.scratch.buffers.ByteBufferList
import com.koushikdutta.scratch.buffers.WritableBuffers
import com.koushikdutta.scratch.external.OkHostnameVerifier

actual typealias DefaultHostnameVerifier = OkHostnameVerifier

fun javax.net.ssl.SSLEngineResult.Status.convert(): SSLEngineStatus {
return when (this) {
Expand Down Expand Up @@ -36,7 +39,7 @@ actual fun SSLEngine.checkHandshakeStatus(): SSLEngineHandshakeStatus {
}

// extension methods for unwrap with Buffers. manages overflows and allocations.
fun SSLEngine.unwrap(src: ByteBufferList, dst: WritableBuffers, tracker: AllocationTracker = AllocationTracker()): SSLEngineResult {
actual fun SSLEngine.unwrap(src: ByteBufferList, dst: WritableBuffers, tracker: AllocationTracker): SSLEngineResult {
tracker.finishTracking()
while (true) {
val unfiltered = if (src.hasRemaining()) src.readFirst() else ByteBufferList.EMPTY_BYTEBUFFER
Expand Down Expand Up @@ -70,7 +73,7 @@ fun SSLEngine.unwrap(src: ByteBufferList, dst: WritableBuffers, tracker: Allocat
}

// extension methods for wrap with Buffers. manages overflows and allocations.
fun SSLEngine.wrap(src: ByteBufferList, dst: WritableBuffers, tracker: AllocationTracker = AllocationTracker()): SSLEngineResult {
actual fun SSLEngine.wrap(src: ByteBufferList, dst: WritableBuffers, tracker: AllocationTracker): SSLEngineResult {
tracker.finishTracking()
while (true) {
val unencrypted = src.readAll()
Expand Down

0 comments on commit 0ae8801

Please sign in to comment.