Skip to content

Commit

Permalink
add domain compress
Browse files Browse the repository at this point in the history
  • Loading branch information
yihuang committed Oct 23, 2011
1 parent f16d70a commit 89d6ab5
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 31 deletions.
32 changes: 20 additions & 12 deletions Network/DNS/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
module Network.DNS.Query (composeQuery, composeDNSFormat) where

import qualified Data.ByteString.Lazy.Char8 as BL (ByteString)
import qualified Data.ByteString as BS (unpack)
import qualified Data.ByteString.Char8 as BS (length, split, null)
import Blaze.ByteString.Builder.ByteString (writeByteString)
import qualified Data.ByteString.Char8 as BS (length, null, break, drop)
import Network.DNS.StateBinary
import Network.DNS.Internal
import Data.Monoid
Expand Down Expand Up @@ -109,7 +107,7 @@ encodeRDATA rd = case rd of
(RD_CNAME dom) -> encodeDomain dom
(RD_PTR dom) -> encodeDomain dom
(RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom]
(RD_TXT txt) -> writeByteString txt
(RD_TXT txt) -> putByteString txt
(RD_OTH bytes) -> mconcat $ map putInt8 bytes
(RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $
[ encodeDomain d1
Expand All @@ -130,13 +128,23 @@ encodeRDATA rd = case rd of
----------------------------------------------------------------

encodeDomain :: Domain -> SPut
encodeDomain dom = foldr ((+++) . encodeSubDomain) (put8 0) $ zip ls ss
encodeDomain dom | BS.null dom = put8 0
encodeDomain dom = do
mpos <- wsPop dom
cur <- gets wsPosition
case mpos of
Just pos -> encodePointer pos
Nothing -> wsPush dom cur >>
mconcat [ encodePartialDomain hd
, encodeDomain tl
]
where
ss = filter (not . BS.null) $ BS.split '.' dom
ls = map BS.length ss
(hd, tl') = BS.break (=='.') dom
tl = if BS.null tl' then tl' else BS.drop 1 tl'

encodeSubDomain :: (Int, Domain) -> SPut
encodeSubDomain (len,sub) = putInt8 len
+++ foldr ((+++) . put8) mempty ss
where
ss = BS.unpack sub
encodePointer :: Int -> SPut
encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w

encodePartialDomain :: Domain -> SPut
encodePartialDomain sub = putInt8 (BS.length sub)
+++ putByteString sub
60 changes: 51 additions & 9 deletions Network/DNS/StateBinary.hs
Original file line number Diff line number Diff line change
@@ -1,43 +1,85 @@
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module Network.DNS.StateBinary where

import Blaze.ByteString.Builder
import Control.Applicative
import Control.Monad.State
import Data.Monoid
import Data.Attoparsec
import Data.Attoparsec.Enumerator
import qualified Data.Attoparsec.Lazy as AL
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (unpack)
import qualified Data.ByteString as BS (unpack, length)
import qualified Data.ByteString.Lazy as BL (ByteString)
import Data.Enumerator (Iteratee)
import Data.Int
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM (insert, lookup, empty)
import Data.Map (Map)
import qualified Data.Map as M (insert, lookup, empty)
import Data.Word
import Network.DNS.Types
import Prelude hiding (lookup, take)

----------------------------------------------------------------

type SPut = Write
type SPut = State WState Write

data WState = WState {
wsDomain :: Map Domain Int
, wsPosition :: Int
}

initialWState :: WState
initialWState = WState M.empty 0

instance Monoid SPut where
mempty = return mempty
mappend a b = mconcat <$> sequence [a, b]

put8 :: Word8 -> SPut
put8 = writeWord8
put8 = fixedSized 1 writeWord8

put16 :: Word16 -> SPut
put16 = writeWord16be
put16 = fixedSized 2 writeWord16be

put32 :: Word32 -> SPut
put32 = writeWord32be
put32 = fixedSized 4 writeWord32be

putInt8 :: Int -> SPut
putInt8 = writeInt8 . fromIntegral
putInt8 = fixedSized 1 (writeInt8 . fromIntegral)

putInt16 :: Int -> SPut
putInt16 = writeInt16be . fromIntegral
putInt16 = fixedSized 2 (writeInt16be . fromIntegral)

putInt32 :: Int -> SPut
putInt32 = writeInt32be . fromIntegral
putInt32 = fixedSized 4 (writeInt32be . fromIntegral)

putByteString :: ByteString -> SPut
putByteString = writeSized BS.length writeByteString

addPositionW :: Int -> State WState ()
addPositionW n = do
(WState m cur) <- get
put $ WState m (cur+n)

fixedSized :: Int -> (a -> Write) -> a -> SPut
fixedSized n f a = do addPositionW n
return (f a)

writeSized :: Show a => (a -> Int) -> (a -> Write) -> a -> SPut
writeSized n f a = do addPositionW (n a)
return (f a)

wsPop :: Domain -> State WState (Maybe Int)
wsPop dom = do
doms <- gets wsDomain
return $ M.lookup dom doms

wsPush :: Domain -> Int -> State WState ()
wsPush dom pos = do
(WState m cur) <- get
put $ WState (M.insert dom pos m) cur

----------------------------------------------------------------

Expand Down Expand Up @@ -122,4 +164,4 @@ runSGet :: SGet a -> BL.ByteString -> Either String (a, PState)
runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs

runSPut :: SPut -> BL.ByteString
runSPut = toLazyByteString . fromWrite
runSPut = toLazyByteString . fromWrite . flip evalState initialWState
20 changes: 10 additions & 10 deletions TestProtocol.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,35 @@ import Test.HUnit hiding (Test)
tests :: [Test]
tests =
[ testGroup "Test case"
[ testCase "QueryA" (test_Format queryA)
, testCase "QueryAAAA" (test_Format queryAAAA)
, testCase "ResponseA" (test_Format responseA)
[ testCase "QueryA" (test_Format testQueryA)
, testCase "QueryAAAA" (test_Format testQueryAAAA)
, testCase "ResponseA" (test_Format $ testResponseA)
]
]

defaultHeader :: DNSHeader
defaultHeader = header defaultQuery

queryA :: DNSFormat
queryA = defaultQuery
testQueryA :: DNSFormat
testQueryA = defaultQuery
{ header = defaultHeader
{ identifier = 1000
, qdCount = 1
}
, question = [makeQuestion "www.mew.org." A]
}

queryAAAA :: DNSFormat
queryAAAA = defaultQuery
testQueryAAAA :: DNSFormat
testQueryAAAA = defaultQuery
{ header = defaultHeader
{ identifier = 1000
, qdCount = 1
}
, question = [makeQuestion "www.mew.org." AAAA]
}

responseA :: DNSFormat
responseA = DNSFormat { header = DNSHeader { identifier = 61046
testResponseA :: DNSFormat
testResponseA = DNSFormat { header = DNSHeader { identifier = 61046
, flags = DNSFlags { qOrR = QR_Response
, opcode = OP_STD
, authAnswer = False
Expand Down Expand Up @@ -157,7 +157,7 @@ test_Format fmt = do
assertEqual "fail" fmt fmt'
where
bs = composeDNSFormat fmt
result = runResponse_ bs
result = runDNSFormat_ bs

main :: IO ()
main = defaultMain tests

0 comments on commit 89d6ab5

Please sign in to comment.