Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kafka: fix (most of) improper behaviours of the group state machine #1813

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions hstream-kafka/HStream/Kafka/Common/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

module HStream.Kafka.Common.Utils where

import Control.Concurrent
import Control.Exception (throw)
import qualified Control.Monad as M
import qualified Control.Monad.ST as ST
Expand All @@ -18,6 +19,7 @@ import qualified Data.Text.Encoding as T
import qualified Data.Vector as V
import HStream.Kafka.Common.KafkaException (ErrorCodeException (ErrorCodeException))
import qualified Kafka.Protocol.Encoding as K
import qualified System.Timeout as Timeout

type HashTable k v = H.BasicHashTable k v

Expand Down Expand Up @@ -93,3 +95,21 @@ encodeBase64 = Base64.extractBase64 . Base64.encodeBase64

decodeBase64 :: T.Text -> BS.ByteString
decodeBase64 = Base64.decodeBase64Lenient . T.encodeUtf8

-- | Perform the action when the predicate is true or timeout is reached.
-- An extra action is performed when the timeout expire, whose result will
-- be discarded.
-- Warning: The second action is always performed no matter whether the
-- timeout is reached or not.
onOrTimeout :: IO Bool -> Int -> IO a -> IO b -> IO b
onOrTimeout p timeoutMs actionOnExpire action =
Timeout.timeout (timeoutMs * 1000) loop >>= \case
Nothing -> do
M.void actionOnExpire
action
Just a -> return a
where
loop = p >>= \case
True -> action
-- FIXME: Hardcoded constant (check every 10ms)
False -> threadDelay (10 * 1000) >> loop
156 changes: 118 additions & 38 deletions hstream-kafka/HStream/Kafka/Group/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import Data.Int (Int32)
import qualified Data.IORef as IO
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, listToMaybe)
import Data.Maybe (fromMaybe, isJust,
listToMaybe)
import qualified Data.Set as Set
import qualified Data.Text as T
import qualified Data.UUID as UUID
Expand Down Expand Up @@ -105,6 +106,28 @@ data GroupState
| Empty
deriving (Show, Eq)

groupStateValidPrevs :: GroupState -> [GroupState]
groupStateValidPrevs = \case
PreparingRebalance -> [Stable, Empty, CompletingRebalance]
CompletingRebalance -> [PreparingRebalance]
Stable -> [CompletingRebalance]
Dead -> [Stable, Empty, Dead, PreparingRebalance, CompletingRebalance]
Empty -> [PreparingRebalance]

-- TODO: Throw exceptions or use Either rather than logging only.
transitionTo :: Group -> GroupState -> IO ()
transitionTo group toState = do
curState <- IO.readIORef group.state
if curState `elem` groupStateValidPrevs toState
then do
IO.atomicWriteIORef group.state toState
Log.info $ "group state changed, " <> Log.buildString' curState <> " -> " <> Log.buildString' toState
<> ", group:" <> Log.build group.groupId
else do
Log.warning . Log.build $
"Invalid state transition from " <> T.pack (show curState)
<> " to " <> T.pack (show toState) <> " for group:" <> group.groupId

data Group = Group
{ lock :: C.MVar ()
, groupId :: T.Text
Expand Down Expand Up @@ -350,18 +373,41 @@ doDynamicNewMemberJoinGroup group reqCtx req newMemberId delayedResponse = do

addMemberAndRebalance :: Group -> K.RequestContext -> K.JoinGroupRequest -> T.Text -> C.MVar K.JoinGroupResponse -> IO ()
addMemberAndRebalance group reqCtx req newMemberId delayedResponse = do
isGroupEmpty <- Utils.hashtableNull group.members
member <- newMemberFromReq reqCtx req newMemberId (refineProtocols req.protocols)
addMember group member (Just delayedResponse)
-- TODO: check state
prepareRebalance group $ "add member:" <> member.memberId
-- Note: We only support dynamic join so it is OK to consider it as new init join
-- when the group was empty.
memberRebalanceTimeoutMs <- IO.readIORef member.rebalanceTimeoutMs
curState <- IO.readIORef group.state
when (curState `elem` [Empty, Stable, CompletingRebalance]) $ do
-- Note: The new-added member is always present. So this is equivalent
-- to check among members before.
prepareRebalance group
(if isGroupEmpty then return False else (haveAllMembersRejoined group))
(if isGroupEmpty then fromIntegral group.groupConfig.groupInitialRebalanceDelay
else memberRebalanceTimeoutMs)
("add member:" <> member.memberId)

updateMemberAndRebalance :: Group -> Member -> K.JoinGroupRequest -> C.MVar K.JoinGroupResponse -> IO ()
updateMemberAndRebalance group member req delayedResponse = do
updateMember group member req delayedResponse
prepareRebalance group $ "update member:" <> member.memberId

prepareRebalance :: Group -> T.Text -> IO ()
prepareRebalance group@Group{..} reason = do
-- Note: On this case, the group can not be empty because at least this member
-- is present. So we wait until all previous members have joined or timeout.
-- FIXME: How long to wait? Is 'rebalanceTimeoutMs' of the member correct?
timeout <- IO.readIORef member.rebalanceTimeoutMs
curState <- IO.readIORef group.state
when (curState `elem` [Empty, Stable, CompletingRebalance]) $ do
-- Note: The new-added member is always present. So this is equivalent
-- to check among members before.
prepareRebalance group
(haveAllMembersRejoined group)
timeout
("update member:" <> member.memberId)

prepareRebalance :: Group -> IO Bool -> Int32 -> T.Text -> IO ()
prepareRebalance group@Group{..} p timeoutMs reason = do
Log.info $ "prepare rebalance, group:" <> Log.build groupId
<> "reason:" <> Log.build reason
-- check state CompletingRebalance and cancel delayedSyncResponses
Expand All @@ -371,24 +417,19 @@ prepareRebalance group@Group{..} reason = do
-- cancel delayed sync
cancelDelayedSync group

-- isEmptyState <- (Empty ==) <$> IO.readIORef state

-- setup delayed rebalance if delayedRebalance is Nothing
IO.readIORef delayedRebalance >>= \case
Nothing -> do
delayed <- makeDelayedRebalance group group.groupConfig.groupInitialRebalanceDelay
delayed <- makeDelayedRebalance group p (fromIntegral timeoutMs)
Log.info $ "created delayed rebalance thread:" <> Log.buildString' delayed
<> ", group:" <> Log.build groupId
IO.atomicWriteIORef delayedRebalance (Just delayed)
IO.atomicWriteIORef state PreparingRebalance
transitionTo group PreparingRebalance
_ -> pure ()

-- TODO: dynamically delay with initTimeoutMs and RebalanceTimeoutMs
makeDelayedRebalance :: Group -> Int -> IO C.ThreadId
makeDelayedRebalance group rebalanceDelayMs = do
C.forkIO $ do
C.threadDelay (1000 * rebalanceDelayMs)
rebalance group
makeDelayedRebalance :: Group -> IO Bool -> Int -> IO C.ThreadId
makeDelayedRebalance group p rebalanceDelayMs =
C.forkIO $ Utils.onOrTimeout p rebalanceDelayMs (pure ()) (rebalance group)

rebalance :: Group -> IO ()
rebalance group@Group{..} = do
Expand Down Expand Up @@ -416,12 +457,6 @@ rebalance group@Group{..} = do
Just leaderMemberId -> do
doRelance group leaderMemberId

transitionTo :: Group -> GroupState -> IO ()
transitionTo group state = do
oldState <- IO.atomicModifyIORef' group.state (state, )
Log.info $ "group state changed, " <> Log.buildString' oldState <> " -> " <> Log.buildString' state
<> ", group:" <> Log.build group.groupId

doRelance :: Group -> T.Text -> IO ()
doRelance group@Group{..} leaderMemberId = do
selectedProtocolName <- computeProtocolName group
Expand Down Expand Up @@ -453,7 +488,9 @@ doRelance group@Group{..} leaderMemberId = do
_ <- C.tryPutMVar delayed resp
H.delete delayedJoinResponses memberId

rebalanceTimeoutMs <- computeRebalnceTimeoutMs group
rebalanceTimeoutMs <- computeRebalanceTimeoutMs group
-- FIXME: Is it correct to use rebalance timeout here? Or maybe session timeout?
-- The state machine here is really weird...
delayedSyncTid <- makeDelayedSync group generationId rebalanceTimeoutMs
IO.atomicWriteIORef delayedSync (Just delayedSyncTid)
Log.info $ "create delayed sync for group:" <> Log.build groupId
Expand Down Expand Up @@ -482,8 +519,11 @@ removeNotYetSyncedMembers group@Group{..} = do

makeDelayedSync :: Group -> Int32 -> Int32 -> IO C.ThreadId
makeDelayedSync group@Group{..} generationId timeoutMs = do
C.forkIO $ do
C.threadDelay (fromIntegral timeoutMs * 1000)
curMembers <- H.toList members
C.forkIO $ Utils.onOrTimeout (hasReceivedAllSyncs curMembers)
(fromIntegral timeoutMs) -- always use rebalance timeout
(pure ())
$ do
C.withMVar lock $ \() -> do
Utils.unlessIORefEq groupGenerationId generationId $ \currentGid -> do
Log.warning $ "unexpected delayed sync with wrong generationId:" <> Log.build generationId
Expand All @@ -496,14 +536,20 @@ makeDelayedSync group@Group{..} generationId timeoutMs = do

-- remove itself (to avoid killing itself in prepareRebalance)
IO.atomicWriteIORef delayedSync Nothing
prepareRebalance group $ "delayed sync timeout"
s -> do
Log.warning $ "unexpected delayed sync with wrong state:" <> Log.buildString' s
<> ", group:" <> Log.build groupId
where
hasReceivedAllSyncs :: [(T.Text, Member)] -> IO Bool
hasReceivedAllSyncs curMembers_ = do
M.foldM (\acc (mid,_) -> case acc of
False -> return False
True -> isJust <$> H.lookup group.delayedSyncResponses mid
) True curMembers_

-- select max rebalanceTimeoutMs from all members
computeRebalnceTimeoutMs :: Group -> IO Int32
computeRebalnceTimeoutMs Group{..} = do
computeRebalanceTimeoutMs :: Group -> IO Int32
computeRebalanceTimeoutMs Group{..} = do
H.foldM (\x (_, m) -> max x <$> IO.readIORef m.rebalanceTimeoutMs) 0 members

getJoinResponseMember :: T.Text -> Member -> IO K.JoinGroupResponseMember
Expand Down Expand Up @@ -702,17 +748,14 @@ leaveGroup group@Group{..} req = do
C.withMVar lock $ \() -> do
member <- getMember group req.memberId
IO.readIORef state >>= \case
Dead -> throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)
Dead -> throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)
Empty -> throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)
CompletingRebalance -> removeMemberAndUpdateGroup group member
Stable -> removeMemberAndUpdateGroup group member
PreparingRebalance -> do
-- TODO: should NOT BE PASSIBLE in this version
Log.warning $ "received a leave group in PreparingRebalance state, ignored it"
<> ", groupId:" <> Log.buildString' req.groupId
<> ", memberId:" <> Log.buildString' req.memberId
throw (ErrorCodeException K.UNKNOWN_MEMBER_ID)

Stable -> removeMemberAndUpdateGroup group member
-- Note: This is possible, of course. On this case, just try completing
-- current rebalance rather than preparing a new one. And in fact
-- we do not need to do anything because we have already watched it.
PreparingRebalance -> removeMemberAndUpdateGroup group member
return $ K.LeaveGroupResponse {errorCode=0, throttleTimeMs=0}

getMember :: Group -> T.Text -> IO Member
Expand All @@ -725,13 +768,34 @@ removeMemberAndUpdateGroup :: Group -> Member -> IO ()
removeMemberAndUpdateGroup group@Group{..} member = do
Log.info $ "member: " <> Log.build member.memberId <> " is leaving group:" <> Log.build groupId

-- Note from Kafka:
-- New members may timeout with a pending JoinGroup while the group is still rebalancing, so we have
-- to invoke the callback before removing the member. We return UNKNOWN_MEMBER_ID so that the consumer
-- will retry the JoinGroup request if is still active.
-- Note: This means returning UNKNOWN_MEMBER_ID to the JoinGroupRequest rather than the LeaveGroupRequest.
-- Note: A client may reset offsets right after leaving the group, and we have to make sure that
-- the group is stable (not rebalancing) before the client can reset offsets. Thanks to the
-- lock and rebalancing watcher, this can be atomatically met after the last consumer leaving.
cancelDelayedJoinResponse group member.memberId

removeMember group member
prepareRebalance group $ "remove member:" <> member.memberId

curState <- IO.readIORef state

if curState == PreparingRebalance then do
IO.readIORef delayedRebalance >>= \case
Nothing -> throw (ErrorCodeException K.UNKNOWN_SERVER_ERROR)
Just _ -> return ()
else if curState `elem` [Empty, Stable, CompletingRebalance] then do
groupRebalanceTimeoutMs <- H.foldM (\acc (_,thisMember) -> do
thisRBTimeoutMs <- IO.readIORef thisMember.rebalanceTimeoutMs
return (max acc thisRBTimeoutMs)
) 0 members
prepareRebalance group
(haveAllMembersRejoined group)
groupRebalanceTimeoutMs
("remove member:" <> member.memberId)
else return ()

cancelDelayedJoinResponse :: Group -> T.Text -> IO ()
cancelDelayedJoinResponse Group{..} memberId = do
Expand Down Expand Up @@ -1035,3 +1099,19 @@ makeJoinResponseError memberId errorCode =
, members = K.NonNullKaArray V.empty
, throttleTimeMs = 0
}


------------------------------ Misc -------------------------------
-- WARNING: Use the list each time this is called, rather than the list
-- when preparing a rebalance. Consider a case: c1, c2 and c3
-- leave the group in order. We should check [c2, c3] then [c3]
-- and finally [], rather than always checking [c2, c3] (this
-- will never be met!).
-- WARNING: Compare with syncing stage (hasReceivedAllSyncs), who uses
-- the list when preparing a rebalance.
haveAllMembersRejoined :: Group -> IO Bool
haveAllMembersRejoined group = do
H.foldM (\acc (mid,_) -> case acc of
False -> return False
True -> isJust <$> H.lookup group.delayedJoinResponses mid
) True group.members
Loading