Skip to content

Commit

Permalink
Added SSK test vector support, fixed setSecondaryKey bug
Browse files Browse the repository at this point in the history
  • Loading branch information
centromere committed May 23, 2016
1 parent e912c12 commit c6832ca
Show file tree
Hide file tree
Showing 7 changed files with 24,832 additions and 4,551 deletions.
12 changes: 7 additions & 5 deletions src/Crypto/Noise.hs
Expand Up @@ -108,18 +108,20 @@ handshakeComplete ns = isJust (ns ^. nsSendingCipherState) &&
-- binding. For example, the initiator might cryptographically sign this
-- value as part of some higher-level authentication scheme.
--
-- The value returned by this function is only meaningful after the
-- handshake is complete.
--
-- See section 9.4 of the protocol for details.
handshakeHash :: (Cipher c, DH d, Hash h)
=> NoiseState c d h
-> Maybe ScrubbedBytes
handshakeHash ns = either (const Nothing)
(Just . hashToBytes)
$ ns ^. nsHandshakeState . hsSymmetricState . ssh
-> ScrubbedBytes
handshakeHash ns = either id hashToBytes
$ ns ^. nsHandshakeState . hsSymmetricState . ssh

-- | Sets a secondary symmetric key. This must be 32 bytes in length.
--
-- See section 9.5 of the protocol for details.
setSecondaryKey :: (Cipher c, DH h, Hash h)
setSecondaryKey :: (Cipher c, DH d, Hash h)
=> NoiseState c d h
-> ScrubbedBytes
-> NoiseState c d h
Expand Down
48 changes: 25 additions & 23 deletions src/Crypto/Noise/Internal/NoiseState.hs
Expand Up @@ -109,6 +109,7 @@ runHandshake msg ns = runExcept $ do
Left (Request req resp) -> return (req, ns & nsHandshakeSuspension .~ (Handshake . resp))
Right _ -> do
hs <- get

let (cs1, cs2) = split (hs ^. hsSymmetricState)
ns' = if hs ^. hsOpts . hoRole == InitiatorRole
then ns & nsSendingCipherState .~ Just cs1
Expand Down Expand Up @@ -145,33 +146,34 @@ processPatternOp :: (Cipher c, DH d, Hash h)
processPatternOp opRole t next = do
hs <- get
input <- Handshake <$> request $ hs ^. hsMsgBuffer
hs' <- get

if opRole == hs ^. hsOpts . hoRole then do
put $ hs & hsMsgBuffer .~ mempty
if opRole == hs' ^. hsOpts . hoRole then do
put $ hs' & hsMsgBuffer .~ mempty
iterM (evalMsgToken opRole) $ hoistFT (return . runIdentity) t

hs' <- get
hs'' <- get

let enc = encryptAndHash (convert input) $ hs' ^. hsSymmetricState
let enc = encryptAndHash (convert input) $ hs'' ^. hsSymmetricState

(ep, ss) <- either throwError return enc

put $ hs' & hsMsgBuffer %~ (flip mappend . convert) ep
& hsSymmetricState .~ ss
put $ hs'' & hsMsgBuffer %~ (flip mappend . convert) ep
& hsSymmetricState .~ ss
else do
put $ hs & hsMsgBuffer .~ input
put $ hs' & hsMsgBuffer .~ input
iterM (evalMsgToken opRole) $ hoistFT (return . runIdentity) t

hs' <- get
hs'' <- get

let remaining = hs' ^. hsMsgBuffer
let remaining = hs'' ^. hsMsgBuffer
dec = decryptAndHash (cipherBytesToText (convert remaining))
$ hs' ^. hsSymmetricState
$ hs'' ^. hsSymmetricState

(dp, ss) <- either (const . throwError . HandshakeError $ "handshake payload failed to decrypt") return dec

put $ hs' & hsMsgBuffer .~ convert dp
& hsSymmetricState .~ ss
put $ hs'' & hsMsgBuffer .~ convert dp
& hsSymmetricState .~ ss

next

Expand Down Expand Up @@ -272,13 +274,13 @@ evalMsgToken opRole (Dhes next) = do
hs <- get

if opRole == hs ^. hsOpts . hoRole then do
let ss = hs ^. hsSymmetricState
let ss = hs ^. hsSymmetricState

rpk <- getRemoteStatic hs

~(sk, _) <- getLocalEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
let dh = dhPerform sk rpk
ss' = mixKey dh ss

put $ hs & hsSymmetricState .~ ss'

Expand All @@ -289,13 +291,13 @@ evalMsgToken opRole (Dhse next) = do
hs <- get

if opRole == hs ^. hsOpts . hoRole then do
let ss = hs ^. hsSymmetricState
let ss = hs ^. hsSymmetricState

~(sk, _) <- getLocalStatic hs

rpk <- getRemoteEphemeral hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
let dh = dhPerform sk rpk
ss' = mixKey dh ss

put $ hs & hsSymmetricState .~ ss'

Expand All @@ -305,12 +307,12 @@ evalMsgToken opRole (Dhse next) = do
evalMsgToken _ (Dhss next) = do
hs <- get

let ss = hs ^. hsSymmetricState
let ss = hs ^. hsSymmetricState

~(sk, _) <- getLocalStatic hs
rpk <- getRemoteStatic hs
let dh = dhPerform sk rpk
ss' = mixKey dh ss
let dh = dhPerform sk rpk
ss' = mixKey dh ss

put $ hs & hsSymmetricState .~ ss'

Expand All @@ -323,7 +325,7 @@ evalPreMsgToken :: (Cipher c, DH d, Hash h)
evalPreMsgToken opRole (E next) = do
hs <- get

let ss = hs ^. hsSymmetricState
let ss = hs ^. hsSymmetricState
pk <- if opRole == hs ^. hsOpts . hoRole
then snd <$> getLocalSemiEphemeral hs
else getRemoteSemiEphemeral hs
Expand All @@ -337,7 +339,7 @@ evalPreMsgToken opRole (E next) = do
evalPreMsgToken opRole (S next) = do
hs <- get

let ss = hs ^. hsSymmetricState
let ss = hs ^. hsSymmetricState

pk <- if opRole == hs ^. hsOpts . hoRole
then snd <$> getLocalStatic hs
Expand Down
19 changes: 13 additions & 6 deletions tests/vectors/Generate.hs
Expand Up @@ -58,7 +58,7 @@ responderStatic DTCurve448 = hexToPair "a9b45971180882a79b89a3399544a425ef8136d2

mkKeys :: DH d
=> Plaintext
-> Maybe Plaintext
-> Maybe ScrubbedBytes
-> Bool
-> DHType d
-> HandshakeKeys d
Expand Down Expand Up @@ -91,18 +91,22 @@ genMessages swap = go []

genVector :: [ScrubbedBytes]
-> HandshakeType
-> Maybe Plaintext
-> Maybe ScrubbedBytes
-> Maybe ScrubbedBytes
-> SomeCipherType
-> SomeDHType
-> SomeHashType
-> IO Vector
genVector payloads pat psk cType@(WrapCipherType c) dType@(WrapDHType d) hType@(WrapHashType h) = do
genVector payloads pat psk ssk cType@(WrapCipherType c) dType@(WrapDHType d) hType@(WrapHashType h) = do
let ihk = mkKeys "John Galt" psk True d
rhk = mkKeys "John Galt" psk False d
ins = mkNoiseState ihk pat InitiatorRole c h
ins' = maybe ins (setSecondaryKey ins) ssk
rns = mkNoiseState rhk pat ResponderRole c h
rns' = maybe rns (setSecondaryKey rns) ssk
swap = not (pat == NoiseN || pat == NoiseK || pat == NoiseX)
name = maybe "Noise_" (const "NoisePSK_") psk <>
maybe "" (const "SSK_") ssk <>
show pat <>
"_" <>
show d <>
Expand All @@ -111,7 +115,7 @@ genVector payloads pat psk cType@(WrapCipherType c) dType@(WrapDHType d) hType@(
"_" <>
show h
allMsgs = (\(payload, ct) -> Message (Just payload) ct)
<$> genMessages swap ins rns payloads
<$> genMessages swap ins' rns' payloads

return
Vector { vName = name
Expand All @@ -122,13 +126,15 @@ genVector payloads pat psk cType@(WrapCipherType c) dType@(WrapDHType d) hType@(
, vFail = False
, viPrologue = hkPrologue ihk
, viPSK = hkPSK ihk
, viSSK = ssk
, viStatic = dhSecToBytes . fst <$> ins ^. nsHandshakeState . hsOpts . hoLocalStatic
, viSemiEphemeral = Nothing
, viEphemeral = dhSecToBytes . fst <$> ins ^. nsHandshakeState . hsOpts . hoLocalEphemeral
, virStatic = dhPubToBytes <$> ins ^. nsHandshakeState . hsOpts . hoRemoteStatic
, virSemiEphemeral = Nothing
, vrPrologue = hkPrologue rhk
, vrPSK = hkPSK rhk
, vrSSK = ssk
, vrStatic = dhSecToBytes . fst <$> rns ^. nsHandshakeState . hsOpts . hoLocalStatic
, vrSemiEphemeral = Nothing
, vrEphemeral = dhSecToBytes . fst <$> rns ^. nsHandshakeState . hsOpts . hoLocalEphemeral
Expand All @@ -142,11 +148,12 @@ genVectorFile :: FilePath
genVectorFile f = do
let payloads = ["Ludwig von Mises", "Murray Rothbard", "F. A. Hayek", "Carl Menger", "Jean-Baptiste Say", "Eugen Böhm von Bawerk"]
patterns = [NoiseNN, NoiseKN, NoiseNK, NoiseKK, NoiseNX, NoiseKX, NoiseXN, NoiseIN, NoiseXK, NoiseIK, NoiseXX, NoiseIX, NoiseN, NoiseK, NoiseX]
psks = [Nothing, Just "This is my Austrian perspective!"]
psks = [Nothing, Just "I don't mean to be subjective..."]
ssks = [Nothing, Just "This is my Austrian perspective!"]
ciphers = [WrapCipherType CTChaChaPoly1305, WrapCipherType CTAESGCM]
dhs = [WrapDHType DTCurve25519, WrapDHType DTCurve448]
hashes = [WrapHashType HTSHA256, WrapHashType HTSHA512, WrapHashType HTBLAKE2s, WrapHashType HTBLAKE2b]
vectors = [genVector payloads p psk c d h | p <- patterns, psk <- psks, c <- ciphers, d <- dhs, h <- hashes]
vectors = [genVector payloads p psk ssk c d h | p <- patterns, psk <- psks, ssk <- ssks, c <- ciphers, d <- dhs, h <- hashes]

vs <- mapConcurrently id vectors

Expand Down
6 changes: 6 additions & 0 deletions tests/vectors/VectorFile.hs
Expand Up @@ -39,13 +39,15 @@ data Vector =
, vFail :: Bool
, viPrologue :: ScrubbedBytes
, viPSK :: Maybe ScrubbedBytes
, viSSK :: Maybe ScrubbedBytes
, viStatic :: Maybe ScrubbedBytes
, viSemiEphemeral :: Maybe ScrubbedBytes
, viEphemeral :: Maybe ScrubbedBytes
, virStatic :: Maybe ScrubbedBytes
, virSemiEphemeral :: Maybe ScrubbedBytes
, vrPrologue :: ScrubbedBytes
, vrPSK :: Maybe ScrubbedBytes
, vrSSK :: Maybe ScrubbedBytes
, vrStatic :: Maybe ScrubbedBytes
, vrSemiEphemeral :: Maybe ScrubbedBytes
, vrEphemeral :: Maybe ScrubbedBytes
Expand All @@ -64,13 +66,15 @@ instance ToJSON Vector where
, "fail" .= vFail
, "init_prologue" .= encodeSB viPrologue
, "init_psk" .= (encodeSB <$> viPSK)
, "init_ssk" .= (encodeSB <$> viSSK)
, "init_static" .= (encodeSB <$> viStatic)
, "init_semiephemeral" .= (encodeSB <$> viSemiEphemeral)
, "init_ephemeral" .= (encodeSB <$> viEphemeral)
, "init_remote_static" .= (encodeSB <$> virStatic)
, "init_remote_semiephemeral" .= (encodeSB <$> virSemiEphemeral)
, "resp_prologue" .= encodeSB vrPrologue
, "resp_psk" .= (encodeSB <$> vrPSK)
, "resp_ssk" .= (encodeSB <$> vrSSK)
, "resp_static" .= (encodeSB <$> vrStatic)
, "resp_semiephemeral" .= (encodeSB <$> vrSemiEphemeral)
, "resp_ephemeral" .= (encodeSB <$> vrEphemeral)
Expand All @@ -93,13 +97,15 @@ instance FromJSON Vector where
<*> o .:? "fail" .!= False
<*> (decodeSB <$> o .: "init_prologue")
<*> (fmap decodeSB <$> o .:? "init_psk")
<*> (fmap decodeSB <$> o .:? "init_ssk")
<*> (fmap decodeSB <$> o .:? "init_static")
<*> (fmap decodeSB <$> o .:? "init_semiephemeral")
<*> (fmap decodeSB <$> o .:? "init_ephemeral")
<*> (fmap decodeSB <$> o .:? "init_remote_static")
<*> (fmap decodeSB <$> o .:? "init_remote_semiephemeral")
<*> (decodeSB <$> o .: "resp_prologue")
<*> (fmap decodeSB <$> o .:? "resp_psk")
<*> (fmap decodeSB <$> o .:? "resp_ssk")
<*> (fmap decodeSB <$> o .:? "resp_static")
<*> (fmap decodeSB <$> o .:? "resp_semiephemeral")
<*> (fmap decodeSB <$> o .:? "resp_ephemeral")
Expand Down
10 changes: 6 additions & 4 deletions tests/vectors/Verify.hs
Expand Up @@ -87,8 +87,10 @@ verifyVector v@Vector{..} =
(WrapCipherType c, WrapDHType d, WrapHashType h) ->
let swap = not $ vPattern == NoiseN || vPattern == NoiseK || vPattern == NoiseX
(io, ro) = mkHandshakeOpts v d
(ins, rns) = mkNoiseStates io ro c h in
go swap [] ins rns vMessages
(ins, rns) = mkNoiseStates io ro c h
ins' = maybe ins (setSecondaryKey ins) viSSK
rns' = maybe rns (setSecondaryKey rns) vrSSK in
go swap [] ins' rns' vMessages

where
stripState = join (***) (either Left (\(r, e, _) -> Right (r, e)))
Expand Down Expand Up @@ -134,8 +136,8 @@ verifyVectorFile f = do

allResults <- mapConcurrently (\v -> return (vName v, verifyVector v, vFail v)) $ vfVectors vf

let didItFail = all (== True) . fmap ((== (True, True)) . join (***) (either (const False) (uncurry (==))))
failures = filter (\(_, results, mustItFail) -> (didItFail results == mustItFail)) allResults
let didItFail = all (== True) . fmap ((== (True, True)) . join (***) (either (const False) (uncurry (==))))
failures = filter (\(_, results, mustItFail) -> (didItFail results == mustItFail)) allResults

if not (null failures) then do
putStrLn $ f <> ": The following vectors have failed:\n"
Expand Down
6 changes: 6 additions & 0 deletions tools/vector_template.jinja
Expand Up @@ -14,6 +14,9 @@
{%- if v.init_psk %}
"init_psk": {{ v.init_psk }},
{%- endif %}
{%- if v.init_ssk %}
"init_ssk": {{ v.init_ssk }},
{%- endif %}
{%- if v.init_static %}
"init_static": {{ v.init_static }},
{%- endif %}
Expand All @@ -33,6 +36,9 @@
{%- if v.resp_psk %}
"resp_psk": {{ v.resp_psk }},
{%- endif %}
{%- if v.resp_ssk %}
"resp_ssk": {{ v.resp_ssk }},
{%- endif %}
{%- if v.resp_static %}
"resp_static": {{ v.resp_static }},
{%- endif %}
Expand Down

0 comments on commit c6832ca

Please sign in to comment.