Skip to content

Commit

Permalink
Replace partial functions skFromBytes and vkFromBytes with total func…
Browse files Browse the repository at this point in the history
…tion versions
  • Loading branch information
newhoggy committed Oct 28, 2020
1 parent 0c2fd00 commit ea1647e
Showing 1 changed file with 30 additions and 47 deletions.
77 changes: 30 additions & 47 deletions cardano-crypto-praos/src/Cardano/Crypto/VRF/Praos.hs
@@ -1,12 +1,10 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
Expand Down Expand Up @@ -74,7 +72,7 @@ import Control.Monad (void)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Coerce (coerce)
import Data.Maybe (isJust)
import Data.Maybe (isJust, fromMaybe)
import Data.Proxy (Proxy (..))
import Foreign.C.Types
import Foreign.ForeignPtr
Expand Down Expand Up @@ -208,7 +206,7 @@ genSeed = do
copyFromByteString :: Ptr a -> ByteString -> Int -> IO ()
copyFromByteString ptr bs lenExpected =
BS.useAsCStringLen bs $ \(cstr, lenActual) ->
if (lenActual >= lenExpected) then
if lenActual >= lenExpected then
copyBytes (castPtr ptr) cstr lenExpected
else
error $ "Invalid input size, expected at least " <> show lenExpected <> ", but got " <> show lenActual
Expand Down Expand Up @@ -264,7 +262,7 @@ instance ToCBOR Proof where
encodedSizeExpr (\_ -> fromIntegral certSizeVRF) (Proxy :: Proxy ByteString)

instance FromCBOR Proof where
fromCBOR = proofFromBytes <$> fromCBOR
fromCBOR = fromCBOR >>= either fail return . proofFromBytes


instance Show SignKey where
Expand All @@ -279,8 +277,7 @@ instance ToCBOR SignKey where
encodedSizeExpr (\_ -> fromIntegral signKeySizeVRF) (Proxy :: Proxy ByteString)

instance FromCBOR SignKey where
fromCBOR = skFromBytes <$> fromCBOR

fromCBOR = fromCBOR >>= either fail return . skFromBytes

instance Show VerKey where
show = show . vkBytes
Expand All @@ -294,7 +291,7 @@ instance ToCBOR VerKey where
encodedSizeExpr (\_ -> fromIntegral verKeySizeVRF) (Proxy :: Proxy ByteString)

instance FromCBOR VerKey where
fromCBOR = vkFromBytes <$> fromCBOR
fromCBOR = fromCBOR >>= either fail return . vkFromBytes

-- | Allocate a Verification Key and attach a finalizer. The allocated memory will
-- not be initialized.
Expand All @@ -309,44 +306,37 @@ mkSignKey = fmap SignKey $ newForeignPtr finalizerFree =<< mallocBytes signKeySi
-- | Allocate a Proof and attach a finalizer. The allocated memory will
-- not be initialized.
mkProof :: IO Proof
mkProof = fmap Proof $ newForeignPtr finalizerFree =<< mallocBytes (certSizeVRF)

proofFromBytes :: ByteString -> Proof
proofFromBytes bs
| BS.length bs /= certSizeVRF
= error "Invalid proof length"
| otherwise
= unsafePerformIO $ do
mkProof = fmap Proof $ newForeignPtr finalizerFree =<< mallocBytes certSizeVRF

proofFromBytes :: ByteString -> Either String Proof
proofFromBytes bs = case BS.length bs of
bsLen -> if bsLen /= certSizeVRF
then Left $ "Invalid proof length " <> show @Int bsLen <> ", expecting " <> show @Int certSizeVRF
else Right . unsafePerformIO $ do
proof <- mkProof
withForeignPtr (unProof proof) $ \ptr ->
copyFromByteString ptr bs certSizeVRF
return proof

skFromBytes :: ByteString -> SignKey
skFromBytes bs
| bsLen /= signKeySizeVRF
= error ("Invalid sk length " <> show @Int bsLen <> ", expecting " <> show @Int signKeySizeVRF)
| otherwise
= unsafePerformIO $ do
skFromBytes :: BS.ByteString -> Either String SignKey
skFromBytes bs = case BS.length bs of
bsLen -> if bsLen /= signKeySizeVRF
then Left $ "Invalid sk length " <> show @Int bsLen <> ", expecting " <> show @Int signKeySizeVRF
else Right . unsafePerformIO $ do
sk <- mkSignKey
withForeignPtr (unSignKey sk) $ \ptr ->
copyFromByteString ptr bs signKeySizeVRF
return sk
where
bsLen = BS.length bs

vkFromBytes :: ByteString -> VerKey
vkFromBytes bs
| BS.length bs /= verKeySizeVRF
= error ("Invalid pk length " <> show @Int bsLen <> ", expecting " <> show @Int verKeySizeVRF)
| otherwise
= unsafePerformIO $ do

vkFromBytes :: BS.ByteString -> Either String VerKey
vkFromBytes bs = case BS.length bs of
bsLen -> if bsLen /= verKeySizeVRF
then Left $ "Invalid pk length " <> show @Int bsLen <> ", expecting " <> show @Int verKeySizeVRF
else return . unsafePerformIO $ do
pk <- mkVerKey
withForeignPtr (unVerKey pk) $ \ptr ->
copyFromByteString ptr bs verKeySizeVRF
return pk
where
bsLen = BS.length bs

-- | Allocate an Output and attach a finalizer. The allocated memory will
-- not be initialized.
Expand Down Expand Up @@ -450,8 +440,8 @@ instance VRFAlgorithm PraosVRF where

evalVRF = \_ msg (SignKeyPraosVRF sk) ->
let msgBS = getSignableRepresentation msg
proof = maybe (error "Invalid Key") id $ prove sk msgBS
output = maybe (error "Invalid Proof") id $ outputFromProof proof
proof = fromMaybe (error "Invalid Key") $ prove sk msgBS
output = fromMaybe (error "Invalid Proof") $ outputFromProof proof
in output `seq` proof `seq`
(OutputVRF (outputBytes output), CertPraosVRF proof)

Expand All @@ -469,17 +459,10 @@ instance VRFAlgorithm PraosVRF where
rawSerialiseVerKeyVRF (VerKeyPraosVRF pk) = vkBytes pk
rawSerialiseSignKeyVRF (SignKeyPraosVRF sk) = skBytes sk
rawSerialiseCertVRF (CertPraosVRF proof) = proofBytes proof
rawDeserialiseVerKeyVRF = fmap (VerKeyPraosVRF . vkFromBytes) . assertLength verKeySizeVRF
rawDeserialiseSignKeyVRF = fmap (SignKeyPraosVRF . skFromBytes) . assertLength signKeySizeVRF
rawDeserialiseCertVRF = fmap (CertPraosVRF . proofFromBytes) . assertLength certSizeVRF

rawDeserialiseVerKeyVRF = either (const Nothing) (Just . VerKeyPraosVRF) . vkFromBytes
rawDeserialiseSignKeyVRF = either (const Nothing) (Just . SignKeyPraosVRF) . skFromBytes
rawDeserialiseCertVRF = either (const Nothing) (Just . CertPraosVRF) . proofFromBytes
sizeVerKeyVRF _ = fromIntegral verKeySizeVRF
sizeSignKeyVRF _ = fromIntegral signKeySizeVRF
sizeCertVRF _ = fromIntegral certSizeVRF

assertLength :: Int -> ByteString -> Maybe ByteString
assertLength l bs
| BS.length bs == l
= Just bs
| otherwise
= Nothing

0 comments on commit ea1647e

Please sign in to comment.