Skip to content

Commit

Permalink
Modified #3472 to make its API more idiomatic.
Browse files Browse the repository at this point in the history
  • Loading branch information
dom96 committed Jun 3, 2016
1 parent c170646 commit 5390c25
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 35 deletions.
26 changes: 19 additions & 7 deletions examples/ssl/extradata.nim
@@ -1,14 +1,26 @@
# Stores extra data inside the SSL context. # Stores extra data inside the SSL context.
import net import net


let ctx = newContext()

# Our unique index for storing foos # Our unique index for storing foos
let fooIndex = getSslContextExtraDataIndex() let fooIndex = ctx.getExtraDataIndex()
# And another unique index for storing foos # And another unique index for storing foos
let barIndex = getSslContextExtraDataIndex() let barIndex = ctx.getExtraDataIndex()
echo "got indexes ", fooIndex, " ", barIndex echo "got indexes ", fooIndex, " ", barIndex


let ctx = newContext() try:
assert ctx.getExtraData(fooIndex) == nil discard ctx.getExtraData(fooIndex)
let foo: int = 5 assert false
ctx.setExtraData(fooIndex, cast[pointer](foo)) except IndexError:
assert cast[int](ctx.getExtraData(fooIndex)) == foo echo("Success")

type
FooRef = ref object of RootRef
foo: int

let foo = FooRef(foo: 5)
ctx.setExtraData(fooIndex, foo)
doAssert ctx.getExtraData(fooIndex).FooRef == foo

ctx.destroyContext()
80 changes: 52 additions & 28 deletions lib/pure/net.nim
Expand Up @@ -66,7 +66,7 @@
## ##


{.deadCodeElim: on.} {.deadCodeElim: on.}
import nativesockets, os, strutils, parseutils, times import nativesockets, os, strutils, parseutils, times, sets
export Port, `$`, `==` export Port, `$`, `==`
export Domain, SockType, Protocol export Domain, SockType, Protocol


Expand All @@ -88,7 +88,10 @@ when defineSsl:
SslProtVersion* = enum SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23 protSSLv2, protSSLv3, protTLSv1, protSSLv23


SslContext* = distinct SslCtx SslContext* = ref object
context: SslCtx
extraInternalIndex: int
referencedData: HashSet[int]


SslAcceptResult* = enum SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
Expand Down Expand Up @@ -229,9 +232,10 @@ when defineSsl:
ErrLoadBioStrings() ErrLoadBioStrings()
OpenSSL_add_all_algorithms() OpenSSL_add_all_algorithms()


type SslContextExtraInternal = ref object type
serverGetPskFunc: SslServerGetPskFunc SslContextExtraInternal = ref object of RootRef
clientGetPskFunc: SslClientGetPskFunc serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc


proc raiseSSLError*(s = "") = proc raiseSSLError*(s = "") =
## Raises a new SSL error. ## Raises a new SSL error.
Expand All @@ -245,21 +249,33 @@ when defineSsl:
var errStr = ErrErrorString(err, nil) var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr) raise newException(SSLError, $errStr)


proc getSslContextExtraDataIndex*(): cint = proc getExtraDataIndex*(ctx: SSLContext): int =
## Retrieves unique index for storing extra data in SSLContext. ## Retrieves unique index for storing extra data in SSLContext.
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil) result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
if result < 0:
raiseSSLError()

proc getExtraData*(ctx: SSLContext, index: int): RootRef =
## Retrieves arbitrary data stored inside SSLContext.
if index notin ctx.referencedData:
raise newException(IndexError, "No data with that index.")
let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
if cast[int](res) == 0:
raiseSSLError()
return cast[RootRef](res)


proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) = proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) =
## Stores arbitrary data inside SSLContext. The unique `index` ## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex. ## should be retrieved using getSslContextExtraDataIndex.
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1: if index in ctx.referencedData:
raiseSSLError() GC_unref(getExtraData(ctx, index))


proc getExtraData*(ctx: SSLContext, index: cint): pointer = if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
## Retrieves arbitrary data stored inside SSLContext. raiseSSLError()
return SslCtx(ctx).SSL_CTX_get_ex_data(index)


let extraInternalIndex = getSslContextExtraDataIndex() if index notin ctx.referencedData:
ctx.referencedData.incl(index)
GC_ref(data)


# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) = proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
Expand Down Expand Up @@ -323,34 +339,41 @@ when defineSsl:
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile) newCTX.loadCertificates(certFile, keyFile)


result = SSLContext(newCTX) result = SSLContext(context: newCTX, extraInternalIndex: 0,
referencedData: initSet[int]())
result.extraInternalIndex = getExtraDataIndex(result)
# The PSK callback functions assume the internal index is 0.
assert result.extraInternalIndex == 0

let extraInternal = new(SslContextExtraInternal) let extraInternal = new(SslContextExtraInternal)
GC_ref(extraInternal) result.setExtraData(result.extraInternalIndex, extraInternal)
result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))


proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex)) return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))


proc destroyContext*(ctx: SSLContext) = proc destroyContext*(ctx: SSLContext) =
## Free memory referenced by SSLContext. ## Free memory referenced by SSLContext.
let extraInternal = ctx.getExtraInternal()
if extraInternal != nil: # We assume here that OpenSSL's internal indexes increase by 1 each time.
GC_unref(extraInternal) # That means we can assume that the next internal index is the length of
SSLCTX(ctx).SSL_CTX_free() # extra data indexes.
for i in ctx.referencedData:
GC_unref(getExtraData(ctx, i).RootRef)
ctx.context.SSL_CTX_free()


proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
## Sets the identity hint passed to server. ## Sets the identity hint passed to server.
## ##
## Only used in PSK ciphersuites. ## Only used in PSK ciphersuites.
if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0: if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
raiseSSLError() raiseSSLError()


proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc = proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
return ctx.getExtraInternal().clientGetPskFunc return ctx.getExtraInternal().clientGetPskFunc


proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} = max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX) let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let hintString = if hint == nil: nil else: $hint let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len: if psk.len.cuint > max_psk_len:
Expand All @@ -369,13 +392,14 @@ when defineSsl:
## ##
## Only used in PSK ciphersuites. ## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun ctx.getExtraInternal().clientGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback) ctx.context.SSL_CTX_set_psk_client_callback(
if fun == nil: nil else: pskClientCallback)


proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc = proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc return ctx.getExtraInternal().serverGetPskFunc


proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX) let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let pskString = (ctx.serverGetPskFunc)($identity) let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len: if psk.len.cint > max_psk_len:
return 0 return 0
Expand All @@ -388,7 +412,7 @@ when defineSsl:
## ##
## Only used in PSK ciphersuites. ## Only used in PSK ciphersuites.
ctx.getExtraInternal().serverGetPskFunc = fun ctx.getExtraInternal().serverGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
else: pskServerCallback) else: pskServerCallback)


proc getPskIdentity*(socket: Socket): string = proc getPskIdentity*(socket: Socket): string =
Expand All @@ -409,7 +433,7 @@ when defineSsl:
assert (not socket.isSSL) assert (not socket.isSSL)
socket.isSSL = true socket.isSSL = true
socket.sslContext = ctx socket.sslContext = ctx
socket.sslHandle = SSLNew(SSLCTX(socket.sslContext)) socket.sslHandle = SSLNew(socket.sslContext.context)
socket.sslNoHandshake = false socket.sslNoHandshake = false
socket.sslHasPeekChar = false socket.sslHasPeekChar = false
if socket.sslHandle == nil: if socket.sslHandle == nil:
Expand Down

0 comments on commit 5390c25

Please sign in to comment.