Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared mutable state and concurrency #49

Closed
Strydom opened this issue Nov 2, 2022 · 8 comments
Closed

Shared mutable state and concurrency #49

Strydom opened this issue Nov 2, 2022 · 8 comments

Comments

@Strydom
Copy link

Strydom commented Nov 2, 2022

Hello,

I have tried to replicate a bare bones version of this library so that i can quickly extend and test new functionality like Streaming, SSL and AUTH.
The issue is, I'm observing unsafe access to the read channel when multiple clients are both reading and writing. Clients are reading responses from the writes (XADD, EXPIRE) instead of the XREAD, INFO etc.

Here is my bare bones code. Hopefully I have just copied something wrong / missed something, and it is not an underlying issue with the implementation 🤞

Bare bones code
@Suppress("TooGenericExceptionThrown", "TooManyFunctions")
@Service
class NettyRedisClient(
    @Value("\${redis.host}") private val host: String,
    @Value("\${redis.port}") private val port: Int,
    @Value("\${redis.password}") private val password: String
) : ExclusiveObject, DisposableBean {
    private val logger = KotlinLogging.logger {}

    final override val mutex: Mutex = Mutex()
    override val key: ReentrantMutexContextKey = ReentrantMutexContextKey(mutex)

    private val group = NioEventLoopGroup()
    private val bootstrap = Bootstrap()
        .group(group)
        .remoteAddress(host, port)
        .option(ChannelOption.SO_KEEPALIVE, true)
        .channel(NioSocketChannel::class.java)

    val sslContext: SslContext = SslContextBuilder
        .forClient()
        .trustManager(InsecureTrustManagerFactory.INSTANCE)
        .build()

    private var writeChannel: SocketChannel? = null
    private var readChannel: KChannel<RedisMessage>? = null

    fun pipelined(): Pipelined = Pipelined(this)

    suspend fun connect() = withReentrantLock {
        if (!isConnected()) {
            val newReadChannel = KChannel<RedisMessage>(KChannel.UNLIMITED)
            val newWriteChannel = bootstrap
                .handler(LoggingHandler(LogLevel.INFO))
                .handler(channelInitializer(newReadChannel))
                .connect()
                .suspendableAwait() as SocketChannel

            readChannel = newReadChannel
            writeChannel = newWriteChannel

            auth(newWriteChannel)
        }
    }

    private fun channelInitializer(newReadChannel: KChannel<RedisMessage>): ChannelInitializer<SocketChannel> {
        return object : ChannelInitializer<SocketChannel>() {
            override fun initChannel(channel: SocketChannel) {
                val pipeline: ChannelPipeline = channel.pipeline()
                pipeline.addLast(sslContext.newHandler(channel.alloc(), host, port))
                pipeline.addLast(RedisDecoder())
                pipeline.addLast(RedisBulkStringAggregator())
                pipeline.addLast(RedisArrayAggregator())
                pipeline.addLast(RedisEncoder())
                pipeline.addLast(commandHandler(newReadChannel))
            }
        }
    }

    private fun commandHandler(newReadChannel: KChannel<RedisMessage>) = object : ChannelDuplexHandler() {
        override fun write(
            handlerContext: ChannelHandlerContext,
            message: Any,
            promise: ChannelPromise
        ) {
            val commands = (message as String)
                .trim()
                .split(Regex("\\s+"))
                .map { command ->
                    FullBulkStringRedisMessage(
                        ByteBufUtil.writeUtf8(
                            handlerContext.alloc(),
                            command
                        )
                    )
                }
            val request: RedisMessage = ArrayRedisMessage(commands)
            handlerContext.write(request, promise)
        }

        override fun channelRead(handlerContext: ChannelHandlerContext, message: Any) {
            message as RedisMessage
            newReadChannel.trySend(message)
        }

        override fun exceptionCaught(handlerContext: ChannelHandlerContext, cause: Throwable) {
            handlerContext.close()
            newReadChannel.close(cause)
        }
    }

    suspend fun executeCommands(commands: List<String>): List<RedisMessage> = withReentrantLock {
        connect()

        commands.forEach {
            write(it)
        }

        flush()

        commands.map {
            read()
        }
    }

    suspend fun executeCommand(command: String): RedisMessage = withReentrantLock {
        connect()
        writeAndFlush(command)
        read()
    }

    suspend fun info(): String {
        return decode(executeCommand("INFO"))
    }

    suspend fun flushAll(): String {
        return decode(executeCommand("FLUSHALL"))
    }

    suspend fun dbSize(): Long {
        return decode(executeCommand("DBSIZE")).toLong()
    }

    suspend fun xread(streamNames: List<String>, block: Long?): String {
        val streamNamesString = streamNames.joinToString(" ")
        val streamOffsetsString = streamNames.joinToString(" ") { "0-0" }
        val string = "$streamNamesString $streamOffsetsString"
        val command = if (block == null) {
            "XREAD STREAMS $string"
        } else {
            "XREAD BLOCK $block STREAMS $string"
        }

        return decode(executeCommand((command)))
    }

    private suspend fun auth(writeChannel: SocketChannel) = withReentrantLock {
        writeChannel.writeAndFlush("AUTH $password")
        val response = (read() as SimpleStringRedisMessage).content()
        if (response == "OK") {
            logger.info("AUTH successful")
        } else {
            throw RuntimeException("AUTH failed")
        }
    }

    private suspend fun isConnected(): Boolean = withReentrantLock {
        if (writeChannel == null || readChannel == null) {
            false
        } else {
            writeChannel!!.isActive
        }
    }

    private suspend fun write(message: String): Unit = withReentrantLock {
        if (!isConnected()) {
            throw RuntimeException("Not yet connected")
        } else {
            writeChannel!!.write(message)
        }
    }

    private suspend fun writeAndFlush(message: String): Unit = withReentrantLock {
        if (!isConnected()) {
            throw RuntimeException("Not yet connected")
        } else {
            writeChannel!!.writeAndFlush(message)
        }
    }

    private suspend fun flush(): Unit = withReentrantLock {
        if (!isConnected()) {
            throw RuntimeException("Not yet connected")
        } else {
            writeChannel!!.flush()
        }
    }

    private suspend fun read(): RedisMessage = withReentrantLock {
        if (!isConnected()) {
            throw RuntimeException("Not yet connected")
        } else {
            readChannel!!.receive()
        }
    }

    override fun destroy() {
        runBlocking {
            withReentrantLock {
                readChannel?.close()
                writeChannel?.close()
                group.shutdownGracefully()
            }
        }
    }
}

@Suppress("TooGenericExceptionThrown")
class Pipelined(private val client: NettyRedisClient) : ExclusiveObject {
    override val mutex: Mutex = Mutex()
    override val key: ReentrantMutexContextKey = ReentrantMutexContextKey(mutex)

    private var done = false
    private val responseFlow = MutableSharedFlow<List<String>>(1)
    private val sharedResponseFlow: Flow<List<String>> = responseFlow.asSharedFlow()
    private val commands = mutableListOf<String>()
    private val commandResponse = mutableListOf<String>()

    suspend fun xadd(streamName: String, keyValues: Map<String, String>): Response {
        val keyValuesString = keyValues.map { "${it.key} ${it.value}" }.joinToString(" ")
        val command = "XADD $streamName * $keyValuesString"

        return add(command)
    }

    suspend fun expire(streamName: String, seconds: Long): Response {
        return add("EXPIRE $streamName $seconds")
    }

    suspend fun execute(): Unit = withReentrantLock {
        if (!done) {
            commandResponse.addAll(executePipeline(commands))
            done = true
            responseFlow.tryEmit(commandResponse.toMutableList())
        }
    }

    private suspend fun add(command: String): Response = withReentrantLock {
        commands.add(command)
        Response(sharedResponseFlow, commands.lastIndex)
    }

    private suspend fun executePipeline(commands: List<String>): List<String> = withReentrantLock {
        val responseMessages = client.executeCommands(commands)

        responseMessages.map { message ->
            decode(message)
        }
    }
}

@Suppress("TooGenericExceptionThrown")
internal fun decode(message: RedisMessage): String {
    return when (message) {
        is ErrorRedisMessage -> message.content()
        is SimpleStringRedisMessage -> message.content()
        is IntegerRedisMessage -> message.value().toString()
        is FullBulkStringRedisMessage -> {
            if (message.isNull) {
                throw RuntimeException("Stream response is null")
            } else {
                message.content().toString(Charset.defaultCharset())
            }
        }

        is ArrayRedisMessage -> {
            message.children().joinToString(" ") { child ->
                decode(child)
            }
        }

        else -> throw NotImplementedError("Message type not implemented")
    }
}

@Suppress("TooGenericExceptionThrown")
class Response internal constructor(
    private val flow: Flow<List<String>>,
    private val index: Int
) {
    suspend operator fun invoke(): String {
        return flow.first().ifEmpty { throw RuntimeException("Operation was cancelled.") }[index]
    }

    suspend fun get(): String = invoke()
}

internal interface ExclusiveObject {
    val mutex: Mutex
    val key: ReentrantMutexContextKey
}

data class ReentrantMutexContextKey(val mutex: Mutex) : CoroutineContext.Key<ReentrantMutexContextElement>
internal class ReentrantMutexContextElement(override val key: ReentrantMutexContextKey) : CoroutineContext.Element

internal suspend inline fun <R> ExclusiveObject.withReentrantLock(crossinline block: suspend () -> R): R {
    if (coroutineContext[key] != null) return block()

    return withContext(ReentrantMutexContextElement(key)) {
        this@withReentrantLock.mutex.withLock {
            block()
        }
    }
}

internal suspend fun ChannelFuture.suspendableAwait(): Channel {
    return suspendCoroutine { continuation ->
        addListener(object : ChannelFutureListener {
            override fun operationComplete(future: ChannelFuture) {
                if (future.isDone && future.isSuccess) {
                    continuation.resume(future.channel())
                } else {
                    continuation.resumeWithException(future.cause())
                }
            }
        })
    }
}
@Strydom
Copy link
Author

Strydom commented Nov 2, 2022

I'm also curious what was the reasoning for changing actor in this commit?
e3b1594

@crackthecodeabhi
Copy link
Owner

crackthecodeabhi commented Nov 2, 2022

I'm also curious what was the reasoning for changing actor in this commit?
e3b1594

The mutex is not reentrant, that commit uses a re-entrant version of mutex, making it more easier to reason and implement coroutine safe code.

@Strydom
Copy link
Author

Strydom commented Nov 3, 2022

Any thoughts on why I'm seeing results that imply the lock is not working as intended?

@Strydom
Copy link
Author

Strydom commented Jan 20, 2023

I didn't realise you needed to create a new connection for every request. That was the cause of the problem.

@crackthecodeabhi
Copy link
Owner

@Strydom new connection is not created for every request.

@Strydom
Copy link
Author

Strydom commented Jan 20, 2023

.use closes the connection when it is done. Every http request coming into the app will open a new one afaik?

@crackthecodeabhi
Copy link
Owner

That is just for the example. You can close the connection when your web app is exiting.

@Strydom
Copy link
Author

Strydom commented Jan 20, 2023

I've tried that, and it didn't work for my bare bones version of your library, maybe i messed something up. Can you see anything wrong with my code?
Only reason I'm using a bare bones version is because I needed AUTH, SSL support and Stream commands like XREAD and XADD.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants