Skip to content

Commit

Permalink
Merge pull request #104 from hnaderi/79-authentication-failure-notifi…
Browse files Browse the repository at this point in the history
…cation

79 authentication failure notification
  • Loading branch information
hnaderi committed Apr 6, 2023
2 parents 4162c51 + 88a509f commit 0e92cc9
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 8 deletions.
9 changes: 6 additions & 3 deletions modules/client/src/main/scala/Connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ package lepus.client

import cats.effect.*
import cats.effect.implicits.*
import cats.effect.kernel.Resource.ExitCase.*
import cats.effect.std.Queue
import cats.effect.std.QueueSink
import cats.effect.std.QueueSource
import cats.implicits.*
import fs2.Pipe
import fs2.Stream
Expand Down Expand Up @@ -115,7 +114,11 @@ object Connection {
}
case m: Frame.Method => dispatcher.invoke(m)
case Frame.Heartbeat => state.onHeartbeat
}.onFinalize(state.onClosed).interruptWhen(state.whenClosed)
}.onFinalizeCase {
case Succeeded => state.onClosed
case Errored(e) => state.onFailed(e)
case Canceled => state.onClosed
}.interruptWhen(state.whenClosed)

private[client] def lifetime[F[_]: Temporal](
config: F[NegotiatedConfig],
Expand Down
10 changes: 9 additions & 1 deletion modules/client/src/main/scala/StartupNegotiation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ import fs2.Pipe
import fs2.Pull
import fs2.Stream
import lepus.protocol.ConnectionClass
import lepus.protocol.ConnectionClass.Close
import lepus.protocol.ConnectionClass.Secure
import lepus.protocol.ConnectionClass.Start
import lepus.protocol.ConnectionClass.Tune
import lepus.protocol.Frame
import lepus.protocol.Method
import lepus.protocol.constants.ReplyCode
import lepus.protocol.domains.*

trait StartupNegotiation[F[_]] {
Expand Down Expand Up @@ -125,6 +127,8 @@ object StartupNegotiation {
)
)
case msg: Tune => afterChallenge(msg)
case Close(ReplyCode.AccessRefused, details, _, _) =>
AuthenticationFailure(details).raiseError
}
private def afterChallenge
: ConnectionClass.Tune => F[NegotiationResult[F]] = {
Expand All @@ -150,7 +154,7 @@ object StartupNegotiation {
ShortString("scala-version") -> ShortString(BuildInfo.scalaVersion),
ShortString("capabilities") -> FieldTable(
ShortString("publisher_confirms") -> true,
// ShortString("authentication_failure_close") -> true,
ShortString("authentication_failure_close") -> true,
ShortString("consumer_cancel_notify") -> true,
ShortString("basic.nack") -> true
// ShortString("connection.blocked") -> true
Expand Down Expand Up @@ -187,6 +191,10 @@ type Negotiation[F[_]] = Frame => F[NegotiationResult[F]]

case object NegotiationError
extends Exception("Error while negotiating with server!")
case class AuthenticationFailure(details: String)
extends Exception(
s"Server refused connection due to authentication failure!\nDetails: $details"
)
case object NoSupportedSASLMechanism
extends Exception(
"Server does not support any of your requested SASL mechanisms!"
Expand Down
11 changes: 7 additions & 4 deletions modules/client/src/main/scala/internal/ConnectionState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ import lepus.protocol.Frame
import lepus.protocol.constants.ReplyCode
import lepus.protocol.domains.*

import ConnectionState.TerminalState

private[client] trait ConnectionState[F[_]] extends Signal[F, Status] {
def onConnected(config: NegotiatedConfig): F[Unit]
def onOpened: F[Unit]
def onClosed: F[Unit]
def onFailed(ex: Throwable): F[Unit]
def onClosed: F[Unit] = onFailed(TerminalState)
def onCloseRequest: F[Unit]
def onCloseRequest(req: ConnectionClass.Close): F[Unit]
def onHeartbeat: F[Unit]
Expand Down Expand Up @@ -84,9 +87,9 @@ private[client] object ConnectionState {
)
)

override def onClosed: F[Unit] =
hasOpened.complete(Left(TerminalState)) *>
configDef.complete(Left(TerminalState)) *>
override def onFailed(ex: Throwable): F[Unit] =
hasOpened.complete(Left(ex)) *>
configDef.complete(Left(ex)) *>
output.onClose *>
dispatcher.onClose *>
underlying.set(Status.Closed)
Expand Down
11 changes: 11 additions & 0 deletions modules/client/src/test/scala/internal/ConnectionStateSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import lepus.protocol.ConnectionClass
import lepus.protocol.Frame
import lepus.protocol.constants.ReplyCode
import lepus.protocol.domains.*
import org.scalacheck.Arbitrary
import org.scalacheck.Gen

import java.util.concurrent.TimeoutException
Expand Down Expand Up @@ -95,6 +96,16 @@ class ConnectionStateSuite extends InternalTestSuite {
} yield ()
}

test("config raises underlying error if closed") {
forAllF(Arbitrary.arbitrary[Throwable]) { ex =>
for {
s <- SUT
_ <- s.onFailed(ex)
_ <- s.config.attempt.assertEquals(Left(ex))
} yield ()
}
}

test("Raises error if onConnected is called more than once") {
for {
s <- SUT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ final class FakeConnectionState(
override def onClosed: IO[Unit] =
interactions.add(Interaction.Closed) >> state.set(Status.Closed)

override def onFailed(ex: Throwable): IO[Unit] =
interactions.add(Interaction.Failed(ex)) >> state.set(Status.Closed)

override def onOpened: IO[Unit] = state.set(Status.Opened) >>
openedDef.complete(Right(())) >> interactions.add(Interaction.Opened)

Expand All @@ -89,6 +92,7 @@ object FakeConnectionState {
case Connected(config: NegotiatedConfig)
case CloseRequest(close: ConnectionClass.Close)
case ClientCloseRequest, Opened, Closed, Heartbeat
case Failed(ex: Throwable)
}

def apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,55 @@ class StartupNegotiationSuite extends InternalTestSuite {
} yield ()
}

// This is a RabbitMQ extension https://www.rabbitmq.com/auth-notification.html
check(
"Notifies authentiaction failure if server closes connection with Access refused"
) {
val serverResponses = fs2.Stream(
method(
ConnectionClass.Start(
0,
9,
FieldTable.empty,
LongString("fake1 fake2"),
locales = LongString("")
)
),
method(ConnectionClass.Secure(LongString("abc"))),
method(
ConnectionClass.Close(
ReplyCode.AccessRefused,
ShortString.empty,
ClassId(0),
MethodId(0)
)
)
)

val expected = List(
method(
ConnectionClass.StartOk(
clientProps,
mechanism = ShortString("fake1"),
response = LongString("initial"),
locale = ShortString("en-US")
)
),
method(ConnectionClass.SecureOk(LongString("abc")))
)

for {
sut <- StartupNegotiation(auth)
send <- ExpectedQueue(expected)
_ <- sut
.pipe(send.assert)(serverResponses)
.compile
.drain
.intercept[AuthenticationFailure]
_ <- sut.config.intercept[AuthenticationFailure]
} yield ()
}

private def method(value: Method) = Frame.Method(ChannelNumber(0), value)
}

Expand Down

0 comments on commit 0e92cc9

Please sign in to comment.