From 46b22b89ec30555774fe5b834f2aca98896c5eb4 Mon Sep 17 00:00:00 2001 From: Alexey Radkov Date: Sun, 1 Oct 2023 04:43:49 +0400 Subject: [PATCH] add server host name into the server data --- NgxExport/Tools/Resolve.hs | 49 +++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/NgxExport/Tools/Resolve.hs b/NgxExport/Tools/Resolve.hs index c4708d0..f8e77a5 100644 --- a/NgxExport/Tools/Resolve.hs +++ b/NgxExport/Tools/Resolve.hs @@ -1,5 +1,5 @@ {-# LANGUAGE TemplateHaskell, RecordWildCards, BangPatterns, NumDecimals #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections, OverloadedStrings #-} ----------------------------------------------------------------------------- -- | @@ -40,12 +40,14 @@ import NgxExport.Tools.TimeInterval import Network.DNS import Network.HTTP.Client +import qualified Data.ByteString.Char8 as C8 import qualified Data.ByteString.Lazy as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.IORef import Data.Text (Text) import qualified Data.Text as T +import qualified Data.Text.Encoding as T import Data.Maybe import Data.Aeson import Data.Function @@ -300,6 +302,7 @@ newtype Upconf = Upconf { upconfAddr :: (SUrl, SAddress) } deriving Read -- -- The fields map exactly to parameters from Nginx /server/ description. data ServerData = ServerData { sAddr :: SAddress -- ^ Server address + , sHost :: SAddress -- ^ Server host name , sWeight :: Maybe Int -- ^ /weight/ , sMaxFails :: Maybe Int -- ^ /max_fails/ , sFailTimeout :: Maybe Int -- ^ /fail_timeout/ @@ -308,6 +311,7 @@ data ServerData = ServerData { sAddr :: SAddress -- ^ Server address instance FromJSON ServerData where parseJSON = withObject "server_options" $ \o -> do sAddr <- o .: "addr" + sHost <- o .: "host" sWeight <- o .:? "weight" sMaxFails <- o .:? "max_fails" sFailTimeout <- o .:? "fail_timeout" @@ -316,6 +320,7 @@ instance FromJSON ServerData where instance ToJSON ServerData where toJSON ServerData {..} = object $ catMaybes [ pure $ "addr" .= sAddr + , pure $ "host" .= sHost , ("weight" .=) <$> sWeight , ("max_fails" .=) <$> sMaxFails , ("fail_timeout" .=) <$> sFailTimeout @@ -364,25 +369,35 @@ collectA lTTL name = do -- -- After getting the /SRV/ record, runs 'collectA' for each collected element. -- --- Returns a list of IP addresses wrapped in an 'SRV' container and the minimum --- value of their TTLs. If the list is empty, then the returned TTL value gets --- taken from the first argument. +-- Returns a list of pairs /(Domain name, IP address)/ wrapped in an 'SRV' +-- container and the minimum value of their TTLs. If the list is empty, then +-- the returned TTL value gets taken from the first argument. Note that trailing +-- dots get removed in the returned domain names. collectSRV :: TTL -- ^ Fallback TTL value -> Name -- ^ Service name - -> IO (TTL, [SRV IPv4]) + -> IO (TTL, [SRV (Name, IPv4)]) collectSRV lTTL name = do !srv <- querySRV name !srv' <- mapConcurrently ((\s@SRV {..} -> do (t, is) <- collectA lTTL srvTarget - return (t, map (\v -> s { srvTarget = v }) is) + return (t + ,map (\v -> + s { srvTarget = + (removeTrailingDot srvTarget, v) + } + ) is + ) ) . snd ) srv return (min (minimumTTL lTTL $ map fst srv) (minimumTTL lTTL $ map fst srv') ,concatMap snd srv' ) + where removeTrailingDot (Name v) = Name $ case C8.unsnoc v of + Just (v', '.') -> v' + _ -> v showIPv4 :: IPv4 -> String showIPv4 (IPv4 w) = @@ -392,15 +407,17 @@ showIPv4 (IPv4 w) = shows ( w .&. 0xff) "" -ipv4ToServerData :: UData -> IPv4 -> ServerData -ipv4ToServerData UData {..} i = - ServerData (T.pack $ show i) Nothing (Just uMaxFails) (Just uFailTimeout) +ipv4ToServerData :: UData -> Name -> IPv4 -> ServerData +ipv4ToServerData UData {..} (Name n) i = + ServerData (T.pack $ show i) (T.decodeUtf8 n) + Nothing (Just uMaxFails) (Just uFailTimeout) -srvToServerData :: UData -> SRV IPv4 -> ServerData +srvToServerData :: UData -> SRV (Name, IPv4) -> ServerData srvToServerData UData {..} SRV {..} = - ServerData (T.pack $ showAddr srvTarget srvPort) - (Just $ fromIntegral srvWeight) (Just uMaxFails) (Just uFailTimeout) - where showAddr i p = showIPv4 i ++ ':' : show p + let (Name n, a) = srvTarget + showAddr i p = showIPv4 i ++ ':' : show p + in ServerData (T.pack $ showAddr a srvPort) (T.decodeUtf8 n) + (Just $ fromIntegral srvWeight) (Just uMaxFails) (Just uFailTimeout) -- | Collects server data for the given upstream configuration. -- @@ -415,15 +432,15 @@ collectServerData collectServerData lTTL (UData (QueryA [] _) _ _) = return (lTTL, M.empty) collectServerData lTTL ud@(UData (QueryA ns u) _ _) = do - a <- mapConcurrently (collectA lTTL) ns + a <- mapConcurrently (\n -> (n, ) <$> collectA lTTL n) ns return $ minimum *** M.singleton u . concat $ - foldr (\(t, s) (ts, ss) -> + foldr (\(n, (t, s)) (ts, ss) -> -- sort is required because resolver may rotate servers -- which means that the same data may differ after every -- single check; this note regards to other clauses of -- this function as well - (t : ts, sort (map (ipv4ToServerData ud) s) : ss) + (t : ts, sort (map (ipv4ToServerData ud n) s) : ss) ) ([], []) a collectServerData lTTL ud@(UData (QuerySRV n (SinglePriority u)) _ _) = do (wt, srv) <- collectSRV lTTL n