Skip to content

Commit

Permalink
Merge pull request #3 from yihuang/master
Browse files Browse the repository at this point in the history
Make this library full capable of decoding and encoding.
  • Loading branch information
kazu-yamamoto committed Oct 24, 2011
2 parents 1cc2f6a + 89d6ab5 commit 04c4afa
Show file tree
Hide file tree
Showing 9 changed files with 494 additions and 50 deletions.
46 changes: 45 additions & 1 deletion Network/DNS/Internal.hs
Expand Up @@ -151,7 +151,15 @@ defaultQuery :: DNSFormat
defaultQuery = DNSFormat {
header = DNSHeader {
identifier = 0
, flags = undefined
, flags = DNSFlags {
qOrR = QR_Query
, opcode = OP_STD
, authAnswer = False
, trunCation = False
, recDesired = True
, recAvailable = False
, rcode = NoErr
}
, qdCount = 0
, anCount = 0
, nsCount = 0
Expand All @@ -162,3 +170,39 @@ defaultQuery = DNSFormat {
, authority = []
, additional = []
}

defaultResponse :: DNSFormat
defaultResponse =
let hd = header defaultQuery
flg = flags hd
in defaultQuery {
header = hd {
flags = flg {
qOrR = QR_Response
, authAnswer = True
, recAvailable = True
}
}
}

responseA :: Int -> Question -> IPv4 -> DNSFormat
responseA ident q ip =
let hd = header defaultResponse
dom = qname q
an = ResourceRecord dom A 300 4 (RD_A ip)
in defaultResponse {
header = hd { identifier=ident, qdCount = 1, anCount = 1 }
, question = [q]
, answer = [an]
}

responseAAAA :: Int -> Question -> IPv6 -> DNSFormat
responseAAAA ident q ip =
let hd = header defaultResponse
dom = qname q
an = ResourceRecord dom AAAA 300 16 (RD_AAAA ip)
in defaultResponse {
header = hd { identifier=ident, qdCount = 1, anCount = 1 }
, question = [q]
, answer = [an]
}
124 changes: 100 additions & 24 deletions Network/DNS/Query.hs
@@ -1,19 +1,26 @@
module Network.DNS.Query (composeQuery) where
{-# LANGUAGE RecordWildCards #-}
module Network.DNS.Query (composeQuery, composeDNSFormat) where

import qualified Data.ByteString.Lazy.Char8 as BL (ByteString)
import qualified Data.ByteString as BS (unpack)
import qualified Data.ByteString.Char8 as BS (length, split, null)
import qualified Data.ByteString.Char8 as BS (length, null, break, drop)
import Network.DNS.StateBinary
import Network.DNS.Internal
import Data.Monoid
import Control.Monad.State
import Data.Bits
import Data.Word
import Data.IP

(+++) :: Monoid a => a -> a -> a
(+++) = mappend

----------------------------------------------------------------

composeDNSFormat :: DNSFormat -> BL.ByteString
composeDNSFormat fmt = runSPut (encodeDNSFormat fmt)

composeQuery :: Int -> [Question] -> BL.ByteString
composeQuery idt qs = runSPut (encodeQuery qry)
composeQuery idt qs = composeDNSFormat qry
where
hdr = header defaultQuery
qry = defaultQuery {
Expand All @@ -26,12 +33,18 @@ composeQuery idt qs = runSPut (encodeQuery qry)

----------------------------------------------------------------

encodeQuery :: DNSFormat -> SPut
encodeQuery fmt = encodeHeader hdr
+++ encodeQuestion qs
encodeDNSFormat :: DNSFormat -> SPut
encodeDNSFormat fmt = encodeHeader hdr
+++ mconcat (map encodeQuestion qs)
+++ mconcat (map encodeRR an)
+++ mconcat (map encodeRR au)
+++ mconcat (map encodeRR ad)
where
hdr = header fmt
qs = question fmt
an = answer fmt
au = authority fmt
ad = additional fmt

encodeHeader :: DNSHeader -> SPut
encodeHeader hdr = encodeIdentifier (identifier hdr)
Expand All @@ -48,27 +61,90 @@ encodeHeader hdr = encodeIdentifier (identifier hdr)
decodeArCount = putInt16

encodeFlags :: DNSFlags -> SPut
encodeFlags _ = put16 0x0100 -- xxx

encodeQuestion :: [Question] -> SPut
encodeQuestion qs = encodeDomain dom
+++ putInt16 (typeToInt typ)
+++ put16 1
encodeFlags DNSFlags{..} = put16 word
where
q = head qs
dom = qname q
typ = qtype q
word16 :: Enum a => a -> Word16
word16 = toEnum . fromEnum

set :: Word16 -> State Word16 ()
set byte = modify (.|. byte)

st :: State Word16 ()
st = sequence_
[ set (word16 rcode)
, when recAvailable $ set (bit 7)
, when recDesired $ set (bit 8)
, when trunCation $ set (bit 9)
, when authAnswer $ set (bit 10)
, set (word16 opcode `shiftL` 11)
, when (qOrR==QR_Response) $ set (bit 15)
]

word = execState st 0

encodeQuestion :: Question -> SPut
encodeQuestion Question{..} =
encodeDomain qname
+++ putInt16 (typeToInt qtype)
+++ put16 1

encodeRR :: ResourceRecord -> SPut
encodeRR ResourceRecord{..} =
mconcat
[ encodeDomain rrname
, putInt16 (typeToInt rrtype)
, put16 1
, putInt32 rrttl
, putInt16 rdlen
, encodeRDATA rdata
]

encodeRDATA :: RDATA -> SPut
encodeRDATA rd = case rd of
(RD_A ip) -> mconcat $ map putInt8 (fromIPv4 ip)
(RD_AAAA ip) -> mconcat $ map putInt16 (fromIPv6 ip)
(RD_NS dom) -> encodeDomain dom
(RD_CNAME dom) -> encodeDomain dom
(RD_PTR dom) -> encodeDomain dom
(RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom]
(RD_TXT txt) -> putByteString txt
(RD_OTH bytes) -> mconcat $ map putInt8 bytes
(RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $
[ encodeDomain d1
, encodeDomain d2
, putInt32 serial
, putInt32 refresh
, putInt32 retry
, putInt32 expire
, putInt32 min'
]
(RD_SRV prio weight port dom) -> mconcat $
[ putInt16 prio
, putInt16 weight
, putInt16 port
, encodeDomain dom
]

----------------------------------------------------------------

encodeDomain :: Domain -> SPut
encodeDomain dom = foldr ((+++) . encodeSubDomain) (put8 0) $ zip ls ss
encodeDomain dom | BS.null dom = put8 0
encodeDomain dom = do
mpos <- wsPop dom
cur <- gets wsPosition
case mpos of
Just pos -> encodePointer pos
Nothing -> wsPush dom cur >>
mconcat [ encodePartialDomain hd
, encodeDomain tl
]
where
ss = filter (not . BS.null) $ BS.split '.' dom
ls = map BS.length ss
(hd, tl') = BS.break (=='.') dom
tl = if BS.null tl' then tl' else BS.drop 1 tl'

encodeSubDomain :: (Int, Domain) -> SPut
encodeSubDomain (len,sub) = putInt8 len
+++ foldr ((+++) . put8) mempty ss
where
ss = BS.unpack sub
encodePointer :: Int -> SPut
encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w

encodePartialDomain :: Domain -> SPut
encodePartialDomain sub = putInt8 (BS.length sub)
+++ putByteString sub
4 changes: 2 additions & 2 deletions Network/DNS/Resolver.hs
Expand Up @@ -21,7 +21,7 @@ module Network.DNS.Resolver (
-- ** Intermediate data type for resolver
, ResolvSeed, makeResolvSeed
-- ** Type and function for resolver
, Resolver, withResolver
, Resolver(..), withResolver
-- ** Looking up functions
, lookup, lookupRaw
) where
Expand Down Expand Up @@ -128,7 +128,7 @@ makeAddrInfo addr = do
argument. 'withResolver' should be passed to 'forkIO'.
-}

withResolver :: ResolvSeed -> (Resolver -> IO ()) -> IO ()
withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver seed func = do
let ai = addrInfo seed
sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
Expand Down
20 changes: 14 additions & 6 deletions Network/DNS/Response.hs
@@ -1,6 +1,6 @@
{-# LANGUAGE OverloadedStrings #-}

module Network.DNS.Response (responseIter, parseResponse) where
module Network.DNS.Response (responseIter, parseResponse, runDNSFormat, runDNSFormat_) where

import Control.Applicative
import Control.Monad
Expand All @@ -12,13 +12,21 @@ import Network.DNS.Internal
import Network.DNS.StateBinary
import Data.Enumerator (Enumerator, Iteratee, run_, ($$))
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as BL

responseIter :: Iteratee ByteString IO (DNSFormat, PState)
responseIter = runSGet decodeResponse
runDNSFormat :: BL.ByteString -> Either String (DNSFormat, PState)
runDNSFormat bs = runSGet decodeResponse bs

parseResponse :: Enumerator ByteString IO (a,b)
-> Iteratee ByteString IO (a,b)
-> IO a
runDNSFormat_ :: BL.ByteString -> Either String DNSFormat
runDNSFormat_ bs = fst <$> runDNSFormat bs

responseIter :: Monad m => Iteratee ByteString m (DNSFormat, PState)
responseIter = iterSGet decodeResponse

parseResponse :: (Functor m, Monad m)
=> Enumerator ByteString m (a,b)
-> Iteratee ByteString m (a,b)
-> m a
parseResponse enum iter = fst <$> run_ (enum $$ iter)

----------------------------------------------------------------
Expand Down
68 changes: 57 additions & 11 deletions Network/DNS/StateBinary.hs
@@ -1,42 +1,85 @@
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module Network.DNS.StateBinary where

import Blaze.ByteString.Builder
import Control.Applicative
import Control.Monad.State
import Data.Monoid
import Data.Attoparsec
import Data.Attoparsec.Enumerator
import qualified Data.Attoparsec.Lazy as AL
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (unpack)
import qualified Data.ByteString as BS (unpack, length)
import qualified Data.ByteString.Lazy as BL (ByteString)
import Data.Enumerator (Iteratee)
import Data.Int
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM (insert, lookup, empty)
import Data.Map (Map)
import qualified Data.Map as M (insert, lookup, empty)
import Data.Word
import Network.DNS.Types
import Prelude hiding (lookup, take)

----------------------------------------------------------------

type SPut = Write
type SPut = State WState Write

data WState = WState {
wsDomain :: Map Domain Int
, wsPosition :: Int
}

initialWState :: WState
initialWState = WState M.empty 0

instance Monoid SPut where
mempty = return mempty
mappend a b = mconcat <$> sequence [a, b]

put8 :: Word8 -> SPut
put8 = writeWord8
put8 = fixedSized 1 writeWord8

put16 :: Word16 -> SPut
put16 = writeWord16be
put16 = fixedSized 2 writeWord16be

put32 :: Word32 -> SPut
put32 = writeWord32be
put32 = fixedSized 4 writeWord32be

putInt8 :: Int -> SPut
putInt8 = writeInt8 . fromIntegral
putInt8 = fixedSized 1 (writeInt8 . fromIntegral)

putInt16 :: Int -> SPut
putInt16 = writeInt16be . fromIntegral
putInt16 = fixedSized 2 (writeInt16be . fromIntegral)

putInt32 :: Int -> SPut
putInt32 = writeInt32be . fromIntegral
putInt32 = fixedSized 4 (writeInt32be . fromIntegral)

putByteString :: ByteString -> SPut
putByteString = writeSized BS.length writeByteString

addPositionW :: Int -> State WState ()
addPositionW n = do
(WState m cur) <- get
put $ WState m (cur+n)

fixedSized :: Int -> (a -> Write) -> a -> SPut
fixedSized n f a = do addPositionW n
return (f a)

writeSized :: Show a => (a -> Int) -> (a -> Write) -> a -> SPut
writeSized n f a = do addPositionW (n a)
return (f a)

wsPop :: Domain -> State WState (Maybe Int)
wsPop dom = do
doms <- gets wsDomain
return $ M.lookup dom doms

wsPush :: Domain -> Int -> State WState ()
wsPush dom pos = do
(WState m cur) <- get
put $ WState (M.insert dom pos m) cur

----------------------------------------------------------------

Expand Down Expand Up @@ -114,8 +157,11 @@ getNByteString n = lift (take n) <* addPosition n
initialState :: PState
initialState = PState IM.empty 0

runSGet :: SGet a -> Iteratee ByteString IO (a, PState)
runSGet parser = iterParser (runStateT parser initialState)
iterSGet :: Monad m => SGet a -> Iteratee ByteString m (a, PState)
iterSGet parser = iterParser (runStateT parser initialState)

runSGet :: SGet a -> BL.ByteString -> Either String (a, PState)
runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs

runSPut :: SPut -> BL.ByteString
runSPut = toLazyByteString . fromWrite
runSPut = toLazyByteString . fromWrite . flip evalState initialWState

0 comments on commit 04c4afa

Please sign in to comment.