Skip to content

Commit

Permalink
add server host name into the server data
Browse files Browse the repository at this point in the history
  • Loading branch information
lyokha committed Oct 1, 2023
1 parent 029717b commit 46b22b8
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions NgxExport/Tools/Resolve.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE TemplateHaskell, RecordWildCards, BangPatterns, NumDecimals #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections, OverloadedStrings #-}

-----------------------------------------------------------------------------
-- |
Expand Down Expand Up @@ -40,12 +40,14 @@ import NgxExport.Tools.TimeInterval

import Network.DNS
import Network.HTTP.Client
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.IORef
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Maybe
import Data.Aeson
import Data.Function
Expand Down Expand Up @@ -300,6 +302,7 @@ newtype Upconf = Upconf { upconfAddr :: (SUrl, SAddress) } deriving Read
--
-- The fields map exactly to parameters from Nginx /server/ description.
data ServerData = ServerData { sAddr :: SAddress -- ^ Server address
, sHost :: SAddress -- ^ Server host name
, sWeight :: Maybe Int -- ^ /weight/
, sMaxFails :: Maybe Int -- ^ /max_fails/
, sFailTimeout :: Maybe Int -- ^ /fail_timeout/
Expand All @@ -308,6 +311,7 @@ data ServerData = ServerData { sAddr :: SAddress -- ^ Server address
instance FromJSON ServerData where
parseJSON = withObject "server_options" $ \o -> do
sAddr <- o .: "addr"
sHost <- o .: "host"
sWeight <- o .:? "weight"
sMaxFails <- o .:? "max_fails"
sFailTimeout <- o .:? "fail_timeout"
Expand All @@ -316,6 +320,7 @@ instance FromJSON ServerData where
instance ToJSON ServerData where
toJSON ServerData {..} =
object $ catMaybes [ pure $ "addr" .= sAddr
, pure $ "host" .= sHost
, ("weight" .=) <$> sWeight
, ("max_fails" .=) <$> sMaxFails
, ("fail_timeout" .=) <$> sFailTimeout
Expand Down Expand Up @@ -364,25 +369,35 @@ collectA lTTL name = do
--
-- After getting the /SRV/ record, runs 'collectA' for each collected element.
--
-- Returns a list of IP addresses wrapped in an 'SRV' container and the minimum
-- value of their TTLs. If the list is empty, then the returned TTL value gets
-- taken from the first argument.
-- Returns a list of pairs /(Domain name, IP address)/ wrapped in an 'SRV'
-- container and the minimum value of their TTLs. If the list is empty, then
-- the returned TTL value gets taken from the first argument. Note that trailing
-- dots get removed in the returned domain names.
collectSRV
:: TTL -- ^ Fallback TTL value
-> Name -- ^ Service name
-> IO (TTL, [SRV IPv4])
-> IO (TTL, [SRV (Name, IPv4)])
collectSRV lTTL name = do
!srv <- querySRV name
!srv' <- mapConcurrently
((\s@SRV {..} -> do
(t, is) <- collectA lTTL srvTarget
return (t, map (\v -> s { srvTarget = v }) is)
return (t
,map (\v ->
s { srvTarget =
(removeTrailingDot srvTarget, v)
}
) is
)
) . snd
) srv
return (min (minimumTTL lTTL $ map fst srv)
(minimumTTL lTTL $ map fst srv')
,concatMap snd srv'
)
where removeTrailingDot (Name v) = Name $ case C8.unsnoc v of
Just (v', '.') -> v'
_ -> v

showIPv4 :: IPv4 -> String
showIPv4 (IPv4 w) =
Expand All @@ -392,15 +407,17 @@ showIPv4 (IPv4 w) =
shows ( w .&. 0xff)
""

ipv4ToServerData :: UData -> IPv4 -> ServerData
ipv4ToServerData UData {..} i =
ServerData (T.pack $ show i) Nothing (Just uMaxFails) (Just uFailTimeout)
ipv4ToServerData :: UData -> Name -> IPv4 -> ServerData
ipv4ToServerData UData {..} (Name n) i =
ServerData (T.pack $ show i) (T.decodeUtf8 n)
Nothing (Just uMaxFails) (Just uFailTimeout)

srvToServerData :: UData -> SRV IPv4 -> ServerData
srvToServerData :: UData -> SRV (Name, IPv4) -> ServerData
srvToServerData UData {..} SRV {..} =
ServerData (T.pack $ showAddr srvTarget srvPort)
(Just $ fromIntegral srvWeight) (Just uMaxFails) (Just uFailTimeout)
where showAddr i p = showIPv4 i ++ ':' : show p
let (Name n, a) = srvTarget
showAddr i p = showIPv4 i ++ ':' : show p
in ServerData (T.pack $ showAddr a srvPort) (T.decodeUtf8 n)
(Just $ fromIntegral srvWeight) (Just uMaxFails) (Just uFailTimeout)

-- | Collects server data for the given upstream configuration.
--
Expand All @@ -415,15 +432,15 @@ collectServerData
collectServerData lTTL (UData (QueryA [] _) _ _) =
return (lTTL, M.empty)
collectServerData lTTL ud@(UData (QueryA ns u) _ _) = do
a <- mapConcurrently (collectA lTTL) ns
a <- mapConcurrently (\n -> (n, ) <$> collectA lTTL n) ns
return $
minimum *** M.singleton u . concat $
foldr (\(t, s) (ts, ss) ->
foldr (\(n, (t, s)) (ts, ss) ->
-- sort is required because resolver may rotate servers
-- which means that the same data may differ after every
-- single check; this note regards to other clauses of
-- this function as well
(t : ts, sort (map (ipv4ToServerData ud) s) : ss)
(t : ts, sort (map (ipv4ToServerData ud n) s) : ss)
) ([], []) a
collectServerData lTTL ud@(UData (QuerySRV n (SinglePriority u)) _ _) = do
(wt, srv) <- collectSRV lTTL n
Expand Down

0 comments on commit 46b22b8

Please sign in to comment.