Skip to content

Commit

Permalink
kafka: fix (most of) improper behaviours of the group state machine (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Commelina committed May 9, 2024
1 parent 1f6ac0f commit 3309ba7
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 38 deletions.
20 changes: 20 additions & 0 deletions hstream-kafka/HStream/Kafka/Common/Utils.hs
Expand Up @@ -3,6 +3,7 @@

module HStream.Kafka.Common.Utils where

import Control.Concurrent
import Control.Exception (throw)
import qualified Control.Monad as M
import qualified Control.Monad.ST as ST
Expand All @@ -18,6 +19,7 @@ import qualified Data.Text.Encoding as T
import qualified Data.Vector as V
import HStream.Kafka.Common.KafkaException (ErrorCodeException (ErrorCodeException))
import qualified Kafka.Protocol.Encoding as K
import qualified System.Timeout as Timeout

type HashTable k v = H.BasicHashTable k v

Expand Down Expand Up @@ -93,3 +95,21 @@ encodeBase64 = Base64.extractBase64 . Base64.encodeBase64

decodeBase64 :: T.Text -> BS.ByteString
decodeBase64 = Base64.decodeBase64Lenient . T.encodeUtf8

-- | Perform the action when the predicate is true or timeout is reached.
-- An extra action is performed when the timeout expire, whose result will
-- be discarded.
-- Warning: The second action is always performed no matter whether the
-- timeout is reached or not.
onOrTimeout :: IO Bool -> Int -> IO a -> IO b -> IO b
onOrTimeout p timeoutMs actionOnExpire action =
Timeout.timeout (timeoutMs * 1000) loop >>= \case
Nothing -> do
M.void actionOnExpire
action
Just a -> return a
where
loop = p >>= \case
True -> action
-- FIXME: Hardcoded constant (check every 10ms)
False -> threadDelay (10 * 1000) >> loop
156 changes: 118 additions & 38 deletions hstream-kafka/HStream/Kafka/Group/Group.hs
Expand Up @@ -15,7 +15,8 @@ import Data.Int (Int32)
import qualified Data.IORef as IO
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, listToMaybe)
import Data.Maybe (fromMaybe, isJust,
listToMaybe)
import qualified Data.Set as Set
import qualified Data.Text as T
import qualified Data.UUID as UUID
Expand Down Expand Up @@ -105,6 +106,28 @@ data GroupState
| Empty
deriving (Show, Eq)

groupStateValidPrevs :: GroupState -> [GroupState]
groupStateValidPrevs = \case
PreparingRebalance -> [Stable, Empty, CompletingRebalance]
CompletingRebalance -> [PreparingRebalance]
Stable -> [CompletingRebalance]
Dead -> [Stable, Empty, Dead, PreparingRebalance, CompletingRebalance]
Empty -> [PreparingRebalance]

-- TODO: Throw exceptions or use Either rather than logging only.
transitionTo :: Group -> GroupState -> IO ()
transitionTo group toState = do
curState <- IO.readIORef group.state
if curState `elem` groupStateValidPrevs toState
then do
IO.atomicWriteIORef group.state toState
Log.info $ "group state changed, " <> Log.buildString' curState <> " -> " <> Log.buildString' toState
<> ", group:" <> Log.build group.groupId
else do
Log.warning . Log.build $
"Invalid state transition from " <> T.pack (show curState)
<> " to " <> T.pack (show toState) <> " for group:" <> group.groupId

data Group = Group
{ lock :: C.MVar ()
, groupId :: T.Text
Expand Down Expand Up @@ -350,18 +373,41 @@ doDynamicNewMemberJoinGroup group reqCtx req newMemberId delayedResponse = do

addMemberAndRebalance :: Group -> K.RequestContext -> K.JoinGroupRequest -> T.Text -> C.MVar K.JoinGroupResponse -> IO ()
addMemberAndRebalance group reqCtx req newMemberId delayedResponse = do
isGroupEmpty <- Utils.hashtableNull group.members
member <- newMemberFromReq reqCtx req newMemberId (refineProtocols req.protocols)
addMember group member (Just delayedResponse)
-- TODO: check state
prepareRebalance group $ "add member:" <> member.memberId
-- Note: We only support dynamic join so it is OK to consider it as new init join
-- when the group was empty.
memberRebalanceTimeoutMs <- IO.readIORef member.rebalanceTimeoutMs
curState <- IO.readIORef group.state
when (curState `elem` [Empty, Stable, CompletingRebalance]) $ do
-- Note: The new-added member is always present. So this is equivalent
-- to check among members before.
prepareRebalance group
(if isGroupEmpty then return False else (haveAllMembersRejoined group))
(if isGroupEmpty then fromIntegral group.groupConfig.groupInitialRebalanceDelay
else memberRebalanceTimeoutMs)
("add member:" <> member.memberId)

updateMemberAndRebalance :: Group -> Member -> K.JoinGroupRequest -> C.MVar K.JoinGroupResponse -> IO ()
updateMemberAndRebalance group member req delayedResponse = do
updateMember group member req delayedResponse
prepareRebalance group $ "update member:" <> member.memberId

prepareRebalance :: Group -> T.Text -> IO ()
prepareRebalance group@Group{..} reason = do
-- Note: On this case, the group can not be empty because at least this member
-- is present. So we wait until all previous members have joined or timeout.
-- FIXME: How long to wait? Is 'rebalanceTimeoutMs' of the member correct?
timeout <- IO.readIORef member.rebalanceTimeoutMs
curState <- IO.readIORef group.state
when (curState `elem` [Empty, Stable, CompletingRebalance]) $ do
-- Note: The new-added member is always present. So this is equivalent
-- to check among members before.
prepareRebalance group
(haveAllMembersRejoined group)
timeout
("update member:" <> member.memberId)

prepareRebalance :: Group -> IO Bool -> Int32 -> T.Text -> IO ()
prepareRebalance group@Group{..} p timeoutMs reason = do
Log.info $ "prepare rebalance, group:" <> Log.build groupId
<> "reason:" <> Log.build reason
-- check state CompletingRebalance and cancel delayedSyncResponses
Expand All @@ -371,24 +417,19 @@ prepareRebalance group@Group{..} reason = do
-- cancel delayed sync
cancelDelayedSync group

-- isEmptyState <- (Empty ==) <$> IO.readIORef state

-- setup delayed rebalance if delayedRebalance is Nothing
IO.readIORef delayedRebalance >>= \case
Nothing -> do
delayed <- makeDelayedRebalance group group.groupConfig.groupInitialRebalanceDelay
delayed <- makeDelayedRebalance group p (fromIntegral timeoutMs)
Log.info $ "created delayed rebalance thread:" <> Log.buildString' delayed
<> ", group:" <> Log.build groupId
IO.atomicWriteIORef delayedRebalance (Just delayed)
IO.atomicWriteIORef state PreparingRebalance
transitionTo group PreparingRebalance
_ -> pure ()

-- TODO: dynamically delay with initTimeoutMs and RebalanceTimeoutMs
makeDelayedRebalance :: Group -> Int -> IO C.ThreadId
makeDelayedRebalance group rebalanceDelayMs = do
C.forkIO $ do
C.threadDelay (1000 * rebalanceDelayMs)
rebalance group
makeDelayedRebalance :: Group -> IO Bool -> Int -> IO C.ThreadId
makeDelayedRebalance group p rebalanceDelayMs =
C.forkIO $ Utils.onOrTimeout p rebalanceDelayMs (pure ()) (rebalance group)

rebalance :: Group -> IO ()
rebalance group@Group{..} = do
Expand Down Expand Up @@ -416,12 +457,6 @@ rebalance group@Group{..} = do
Just leaderMemberId -> do
doRelance group leaderMemberId

transitionTo :: Group -> GroupState -> IO ()
transitionTo group state = do
oldState <- IO.atomicModifyIORef' group.state (state, )
Log.info $ "group state changed, " <> Log.buildString' oldState <> " -> " <> Log.buildString' state
<> ", group:" <> Log.build group.groupId

doRelance :: Group -> T.Text -> IO ()
doRelance group@Group{..} leaderMemberId = do
selectedProtocolName <- computeProtocolName group
Expand Down Expand Up @@ -453,7 +488,9 @@ doRelance group@Group{..} leaderMemberId = do
_ <- C.tryPutMVar delayed resp
H.delete delayedJoinResponses memberId

rebalanceTimeoutMs <- computeRebalnceTimeoutMs group
rebalanceTimeoutMs <- computeRebalanceTimeoutMs group
-- FIXME: Is it correct to use rebalance timeout here? Or maybe session timeout?
-- The state machine here is really weird...
delayedSyncTid <- makeDelayedSync group generationId rebalanceTimeoutMs
IO.atomicWriteIORef delayedSync (Just delayedSyncTid)
Log.info $ "create delayed sync for group:" <> Log.build groupId
Expand Down Expand Up @@ -482,8 +519,11 @@ removeNotYetSyncedMembers group@Group{..} = do

makeDelayedSync :: Group -> Int32 -> Int32 -> IO C.ThreadId
makeDelayedSync group@Group{..} generationId timeoutMs = do
C.forkIO $ do
C.threadDelay (fromIntegral timeoutMs * 1000)
curMembers <- H.toList members
C.forkIO $ Utils.onOrTimeout (hasReceivedAllSyncs curMembers)
(fromIntegral timeoutMs) -- always use rebalance timeout
(pure ())
$ do
C.withMVar lock $ \() -> do
Utils.unlessIORefEq groupGenerationId generationId $ \currentGid -> do
Log.warning $ "unexpected delayed sync with wrong generationId:" <> Log.build generationId
Expand All @@ -496,14 +536,20 @@ makeDelayedSync group@Group{..} generationId timeoutMs = do

-- remove itself (to avoid killing itself in prepareRebalance)
IO.atomicWriteIORef delayedSync Nothing
prepareRebalance group $ "delayed sync timeout"
s -> do
Log.warning $ "unexpected delayed sync with wrong state:" <> Log.buildString' s
<> ", group:" <> Log.build groupId
where
hasReceivedAllSyncs :: [(T.Text, Member)] -> IO Bool
hasReceivedAllSyncs curMembers_ = do
M.foldM (\acc (mid,_) -> case acc of
False -> return False
True -> isJust <$> H.lookup group.delayedSyncResponses mid
) True curMembers_

-- select max rebalanceTimeoutMs from all members
computeRebalnceTimeoutMs :: Group -> IO Int32
computeRebalnceTimeoutMs Group{..} = do
computeRebalanceTimeoutMs :: Group -> IO Int32
computeRebalanceTimeoutMs Group{..} = do
H.foldM (\x (_, m) -> max x <$> IO.readIORef m.rebalanceTimeoutMs) 0 members

getJoinResponseMember :: T.Text -> Member -> IO K.JoinGroupResponseMember
Expand Down Expand Up @@ -702,17 +748,14 @@ leaveGroup group@Group{..} req = do
C.withMVar lock $ \() -> do
member <- getMember group req.memberId
IO.readIORef state >>= \case
Dead -> throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)
Dead -> throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)
Empty -> throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)
CompletingRebalance -> removeMemberAndUpdateGroup group member
Stable -> removeMemberAndUpdateGroup group member
PreparingRebalance -> do
-- TODO: should NOT BE PASSIBLE in this version
Log.warning $ "received a leave group in PreparingRebalance state, ignored it"
<> ", groupId:" <> Log.buildString' req.groupId
<> ", memberId:" <> Log.buildString' req.memberId
throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)

Stable -> removeMemberAndUpdateGroup group member
-- Note: This is possible, of course. On this case, just try completing
-- current rebalance rather than preparing a new one. And in fact
-- we do not need to do anything because we have already watched it.
PreparingRebalance -> removeMemberAndUpdateGroup group member
return $ K.LeaveGroupResponse {errorCode=0, throttleTimeMs=0}

getMember :: Group -> T.Text -> IO Member
Expand All @@ -725,13 +768,34 @@ removeMemberAndUpdateGroup :: Group -> Member -> IO ()
removeMemberAndUpdateGroup group@Group{..} member = do
Log.info $ "member: " <> Log.build member.memberId <> " is leaving group:" <> Log.build groupId

-- Note from Kafka:
-- New members may timeout with a pending JoinGroup while the group is still rebalancing, so we have
-- to invoke the callback before removing the member. We return UNKNOWN_MEMBER_ID so that the consumer
-- will retry the JoinGroup request if is still active.
-- Note: This means returning UNKNOWN_MEMBER_ID to the JoinGroupRequest rather than the LeaveGroupRequest.
-- Note: A client may reset offsets right after leaving the group, and we have to make sure that
-- the group is stable (not rebalancing) before the client can reset offsets. Thanks to the
-- lock and rebalancing watcher, this can be atomatically met after the last consumer leaving.
cancelDelayedJoinResponse group member.memberId

removeMember group member
prepareRebalance group $ "remove member:" <> member.memberId

curState <- IO.readIORef state

if curState == PreparingRebalance then do
IO.readIORef delayedRebalance >>= \case
Nothing -> throw (ErrorCodeException K.UNKNOWN_SERVER_ERROR)
Just _ -> return ()
else if curState `elem` [Empty, Stable, CompletingRebalance] then do
groupRebalanceTimeoutMs <- H.foldM (\acc (_,thisMember) -> do
thisRBTimeoutMs <- IO.readIORef thisMember.rebalanceTimeoutMs
return (max acc thisRBTimeoutMs)
) 0 members
prepareRebalance group
(haveAllMembersRejoined group)
groupRebalanceTimeoutMs
("remove member:" <> member.memberId)
else return ()

cancelDelayedJoinResponse :: Group -> T.Text -> IO ()
cancelDelayedJoinResponse Group{..} memberId = do
Expand Down Expand Up @@ -1035,3 +1099,19 @@ makeJoinResponseError memberId errorCode =
, members = K.NonNullKaArray V.empty
, throttleTimeMs = 0
}


------------------------------ Misc -------------------------------
-- WARNING: Use the list each time this is called, rather than the list
-- when preparing a rebalance. Consider a case: c1, c2 and c3
-- leave the group in order. We should check [c2, c3] then [c3]
-- and finally [], rather than always checking [c2, c3] (this
-- will never be met!).
-- WARNING: Compare with syncing stage (hasReceivedAllSyncs), who uses
-- the list when preparing a rebalance.
haveAllMembersRejoined :: Group -> IO Bool
haveAllMembersRejoined group = do
H.foldM (\acc (mid,_) -> case acc of
False -> return False
True -> isJust <$> H.lookup group.delayedJoinResponses mid
) True group.members

0 comments on commit 3309ba7

Please sign in to comment.