Permalink
Browse files

Modified #3472 to make its API more idiomatic.

  • Loading branch information...
1 parent c170646 commit 5390c25b60e79f87aca339f7428575066b0b2d08 @dom96 dom96 committed Jun 3, 2016
Showing with 71 additions and 35 deletions.
  1. +19 −7 examples/ssl/extradata.nim
  2. +52 −28 lib/pure/net.nim
@@ -1,14 +1,26 @@
# Stores extra data inside the SSL context.
import net
+let ctx = newContext()
+
# Our unique index for storing foos
-let fooIndex = getSslContextExtraDataIndex()
+let fooIndex = ctx.getExtraDataIndex()
# And another unique index for storing foos
-let barIndex = getSslContextExtraDataIndex()
+let barIndex = ctx.getExtraDataIndex()
echo "got indexes ", fooIndex, " ", barIndex
-let ctx = newContext()
-assert ctx.getExtraData(fooIndex) == nil
-let foo: int = 5
-ctx.setExtraData(fooIndex, cast[pointer](foo))
-assert cast[int](ctx.getExtraData(fooIndex)) == foo
+try:
+ discard ctx.getExtraData(fooIndex)
+ assert false
+except IndexError:
+ echo("Success")
+
+type
+ FooRef = ref object of RootRef
+ foo: int
+
+let foo = FooRef(foo: 5)
+ctx.setExtraData(fooIndex, foo)
+doAssert ctx.getExtraData(fooIndex).FooRef == foo
+
+ctx.destroyContext()
View
@@ -66,7 +66,7 @@
##
{.deadCodeElim: on.}
-import nativesockets, os, strutils, parseutils, times
+import nativesockets, os, strutils, parseutils, times, sets
export Port, `$`, `==`
export Domain, SockType, Protocol
@@ -88,7 +88,10 @@ when defineSsl:
SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23
- SslContext* = distinct SslCtx
+ SslContext* = ref object
+ context: SslCtx
+ extraInternalIndex: int
+ referencedData: HashSet[int]
SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
@@ -229,9 +232,10 @@ when defineSsl:
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()
- type SslContextExtraInternal = ref object
- serverGetPskFunc: SslServerGetPskFunc
- clientGetPskFunc: SslClientGetPskFunc
+ type
+ SslContextExtraInternal = ref object of RootRef
+ serverGetPskFunc: SslServerGetPskFunc
+ clientGetPskFunc: SslClientGetPskFunc
proc raiseSSLError*(s = "") =
## Raises a new SSL error.
@@ -245,21 +249,33 @@ when defineSsl:
var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr)
- proc getSslContextExtraDataIndex*(): cint =
+ proc getExtraDataIndex*(ctx: SSLContext): int =
## Retrieves unique index for storing extra data in SSLContext.
- return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
+ result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
+ if result < 0:
+ raiseSSLError()
+
+ proc getExtraData*(ctx: SSLContext, index: int): RootRef =
+ ## Retrieves arbitrary data stored inside SSLContext.
+ if index notin ctx.referencedData:
+ raise newException(IndexError, "No data with that index.")
+ let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
+ if cast[int](res) == 0:
+ raiseSSLError()
+ return cast[RootRef](res)
- proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
+ proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) =
## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
- if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
- raiseSSLError()
+ if index in ctx.referencedData:
+ GC_unref(getExtraData(ctx, index))
- proc getExtraData*(ctx: SSLContext, index: cint): pointer =
- ## Retrieves arbitrary data stored inside SSLContext.
- return SslCtx(ctx).SSL_CTX_get_ex_data(index)
+ if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
+ raiseSSLError()
- let extraInternalIndex = getSslContextExtraDataIndex()
+ if index notin ctx.referencedData:
+ ctx.referencedData.incl(index)
+ GC_ref(data)
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
@@ -323,34 +339,41 @@ when defineSsl:
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)
- result = SSLContext(newCTX)
+ result = SSLContext(context: newCTX, extraInternalIndex: 0,
+ referencedData: initSet[int]())
+ result.extraInternalIndex = getExtraDataIndex(result)
+ # The PSK callback functions assume the internal index is 0.
+ assert result.extraInternalIndex == 0
+
let extraInternal = new(SslContextExtraInternal)
- GC_ref(extraInternal)
- result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
+ result.setExtraData(result.extraInternalIndex, extraInternal)
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
- return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
+ return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
proc destroyContext*(ctx: SSLContext) =
## Free memory referenced by SSLContext.
- let extraInternal = ctx.getExtraInternal()
- if extraInternal != nil:
- GC_unref(extraInternal)
- SSLCTX(ctx).SSL_CTX_free()
+
+ # We assume here that OpenSSL's internal indexes increase by 1 each time.
+ # That means we can assume that the next internal index is the length of
+ # extra data indexes.
+ for i in ctx.referencedData:
+ GC_unref(getExtraData(ctx, i).RootRef)
+ ctx.context.SSL_CTX_free()
proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
## Sets the identity hint passed to server.
##
## Only used in PSK ciphersuites.
- if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
+ if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
raiseSSLError()
proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
return ctx.getExtraInternal().clientGetPskFunc
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} =
- let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
+ let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len:
@@ -369,13 +392,14 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
- SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
+ ctx.context.SSL_CTX_set_psk_client_callback(
+ if fun == nil: nil else: pskClientCallback)
proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
- let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
+ let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len:
return 0
@@ -388,7 +412,7 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().serverGetPskFunc = fun
- SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
+ ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
else: pskServerCallback)
proc getPskIdentity*(socket: Socket): string =
@@ -409,7 +433,7 @@ when defineSsl:
assert (not socket.isSSL)
socket.isSSL = true
socket.sslContext = ctx
- socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
+ socket.sslHandle = SSLNew(socket.sslContext.context)
socket.sslNoHandshake = false
socket.sslHasPeekChar = false
if socket.sslHandle == nil:

0 comments on commit 5390c25

Please sign in to comment.