Skip to content

Commit

Permalink
switch to ord map
Browse files Browse the repository at this point in the history
Use ordered maps for aeson processing as well as for the JWT
`unregisteredClaims` field (necessitating a major version bump).
This change makes *jose* secure from hash key collision DoS attacks.
  • Loading branch information
frasertweedale committed Oct 11, 2021
1 parent 0a2f59d commit ffc0cde
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 40 deletions.
3 changes: 3 additions & 0 deletions cabal.project
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
packages: .
constraints:
aeson +ordered-keymap
4 changes: 1 addition & 3 deletions jose.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ common common

build-depends:
base >= 4.8 && < 5
, aeson >= 0.11.1.0 && < 2.0
, aeson >= 2.0.1.0 && < 3
, bytestring == 0.10.*
, lens >= 4.16
, mtl >= 2
Expand Down Expand Up @@ -86,7 +86,6 @@ library
, memory >= 0.7
, monad-time >= 0.1
, template-haskell >= 2.4
, unordered-containers == 0.2.*
, time >= 1.5
, network-uri >= 2.6
, QuickCheck >= 2.9
Expand Down Expand Up @@ -122,7 +121,6 @@ test-suite tests
, cryptonite
, memory
, monad-time
, unordered-containers
, time
, network-uri
, x509
Expand Down
35 changes: 22 additions & 13 deletions src/Crypto/JOSE/Header.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ import Data.Proxy (Proxy(..))
import Control.Lens (Lens', Getter, review, to)
import Data.Aeson (FromJSON(..), Object, Value, encode, object)
import Data.Aeson.Types (Pair, Parser)
import qualified Data.Aeson.Key as Key
import qualified Data.Aeson.KeyMap as M
import qualified Data.ByteString.Lazy as L
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T

import qualified Crypto.JOSE.JWA.JWS as JWA.JWS
Expand Down Expand Up @@ -228,14 +229,16 @@ headerOptional
-> Maybe Object
-> Maybe Object
-> Parser (Maybe (HeaderParam p a))
headerOptional k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show k
headerOptional kText hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show kText
(Just v, Nothing) -> Just . HeaderParam getProtected <$> parseJSON v
(Nothing, Just v) -> maybe
(fail "unprotected header not supported")
(\p -> Just . HeaderParam p <$> parseJSON v)
getUnprotected
(Nothing, Nothing) -> pure Nothing
where
k = Key.fromText kText

-- | Parse an optional parameter that, if present, MUST be carried
-- in the protected header.
Expand All @@ -246,11 +249,13 @@ headerOptionalProtected
-> Maybe Object
-> Maybe Object
-> Parser (Maybe a)
headerOptionalProtected k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show k
(_, Just _) -> fail $ "header must be protected: " ++ show k
headerOptionalProtected kText hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show kText
(_, Just _) -> fail $ "header must be protected: " ++ show kText
(Just v, _) -> Just <$> parseJSON v
_ -> pure Nothing
where
k = Key.fromText kText

-- | Parse a required parameter that may be carried in either
-- the protected or the unprotected header.
Expand All @@ -261,14 +266,16 @@ headerRequired
-> Maybe Object
-> Maybe Object
-> Parser (HeaderParam p a)
headerRequired k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show k
headerRequired kText hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show kText
(Just v, Nothing) -> HeaderParam getProtected <$> parseJSON v
(Nothing, Just v) -> maybe
(fail "unprotected header not supported")
(\p -> HeaderParam p <$> parseJSON v)
getUnprotected
(Nothing, Nothing) -> fail $ "missing required header " ++ show k
where
k = Key.fromText kText

-- | Parse a required parameter that MUST be carried
-- in the protected header.
Expand All @@ -279,11 +286,13 @@ headerRequiredProtected
-> Maybe Object
-> Maybe Object
-> Parser a
headerRequiredProtected k hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show k
(_, Just _) -> fail $ "header must be protected: " <> show k
headerRequiredProtected kText hp hu = case (hp >>= M.lookup k, hu >>= M.lookup k) of
(Just _, Just _) -> fail $ "duplicate header " ++ show kText
(_, Just _) -> fail $ "header must be protected: " <> show kText
(Just v, _) -> parseJSON v
_ -> fail $ "missing required protected header: " <> show k
_ -> fail $ "missing required protected header: " <> show kText
where
k = Key.fromText kText


critObjectParser
Expand All @@ -292,7 +301,7 @@ critObjectParser
critObjectParser reserved exts o s
| s `elem` reserved = Fail.fail "crit key is reserved"
| s `notElem` exts = Fail.fail "crit key is not understood"
| not (s `M.member` o) = Fail.fail "crit key is not present in headers"
| not (Key.fromText s `M.member` o) = Fail.fail "crit key is not present in headers"
| otherwise = pure s

-- | Parse a "crit" header param
Expand Down
3 changes: 1 addition & 2 deletions src/Crypto/JOSE/JWA/JWE.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ module Crypto.JOSE.JWA.JWE where

import Data.Maybe (catMaybes)

import qualified Data.HashMap.Strict as M

import Crypto.JOSE.JWK
import Crypto.JOSE.TH
import Crypto.JOSE.Types
import Crypto.JOSE.Types.Internal (insertToObject)

import Data.Aeson
import qualified Data.Aeson.KeyMap as M


-- | RFC 7518 §4. Cryptographic Algorithms for Key Management
Expand Down
2 changes: 1 addition & 1 deletion src/Crypto/JOSE/JWA/JWK.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ import qualified Crypto.PubKey.Ed25519 as Ed25519
import qualified Crypto.PubKey.Curve25519 as Curve25519
import Crypto.Random
import Data.Aeson
import qualified Data.Aeson.KeyMap as M
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import qualified Data.HashMap.Strict as M
import Data.List.NonEmpty (NonEmpty)
import qualified Data.Text as T
import Data.X509 as X509
Expand Down
2 changes: 1 addition & 1 deletion src/Crypto/JOSE/JWS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ import Control.Lens.Cons.Extras (recons)
import Control.Monad.Error.Lens (throwing, throwing_)
import Control.Monad.Except (MonadError, unless)
import Data.Aeson
import qualified Data.Aeson.KeyMap as M
import qualified Data.ByteString as B
import qualified Data.HashMap.Strict as M
import qualified Data.Set as S
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
Expand Down
4 changes: 2 additions & 2 deletions src/Crypto/JOSE/Types/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ import Control.Lens
import Control.Lens.Cons.Extras
import Crypto.Number.Basic (log2)
import Data.Aeson.Types
import qualified Data.Aeson.KeyMap as M
import qualified Data.ByteString as B
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Base64.URL as B64U
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as E

-- | Insert the given key and value into the given @Value@, which
-- is expected to be an @Object@. If the value is not an @Object@,
-- this is a no-op.
--
insertToObject :: ToJSON v => T.Text -> v -> Value -> Value
insertToObject :: ToJSON v => Key -> v -> Value -> Value
insertToObject k v (Object o) = Object $ M.insert k (toJSON v) o
insertToObject _ _ v = v

Expand Down
61 changes: 45 additions & 16 deletions src/Crypto/JWT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ import Data.Foldable (traverse_)
import Data.Functor.Identity
import Data.Maybe
import qualified Data.String
import Data.Semigroup ((<>))

import Control.Lens (
makeClassy, makeClassyPrisms, makePrisms,
Expand All @@ -132,7 +133,10 @@ import Control.Monad.Error.Lens (throwing, throwing_)
import Control.Monad.Except (MonadError)
import Control.Monad.Reader (ReaderT, asks, runReaderT)
import Data.Aeson
import qualified Data.HashMap.Strict as M
import qualified Data.Aeson.Key as Key
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds)
Expand Down Expand Up @@ -251,7 +255,7 @@ data ClaimsSet = ClaimsSet
, _claimNbf :: Maybe NumericDate
, _claimIat :: Maybe NumericDate
, _claimJti :: Maybe T.Text
, _unregisteredClaims :: M.HashMap T.Text Value
, _unregisteredClaims :: M.Map T.Text Value
}
deriving (Eq, Show)

Expand Down Expand Up @@ -320,7 +324,7 @@ claimJti f h@ClaimsSet{ _claimJti = a} =
fmap (\a' -> h { _claimJti = a' }) (f a)

-- | Claim Names can be defined at will by those using JWTs.
unregisteredClaims :: Lens' ClaimsSet (M.HashMap T.Text Value)
unregisteredClaims :: Lens' ClaimsSet (M.Map T.Text Value)
unregisteredClaims f h@ClaimsSet{ _unregisteredClaims = a} =
fmap (\a' -> h { _unregisteredClaims = a' }) (f a)

Expand All @@ -333,9 +337,30 @@ emptyClaimsSet = ClaimsSet n n n n n n n M.empty where n = Nothing
addClaim :: T.Text -> Value -> ClaimsSet -> ClaimsSet
addClaim k v = over unregisteredClaims (M.insert k v)

filterUnregistered :: M.HashMap T.Text Value -> M.HashMap T.Text Value
filterUnregistered = M.filterWithKey (\k _ -> k `notElem` registered) where
registered = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"]
registeredClaims :: S.Set T.Text
registeredClaims = S.fromDistinctAscList
[ "aud"
, "exp"
, "iat"
, "iss"
, "jti"
, "nbf"
, "sub"
]

filterUnregistered :: M.Map T.Text Value -> M.Map T.Text Value
filterUnregistered m =
#if MIN_VERSION_containers(0,5,8)
m `M.withoutKeys` registeredClaims
#else
m `M.difference` M.fromSet (const ()) registeredClaims
#endif

toKeyMap :: M.Map T.Text Value -> KeyMap.KeyMap Value
toKeyMap = KeyMap.fromMap . M.mapKeysMonotonic Key.fromText

fromKeyMap :: KeyMap.KeyMap Value -> M.Map T.Text Value
fromKeyMap = M.mapKeysMonotonic Key.toText . KeyMap.toMap

instance FromJSON ClaimsSet where
parseJSON = withObject "JWT Claims Set" (\o -> ClaimsSet
Expand All @@ -346,18 +371,22 @@ instance FromJSON ClaimsSet where
<*> o .:? "nbf"
<*> o .:? "iat"
<*> o .:? "jti"
<*> pure (filterUnregistered o))
<*> pure (filterUnregistered . fromKeyMap $ o)
)

instance ToJSON ClaimsSet where
toJSON (ClaimsSet iss sub aud exp' nbf iat jti o) = object $ catMaybes [
fmap ("iss" .=) iss
, fmap ("sub" .=) sub
, fmap ("aud" .=) aud
, fmap ("exp" .=) exp'
, fmap ("nbf" .=) nbf
, fmap ("iat" .=) iat
, fmap ("jti" .=) jti
] ++ M.toList (filterUnregistered o)
toJSON (ClaimsSet iss sub aud exp' nbf iat jti o) = Object $
( KeyMap.fromMap . M.fromDistinctAscList $ catMaybes
[ fmap ("aud" .=) aud
, fmap ("exp" .=) exp'
, fmap ("iat" .=) iat
, fmap ("iss" .=) iss
, fmap ("jti" .=) jti
, fmap ("nbf" .=) nbf
, fmap ("sub" .=) sub
]
)
<> toKeyMap (filterUnregistered o)


data JWTValidationSettings = JWTValidationSettings
Expand Down
2 changes: 0 additions & 2 deletions test/JWT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import Control.Monad.State (execState)
import Control.Monad.Time (MonadTime(..))
import Data.Aeson hiding ((.=))
import Data.Functor.Identity (runIdentity)
import Data.HashMap.Strict (insert)
import qualified Data.Set as S
import Data.Time
import Network.URI (parseURI)
Expand All @@ -49,7 +48,6 @@ exampleClaimsSet :: ClaimsSet
exampleClaimsSet = emptyClaimsSet
& claimIss .~ preview stringOrUri ("joe" :: String)
& claimExp .~ intDate "2011-03-22 18:43:00"
& over unregisteredClaims (insert "http://example.com/is_root" (Bool True))
& addClaim "http://example.com/is_root" (Bool True)

#if ! MIN_VERSION_monad_time(0,3,0)
Expand Down

0 comments on commit ffc0cde

Please sign in to comment.