-
Notifications
You must be signed in to change notification settings - Fork 36
/
Resolver.hs
273 lines (244 loc) · 8.84 KB
/
Resolver.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
{-# LANGUAGE CPP #-}
-- | DNS Resolver and generic (lower-level) lookup functions.
module Network.DNS.Resolver (
-- * Documentation
-- ** Configuration for resolver
FileOrNumericHost(..), ResolvConf(..), defaultResolvConf
-- ** Intermediate data type for resolver
, ResolvSeed, makeResolvSeed
-- ** Type and function for resolver
, Resolver(..), withResolver
-- ** Looking up functions
, lookup, lookupAuth, lookupRaw
) where
import Control.Applicative
import Control.Exception
import Data.Char
import Data.Int
import Data.List hiding (find, lookup)
import Network.BSD
import Network.DNS.Decode
import Network.DNS.Encode
import Network.DNS.Internal
import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString.Lazy
import Prelude hiding (lookup)
import System.Random
import System.Timeout
#if mingw32_HOST_OS == 1
import Network.Socket (send)
import qualified Data.ByteString.Lazy.Char8 as LB
import Control.Monad (when)
#endif
----------------------------------------------------------------
-- | Union type for 'FilePath' and 'HostName'. Specify 'FilePath' to
-- \"resolv.conf\" or numeric IP address in 'String' form.
--
-- /Warning/: Only numeric IP addresses are valid @RCHostName@s.
--
-- Example (using Google's public DNS cache):
--
-- >>> let cache = RCHostName "8.8.8.8"
--
data FileOrNumericHost = RCFilePath FilePath | RCHostName HostName
-- | Type for resolver configuration. The easiest way to construct a
-- @ResolvConf@ object is to modify the 'defaultResolvConf'.
data ResolvConf = ResolvConf {
resolvInfo :: FileOrNumericHost
, resolvTimeout :: Int
, resolvRetry :: Int
-- | This field was obsoleted.
, resolvBufsize :: Integer
}
-- | Return a default 'ResolvConf':
--
-- * 'resolvInfo' is 'RCFilePath' \"\/etc\/resolv.conf\".
--
-- * 'resolvTimeout' is 3,000,000 micro seconds.
--
-- * 'resolvRetry' is 5.
--
-- * 'resolvBufsize' is 512. (obsoleted)
--
-- Example (use Google's public DNS cache instead of resolv.conf):
--
-- >>> let cache = RCHostName "8.8.8.8"
-- >>> let rc = defaultResolvConf { resolvInfo = cache }
--
defaultResolvConf :: ResolvConf
defaultResolvConf = ResolvConf {
resolvInfo = RCFilePath "/etc/resolv.conf"
, resolvTimeout = 3 * 1000 * 1000
, resolvRetry = 5
, resolvBufsize = 512
}
----------------------------------------------------------------
-- | Abstract data type of DNS Resolver seed
data ResolvSeed = ResolvSeed {
addrInfo :: AddrInfo
, rsTimeout :: Int
, rsRetry :: Int
, rsBufsize :: Integer
}
-- | Abstract data type of DNS Resolver
data Resolver = Resolver {
genId :: IO Int
, dnsSock :: Socket
, dnsTimeout :: Int
, dnsRetry :: Int
, dnsBufsize :: Integer
}
----------------------------------------------------------------
-- | Make a 'ResolvSeed' from a 'ResolvConf'.
--
-- Examples:
--
-- >>> rs <- makeResolvSeed defaultResolvConf
--
makeResolvSeed :: ResolvConf -> IO ResolvSeed
makeResolvSeed conf = ResolvSeed <$> addr
<*> pure (resolvTimeout conf)
<*> pure (resolvRetry conf)
<*> pure (resolvBufsize conf)
where
addr = case resolvInfo conf of
RCHostName numhost -> makeAddrInfo numhost
RCFilePath file -> toAddr <$> readFile file >>= makeAddrInfo
toAddr cs = let l:_ = filter ("nameserver" `isPrefixOf`) $ lines cs
in extract l
extract = reverse . dropWhile isSpace . reverse . dropWhile isSpace . drop 11
makeAddrInfo :: HostName -> IO AddrInfo
makeAddrInfo addr = do
proto <- getProtocolNumber "udp"
let hints = defaultHints {
addrFlags = [AI_ADDRCONFIG, AI_NUMERICHOST, AI_PASSIVE]
, addrSocketType = Datagram
, addrProtocol = proto
}
a:_ <- getAddrInfo (Just hints) (Just addr) (Just "domain")
return a
----------------------------------------------------------------
-- | Giving a thread-safe 'Resolver' to the function of the second
-- argument. 'withResolver' should be passed to 'forkIO'. For
-- examples, see "Network.DNS.Lookup".
withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver seed func = do
let ai = addrInfo seed
sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
connect sock (addrAddress ai)
let resolv = Resolver {
genId = getRandom
, dnsSock = sock
, dnsTimeout = rsTimeout seed
, dnsRetry = rsRetry seed
, dnsBufsize = rsBufsize seed
}
func resolv `finally` sClose sock
getRandom :: IO Int
getRandom = getStdRandom (randomR (0,65535))
----------------------------------------------------------------
-- | Looking up resource records of a domain. The first parameter is one of
-- the field accessors of the 'DNSFormat' type -- this allows you to
-- choose which section (answer, authority, or additional) you would like
-- to inspect for the result.
lookupSection :: (DNSFormat -> [ResourceRecord])
-> Resolver
-> Domain
-> TYPE
-> IO (Either DNSError [RDATA])
lookupSection section rlv dom typ = (>>= toRDATA) <$> lookupRaw rlv dom typ
where
{- CNAME hack
dom' = if "." `isSuffixOf` dom then dom else dom ++ "."
correct r = rrname r == dom' && rrtype r == typ
-}
correct r = rrtype r == typ
toRDATA = Right . map rdata . filter correct . section
-- | Look up resource records for a domain, collecting the results
-- from the ANSWER section of the response.
--
-- We repeat an example from "Network.DNS.Lookup":
--
-- >>> let hostname = Data.ByteString.Char8.pack "www.example.com"
-- >>> rs <- makeResolvSeed defaultResolvConf
-- >>> withResolver rs $ \resolver -> lookup resolver hostname A
-- Right [93.184.216.119]
--
lookup :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RDATA])
lookup = lookupSection answer
-- | Look up resource records for a domain, collecting the results
-- from the AUTHORITY section of the response.
lookupAuth :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RDATA])
lookupAuth = lookupSection authority
-- | Look up a name and return the entire DNS Response. Sample output
-- is included below, however it is /not/ tested -- the sequence
-- number is unpredictable (it has to be!).
--
-- The example code:
--
-- @
-- let hostname = Data.ByteString.Char8.pack \"www.example.com\"
-- rs <- makeResolvSeed defaultResolvConf
-- withResolver rs $ \resolver -> lookupRaw resolver hostname A
-- @
--
-- And the (formatted) expected output:
--
-- @
-- Right (DNSFormat
-- { header = DNSHeader
-- { identifier = 1,
-- flags = DNSFlags
-- { qOrR = QR_Response,
-- opcode = OP_STD,
-- authAnswer = False,
-- trunCation = False,
-- recDesired = True,
-- recAvailable = True,
-- rcode = NoErr },
-- qdCount = 1,
-- anCount = 1,
-- nsCount = 0,
-- arCount = 0},
-- question = [Question { qname = \"www.example.com.\",
-- qtype = A}],
-- answer = [ResourceRecord {rrname = \"www.example.com.\",
-- rrtype = A,
-- rrttl = 800,
-- rdlen = 4,
-- rdata = 93.184.216.119}],
-- authority = [],
-- additional = []})
-- @
--
lookupRaw :: Resolver -> Domain -> TYPE -> IO (Either DNSError DNSFormat)
lookupRaw rlv dom typ = do
seqno <- genId rlv
let query = composeQuery seqno [q]
checkSeqno = check seqno
loop query checkSeqno 0
where
loop query checkSeqno cnt
| cnt == retry = return $ Left TimeoutExpired
| otherwise = do
sendAll sock query
response <- timeout tm (receive sock)
case response of
Nothing -> loop query checkSeqno (cnt + 1)
Just res -> do
let valid = checkSeqno res
ret | valid = Right res
| otherwise = Left SequenceNumberMismatch
return ret
sock = dnsSock rlv
tm = dnsTimeout rlv
retry = dnsRetry rlv
q = makeQuestion dom typ
check seqno res = identifier (header res) == seqno
#if mingw32_HOST_OS == 1
-- Windows does not support sendAll in Network.ByteString.Lazy.
-- This implements sendAll with Haskell Strings.
sendAll sock bs = do
sent <- send sock (LB.unpack bs)
when (sent < fromIntegral (LB.length bs)) $ sendAll sock (LB.drop (fromIntegral sent) bs)
#endif