Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to strict ByteString #59

Merged
merged 1 commit into from Jul 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 11 additions & 10 deletions Network/DNS/Decode.hs
Expand Up @@ -16,10 +16,11 @@ import Control.Monad (replicateM)
import Control.Monad.Trans.Resource (ResourceT, runResourceT)
import qualified Control.Exception as ControlException
import Data.Bits ((.&.), shiftR, testBit)
import Data.Char (ord)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.Conduit (($$), ($$+), ($$+-), (=$), Source)
import Data.Conduit.Network (sourceSocket)
import qualified Data.Conduit.Binary as CB
Expand Down Expand Up @@ -58,7 +59,7 @@ receive = receiveDNSFormat . sourceSocket
receiveVC :: Socket -> IO DNSMessage
receiveVC sock = runResourceT $ do
(src, lenbytes) <- sourceSocket sock $$+ CB.take 2
let len = case map fromIntegral $ BL.unpack lenbytes of
let len = case map ord $ LBS.unpack lenbytes of
[hi, lo] -> 256 * hi + lo
_ -> 0
fmap fst (src $$+- CB.isolate len =$ sinkSGet getResponse)
Expand All @@ -67,33 +68,33 @@ receiveVC sock = runResourceT $ do

-- | Parsing DNS data.

decode :: BL.ByteString -> Either String DNSMessage
decode :: ByteString -> Either String DNSMessage
decode bs = fst <$> runSGet getResponse bs

-- | Parse many length-encoded DNS records, for example, from TCP traffic.

decodeMany :: BL.ByteString -> Either String ([DNSMessage], BL.ByteString)
decodeMany :: ByteString -> Either String ([DNSMessage], ByteString)
decodeMany bs = do
((bss, _), leftovers) <- runSGetWithLeftovers lengthEncoded bs
msgs <- mapM decode bss
return (msgs, leftovers)
where
-- Read a list of length-encoded lazy bytestrings
lengthEncoded :: SGet [BL.ByteString]
lengthEncoded :: SGet [ByteString]
lengthEncoded = many $ do
len <- getInt16
fmap BL.fromStrict (getNByteString len)
getNByteString len

decodeDNSFlags :: BL.ByteString -> Either String DNSFlags
decodeDNSFlags :: ByteString -> Either String DNSFlags
decodeDNSFlags bs = fst <$> runSGet getDNSFlags bs

decodeDNSHeader :: BL.ByteString -> Either String DNSHeader
decodeDNSHeader :: ByteString -> Either String DNSHeader
decodeDNSHeader bs = fst <$> runSGet getHeader bs

decodeDomain :: BL.ByteString -> Either String Domain
decodeDomain :: ByteString -> Either String Domain
decodeDomain bs = fst <$> runSGet getDomain bs

decodeResourceRecord :: BL.ByteString -> Either String ResourceRecord
decodeResourceRecord :: ByteString -> Either String ResourceRecord
decodeResourceRecord bs = fst <$> runSGet getResourceRecord bs

----------------------------------------------------------------
Expand Down
8 changes: 4 additions & 4 deletions Network/DNS/Encode.hs
Expand Up @@ -17,8 +17,8 @@ import Data.Binary (Word16)
import Data.Bits ((.|.), bit, shiftL, setBit)
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import Data.ByteString.Lazy.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.ByteString.Char8 (ByteString)
import Data.IP (IP(..),fromIPv4, fromIPv6b)
import Data.List (dropWhileEnd)
import Data.Monoid ((<>))
Expand Down Expand Up @@ -68,7 +68,7 @@ encode = runSPut . putDNSMessage

encodeVC :: ByteString -> ByteString
encodeVC query =
let len = BB.toLazyByteString $ BB.int16BE $ fromIntegral $ BL.length query
let len = LBS.toStrict . BB.toLazyByteString $ BB.int16BE $ fromIntegral $ BS.length query
in len <> query

encodeDNSFlags :: DNSFlags -> ByteString
Expand Down Expand Up @@ -162,7 +162,7 @@ putResourceRecord rr =
putResourceRData rd = do
addPositionW 2 -- "simulate" putInt16
rDataBuilder <- putRData rd
let rdataLength = fromIntegral . BL.length . BB.toLazyByteString $ rDataBuilder
let rdataLength = fromIntegral . LBS.length . BB.toLazyByteString $ rDataBuilder
let rlenBuilder = BB.int16BE rdataLength
return $ rlenBuilder <> rDataBuilder

Expand Down
3 changes: 1 addition & 2 deletions Network/DNS/Internal.hs
Expand Up @@ -4,7 +4,6 @@ module Network.DNS.Internal where

import Control.Exception (Exception)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as L
import Data.IP (IP, IPv4, IPv6)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
Expand All @@ -16,7 +15,7 @@ import Data.Word (Word8, Word16, Word32)
type Domain = ByteString

-- | Return type of composeQuery from Encode, needed in Resolver
type Query = L.ByteString
type Query = ByteString

----------------------------------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions Network/DNS/Resolver.hs
Expand Up @@ -43,10 +43,10 @@ import Control.Applicative ((<$>), (<*>), pure)

#if mingw32_HOST_OS == 1
import Network.Socket (send)
import qualified Data.ByteString.Lazy.Char8 as LB
import qualified Data.ByteString.Char8 as BS
import Control.Monad (when)
#else
import Network.Socket.ByteString.Lazy (sendAll)
import Network.Socket.ByteString (sendAll)
#endif

----------------------------------------------------------------
Expand Down Expand Up @@ -441,7 +441,7 @@ tcpLookup query peer tm (Just vc) = do
Just res -> return $ Right res

#if mingw32_HOST_OS == 1
-- Windows does not support sendAll in Network.ByteString.Lazy.
-- Windows does not support sendAll in Network.ByteString.
-- This implements sendAll with Haskell Strings.
sendAll sock bs = do
sent <- send sock (LB.unpack bs)
Expand Down
23 changes: 12 additions & 11 deletions Network/DNS/StateBinary.hs
Expand Up @@ -5,13 +5,13 @@ import Control.Monad.State (State, StateT)
import qualified Control.Monad.State as ST
import Control.Monad.Trans.Resource (ResourceT)
import qualified Data.Attoparsec.ByteString as A
import qualified Data.Attoparsec.ByteString.Lazy as AL

import qualified Data.Attoparsec.Types as T
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.Conduit (Sink)
import Data.Conduit.Attoparsec (sinkParser)
import Data.IntMap (IntMap)
Expand Down Expand Up @@ -165,15 +165,16 @@ initialState = PState IM.empty 0
sinkSGet :: SGet a -> Sink ByteString (ResourceT IO) (a, PState)
sinkSGet parser = sinkParser (ST.runStateT parser initialState)

runSGet :: SGet a -> BL.ByteString -> Either String (a, PState)
runSGet parser bs = AL.eitherResult $ AL.parse (ST.runStateT parser initialState) bs
runSGet :: SGet a -> ByteString -> Either String (a, PState)
runSGet parser bs = A.eitherResult $ A.parse (ST.runStateT parser initialState) bs

runSGetWithLeftovers :: SGet a -> BL.ByteString -> Either String ((a, PState), BL.ByteString)
runSGetWithLeftovers parser bs = toResult $ AL.parse (ST.runStateT parser initialState) bs
runSGetWithLeftovers :: SGet a -> ByteString -> Either String ((a, PState), ByteString)
runSGetWithLeftovers parser bs = toResult $ A.parse (ST.runStateT parser initialState) bs
where
toResult :: AL.Result r -> Either String (r, BL.ByteString)
toResult (AL.Done i r) = Right (r, i)
toResult (AL.Fail _ _ err) = Left err
toResult :: A.Result r -> Either String (r, ByteString)
toResult (A.Done i r) = Right (r, i)
toResult (A.Partial f) = toResult $ f BS.empty
toResult (A.Fail _ _ err) = Left err

runSPut :: SPut -> BL.ByteString
runSPut = BB.toLazyByteString . flip ST.evalState initialWState
runSPut :: SPut -> ByteString
runSPut = LBS.toStrict . BB.toLazyByteString . flip ST.evalState initialWState
22 changes: 7 additions & 15 deletions test/DecodeSpec.hs
Expand Up @@ -6,7 +6,6 @@ import Data.ByteString.Internal (ByteString(..), unsafeCreate)
#if !MIN_VERSION_bytestring(0,10,0)
import qualified Data.ByteString as BS
#endif
import qualified Data.ByteString.Lazy as BL
import Data.Word8
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (plusPtr)
Expand All @@ -16,23 +15,23 @@ import Test.Hspec

----------------------------------------------------------------

test_doublePointer :: BL.ByteString
test_doublePointer :: ByteString
test_doublePointer = "f7eb8500000100010007000404736563330561706e696303636f6d0000010001c00c0001000100001c200004ca0c1c8cc0110002000100001c20000f036e73310561706e6963036e657400c0300002000100001c200006036e7333c040c0300002000100001c200006036e7334c040c0300002000100001c20001004736563310561706e696303636f6d00c0300002000100001c20001704736563310761757468646e730472697065036e657400c0300002000100001c20001004736563320561706e696303636f6d00c0300002000100001c2000070473656333c0bfc07b0001000100001c200004ca0c1d3bc07b001c000100001c20001020010dc02001000a4608000000000059c0ba0001000100001c200004ca0c1d3cc0d6001c000100001c20001020010dc0000100004777000000000140"
-- DNSMessage {header = DNSHeader {identifier = 63467, flags = DNSFlags {qOrR = QR_Response, opcode = OP_STD, authAnswer = True, trunCation = False, recDesired = True, recAvailable = False, rcode = NoErr, authenData = False}}, question = [Question {qname = "sec3.apnic.com.", qtype = A}], answer = [ResourceRecord {rrname = "sec3.apnic.com.", rrtype = A, rrttl = 7200, rdata = 202.12.28.140}], authority = [ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = ns1.apnic.net.},ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = ns3.apnic.net.},ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = ns4.apnic.net.},ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = sec1.apnic.com.},ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = sec1.authdns.ripe.net.},ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = sec2.apnic.com.},ResourceRecord {rrname = "apnic.com.", rrtype = NS, rrttl = 7200, rdata = sec3.apnic.com.}], additional = [ResourceRecord {rrname = "sec1.apnic.com.", rrtype = A, rrttl = 7200, rdata = 202.12.29.59},ResourceRecord {rrname = "sec1.apnic.com.", rrtype = AAAA, rrttl = 7200, rdata = 2001:dc0:2001:a:4608::59},ResourceRecord {rrname = "sec2.apnic.com.", rrtype = A, rrttl = 7200, rdata = 202.12.29.60},ResourceRecord {rrname = "sec3.apnic.com.", rrtype = AAAA, rrttl = 7200, rdata = 2001:dc0:1:0:4777::140}]})

test_txt :: BL.ByteString
test_txt :: ByteString
test_txt = "463181800001000100000000076e69636f6c6173046b766462076e647072696d6102696f0000100001c00c0010000100000e10000c6e69636f6c61732e6b766462"
-- DNSMessage {header = DNSHeader {identifier = 17969, flags = DNSFlags {qOrR = QR_Response, opcode = OP_STD, authAnswer = False, trunCation = False, recDesired = True, recAvailable = True, rcode = NoErr, authenData = False}}
-- , question = [Question {qname = "nicolas.kvdb.ndprima.io.", qtype = TXT}]
-- , answer = [ResourceRecord {rrname = "nicolas.kvdb.ndprima.io.", rrtype = TXT, rrttl = 3600, rdata = icolas.kvdb}]
-- , authority = []
-- , additional = []})

test_dname :: BL.ByteString
test_dname :: ByteString
test_dname = "b3c0818000010005000200010377777706376b616e616c02636f02696c0000010001c0100027000100000003000c0769737261656c3702727500c00c0005000100000003000603777777c02ec046000500010000255b0002c02ec02e000100010000003d000451daf938c02e000100010000003d0004c33ce84ac02e000200010005412b000c036e7332026137036f726700c02e000200010005412b0006036e7331c08a0000291000000000000000"
-- DNSMessage {header = DNSHeader {identifier = 46016, flags = DNSFlags {qOrR = QR_Response, opcode = OP_STD, authAnswer = False, trunCation = False, recDesired = True, recAvailable = True, rcode = NoErr, authenData = False}}, question = [Question {qname = "www.7kanal.co.il.", qtype = A}], answer = [ResourceRecord {rrname = "7kanal.co.il.", rrtype = DNAME, rrttl = 3, rdata = israel7.ru.},ResourceRecord {rrname = "www.7kanal.co.il.", rrtype = CNAME, rrttl = 3, rdata = www.israel7.ru.},ResourceRecord {rrname = "www.israel7.ru.", rrtype = CNAME, rrttl = 9563, rdata = israel7.ru.},ResourceRecord {rrname = "israel7.ru.", rrtype = A, rrttl = 61, rdata = 81.218.249.56},ResourceRecord {rrname = "israel7.ru.", rrtype = A, rrttl = 61, rdata = 195.60.232.74}], authority = [ResourceRecord {rrname = "israel7.ru.", rrtype = NS, rrttl = 344363, rdata = ns2.a7.org.},ResourceRecord {rrname = "israel7.ru.", rrtype = NS, rrttl = 344363, rdata = ns1.a7.org.}], additional = [OptRecord {orudpsize = 4096, ordnssecok = False, orversion = 0, rdata = []}]})

test_mx :: BL.ByteString
test_mx :: ByteString
test_mx = "f03681800001000100000001036d6577036f726700000f0001c00c000f000100000df10009000a046d61696cc00c0000291000000000000000"
-- DNSMessage {header = DNSHeader {identifier = 61494, flags = DNSFlags {qOrR = QR_Response, opcode = OP_STD, authAnswer = False, trunCation = False, recDesired = True, recAvailable = True, rcode = NoErr, authenData = False}}
-- , question = [Question {qname = "mew.org.", qtype = MX}]
Expand All @@ -55,7 +54,7 @@ spec = do
tripleDecodeTest test_mx


tripleDecodeTest :: BL.ByteString -> IO ()
tripleDecodeTest :: ByteString -> IO ()
tripleDecodeTest hexbs =
ecase (decode $ fromHexString hexbs) fail $ \ x1 ->
ecase (decode $ encode x1) fail $ \ x2 ->
Expand All @@ -68,15 +67,8 @@ ecase (Right b) _ g = g b

----------------------------------------------------------------

fromHexString :: BL.ByteString -> BL.ByteString
#if MIN_VERSION_bytestring(0,10,0)
fromHexString = BL.fromStrict . fromHexString' . BL.toStrict
#else
fromHexString = BL.pack . BS.unpack . fromHexString' . BS.pack . BL.unpack
#endif

fromHexString' :: ByteString -> ByteString
fromHexString' (PS fptr off len) = unsafeCreate size $ \dst ->
fromHexString :: ByteString -> ByteString
fromHexString (PS fptr off len) = unsafeCreate size $ \dst ->
withForeignPtr fptr $ \src -> go (src `plusPtr` off) dst 0
where
size = len `div` 2
Expand Down