Skip to content

Commit

Permalink
KAFKA-15481: Fix concurrency bug in RemoteIndexCache (apache#14483)
Browse files Browse the repository at this point in the history
RemoteIndexCache has a concurrency bug which leads to IOException while fetching data from remote tier.

The bug could be reproduced as per the following order of events:-

Thread 1 (cache thread): invalidates the entry, removalListener is invoked async, so the files have not been renamed to "deleted" suffix yet.
Thread 2: (fetch thread): tries to find entry in cache, doesn't find it because it has been removed by 1, fetches the entry from S3, writes it to existing file (using replace existing)
Thread 1: async removalListener is invoked, acquires a lock on old entry (which has been removed from cache), it renames the file to "deleted" and starts deleting it
Thread 2: Tries to create in-memory/mmapped index, but doesn't find the file and hence, creates a new file of size 2GB in AbstractIndex constructor. JVM returns an error as it won't allow creation of 2GB random access file.

This commit fixes the bug by using EvictionListener instead of RemovalListener to perform the eviction atomically with the file rename. It handles the manual removal (not handled by EvictionListener) by using computeIfAbsent() and enforcing atomic cache removal & file rename.

Reviewers: Luke Chen <showuon@gmail.com>, Divij Vaidya <diviv@amazon.com>, Arpit Goyal
<goyal.arpit.91@gmail.com>, Kamal Chandraprakash <kamal.chandraprakash@gmail.com>
  • Loading branch information
jeel2420 authored and mjsax committed Nov 22, 2023
1 parent bad8345 commit 10dc795
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 60 deletions.
177 changes: 130 additions & 47 deletions core/src/test/scala/unit/kafka/log/remote/RemoteIndexCacheTest.scala
Expand Up @@ -23,7 +23,7 @@ import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
import org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentId, RemoteLogSegmentMetadata, RemoteResourceNotFoundException, RemoteStorageManager}
import org.apache.kafka.server.util.MockTime
import org.apache.kafka.storage.internals.log.RemoteIndexCache.{REMOTE_LOG_INDEX_CACHE_CLEANER_THREAD, remoteOffsetIndexFile, remoteOffsetIndexFileName, remoteTimeIndexFile, remoteTimeIndexFileName, remoteTransactionIndexFile, remoteTransactionIndexFileName}
import org.apache.kafka.storage.internals.log.RemoteIndexCache.{DIR_NAME, REMOTE_LOG_INDEX_CACHE_CLEANER_THREAD, remoteOffsetIndexFile, remoteOffsetIndexFileName, remoteTimeIndexFile, remoteTimeIndexFileName, remoteTransactionIndexFile, remoteTransactionIndexFileName}
import org.apache.kafka.storage.internals.log.{AbortedTxn, CorruptIndexException, LogFileUtils, OffsetIndex, OffsetPosition, RemoteIndexCache, TimeIndex, TransactionIndex}
import org.apache.kafka.test.{TestUtils => JTestUtils}
import org.junit.jupiter.api.Assertions._
Expand All @@ -32,14 +32,15 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers.any
import org.mockito.invocation.InvocationOnMock
import org.mockito.Mockito._
import org.slf4j.{Logger, LoggerFactory}

import java.io.{File, FileInputStream, IOException, PrintWriter}
import java.nio.file.{Files, Paths}
import java.io.{File, FileInputStream, IOException, PrintWriter, UncheckedIOException}
import java.nio.file.{Files, NoSuchFileException, Paths}
import java.util
import java.util.Collections
import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
import java.util.{Collections, Optional}
import java.util.concurrent.{CountDownLatch, Executors, Future, TimeUnit}
import scala.collection.mutable

class RemoteIndexCacheTest {
Expand Down Expand Up @@ -73,9 +74,9 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
indexType match {
case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
Expand Down Expand Up @@ -152,8 +153,8 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
indexType match {
case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
Expand Down Expand Up @@ -262,7 +263,7 @@ class RemoteIndexCacheTest {
}

@Test
def testCacheEntryIsDeletedOnInvalidation(): Unit = {
def testCacheEntryIsDeletedOnRemoval(): Unit = {
def getIndexFileFromDisk(suffix: String) = {
Files.walk(tpDir.toPath)
.filter(Files.isRegularFile(_))
Expand All @@ -284,8 +285,8 @@ class RemoteIndexCacheTest {
// no expired entries yet
assertEquals(0, cache.expiredIndexes.size, "expiredIndex queue should be zero at start of test")

// invalidate the cache. it should async mark the entry for removal
cache.internalCache.invalidate(internalIndexKey)
// call remove function to mark the entry for removal
cache.remove(internalIndexKey)

// wait until entry is marked for deletion
TestUtils.waitUntilTrue(() => cacheEntry.isMarkedForCleanup,
Expand All @@ -304,13 +305,13 @@ class RemoteIndexCacheTest {
verify(cacheEntry.txnIndex).renameTo(any(classOf[File]))

// verify no index files on disk
assertFalse(getIndexFileFromDisk(LogFileUtils.INDEX_FILE_SUFFIX).isPresent,
assertFalse(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.INDEX_FILE_SUFFIX).isPresent,
s"Offset index file should not be present on disk at ${tpDir.toPath}")
assertFalse(getIndexFileFromDisk(LogFileUtils.TXN_INDEX_FILE_SUFFIX).isPresent,
assertFalse(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.TXN_INDEX_FILE_SUFFIX).isPresent,
s"Txn index file should not be present on disk at ${tpDir.toPath}")
assertFalse(getIndexFileFromDisk(LogFileUtils.TIME_INDEX_FILE_SUFFIX).isPresent,
assertFalse(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.TIME_INDEX_FILE_SUFFIX).isPresent,
s"Time index file should not be present on disk at ${tpDir.toPath}")
assertFalse(getIndexFileFromDisk(LogFileUtils.DELETED_FILE_SUFFIX).isPresent,
assertFalse(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.DELETED_FILE_SUFFIX).isPresent,
s"Index file marked for deletion should not be present on disk at ${tpDir.toPath}")
}

Expand Down Expand Up @@ -558,16 +559,94 @@ class RemoteIndexCacheTest {
verifyFetchIndexInvocation(count = 1)
}

@Test
def testConcurrentRemoveReadForCache(): Unit = {
// Create a spy Cache Entry
val rlsMetadata = new RemoteLogSegmentMetadata(RemoteLogSegmentId.generateNew(idPartition), baseOffset, lastOffset,
time.milliseconds(), brokerId, time.milliseconds(), segmentSize, Collections.singletonMap(0, 0L))

val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata, new File(tpDir, DIR_NAME)))
val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata, new File(tpDir, DIR_NAME)))
val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata, new File(tpDir, DIR_NAME)))

val spyEntry = spy(new RemoteIndexCache.Entry(offsetIndex, timeIndex, txIndex))
cache.internalCache.put(rlsMetadata.remoteLogSegmentId().id(), spyEntry)

assertCacheSize(1)

var entry: RemoteIndexCache.Entry = null

val latchForCacheRead = new CountDownLatch(1)
val latchForCacheRemove = new CountDownLatch(1)
val latchForTestWait = new CountDownLatch(1)

var markForCleanupCallCount = 0

doAnswer((invocation: InvocationOnMock) => {
markForCleanupCallCount += 1

if (markForCleanupCallCount == 1) {
// Signal the CacheRead to unblock itself
latchForCacheRead.countDown()
// Wait for signal to start renaming the files
latchForCacheRemove.await()
// Calling the markForCleanup() actual method to start renaming the files
invocation.callRealMethod()
// Signal TestWait to unblock itself so that test can be completed
latchForTestWait.countDown()
}
}).when(spyEntry).markForCleanup()

val removeCache = (() => {
cache.remove(rlsMetadata.remoteLogSegmentId().id())
}): Runnable

val readCache = (() => {
// Wait for signal to start CacheRead
latchForCacheRead.await()
entry = cache.getIndexEntry(rlsMetadata)
// Signal the CacheRemove to start renaming the files
latchForCacheRemove.countDown()
}): Runnable

val executor = Executors.newFixedThreadPool(2)
try {
val removeCacheFuture: Future[_] = executor.submit(removeCache: Runnable)
val readCacheFuture: Future[_] = executor.submit(readCache: Runnable)

// Verify both tasks are completed without any exception
removeCacheFuture.get()
readCacheFuture.get()

// Wait for signal to complete the test
latchForTestWait.await()

// We can't determine read thread or remove thread will go first so if,
// 1. Read thread go first, cache file should not exist and cache size should be zero.
// 2. Remove thread go first, cache file should present and cache size should be one.
// so basically here we are making sure that if cache existed, the cache file should exist,
// and if cache is non-existed, the cache file should not exist.
if (getIndexFileFromRemoteCacheDir(cache, LogFileUtils.INDEX_FILE_SUFFIX).isPresent) {
assertCacheSize(1)
} else {
assertCacheSize(0)
}
} finally {
executor.shutdownNow()
}

}

@Test
def testMultipleIndexEntriesExecutionInCorruptException(): Unit = {
reset(rsm)
when(rsm.fetchIndex(any(classOf[RemoteLogSegmentMetadata]), any(classOf[IndexType])))
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
// Create corrupted index file
createCorruptTimeIndexOffsetFile(tpDir)
Expand Down Expand Up @@ -603,9 +682,9 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
indexType match {
case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
Expand All @@ -629,13 +708,6 @@ class RemoteIndexCacheTest {
val remoteIndexCacheDir = cache.cacheDir()
val tempSuffix = ".tmptest"

def getRemoteCacheIndexFileFromDisk(suffix: String) = {
Files.walk(remoteIndexCacheDir.toPath)
.filter(Files.isRegularFile(_))
.filter(path => path.getFileName.toString.endsWith(suffix))
.findAny()
}

def renameRemoteCacheIndexFileFromDisk(suffix: String) = {
Files.walk(remoteIndexCacheDir.toPath)
.filter(Files.isRegularFile(_))
Expand All @@ -650,7 +722,7 @@ class RemoteIndexCacheTest {
Files.copy(entry.txnIndex().file().toPath(), Paths.get(Utils.replaceSuffix(entry.txnIndex().file().getPath(), "", tempSuffix)))
Files.copy(entry.timeIndex().file().toPath(), Paths.get(Utils.replaceSuffix(entry.timeIndex().file().getPath(), "", tempSuffix)))

cache.internalCache().invalidate(rlsMetadata.remoteLogSegmentId().id())
cache.remove(rlsMetadata.remoteLogSegmentId().id())

// wait until entry is marked for deletion
TestUtils.waitUntilTrue(() => entry.isMarkedForCleanup,
Expand All @@ -666,9 +738,9 @@ class RemoteIndexCacheTest {
// Index Files already exist ,rsm should not fetch them again.
verifyFetchIndexInvocation(count = 1)
// verify index files on disk
assertTrue(getRemoteCacheIndexFileFromDisk(LogFileUtils.INDEX_FILE_SUFFIX).isPresent, s"Offset index file should be present on disk at ${remoteIndexCacheDir.toPath}")
assertTrue(getRemoteCacheIndexFileFromDisk(LogFileUtils.TXN_INDEX_FILE_SUFFIX).isPresent, s"Txn index file should be present on disk at ${remoteIndexCacheDir.toPath}")
assertTrue(getRemoteCacheIndexFileFromDisk(LogFileUtils.TIME_INDEX_FILE_SUFFIX).isPresent, s"Time index file should be present on disk at ${remoteIndexCacheDir.toPath}")
assertTrue(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.INDEX_FILE_SUFFIX).isPresent, s"Offset index file should be present on disk at ${remoteIndexCacheDir.toPath}")
assertTrue(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.TXN_INDEX_FILE_SUFFIX).isPresent, s"Txn index file should be present on disk at ${remoteIndexCacheDir.toPath}")
assertTrue(getIndexFileFromRemoteCacheDir(cache, LogFileUtils.TIME_INDEX_FILE_SUFFIX).isPresent, s"Time index file should be present on disk at ${remoteIndexCacheDir.toPath}")
}

@ParameterizedTest
Expand All @@ -678,9 +750,9 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
// Create corrupt index file return from RSM
createCorruptedIndexFile(testIndexType, tpDir)
Expand Down Expand Up @@ -725,7 +797,7 @@ class RemoteIndexCacheTest {
// verify deleted file exists on disk
assertTrue(getRemoteCacheIndexFileFromDisk(LogFileUtils.DELETED_FILE_SUFFIX).isPresent, s"Deleted Offset index file should be present on disk at ${remoteIndexCacheDir.toPath}")

cache.internalCache().invalidate(rlsMetadata.remoteLogSegmentId().id())
cache.remove(rlsMetadata.remoteLogSegmentId().id())

// wait until entry is marked for deletion
TestUtils.waitUntilTrue(() => entry.isMarkedForCleanup,
Expand All @@ -748,9 +820,9 @@ class RemoteIndexCacheTest {
= RemoteLogSegmentId.generateNew(idPartition)): RemoteIndexCache.Entry = {
val rlsMetadata = new RemoteLogSegmentMetadata(remoteLogSegmentId, baseOffset, lastOffset,
time.milliseconds(), brokerId, time.milliseconds(), segmentSize, Collections.singletonMap(0, 0L))
val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata))
val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata))
val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata))
val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata, tpDir))
val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata, tpDir))
val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata, tpDir))
spy(new RemoteIndexCache.Entry(offsetIndex, timeIndex, txIndex))
}

Expand Down Expand Up @@ -778,8 +850,8 @@ class RemoteIndexCacheTest {
}
}

private def createTxIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata): TransactionIndex = {
val txnIdxFile = remoteTransactionIndexFile(tpDir, metadata)
private def createTxIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata, dir: File): TransactionIndex = {
val txnIdxFile = remoteTransactionIndexFile(dir, metadata)
txnIdxFile.createNewFile()
new TransactionIndex(metadata.startOffset(), txnIdxFile)
}
Expand All @@ -800,14 +872,14 @@ class RemoteIndexCacheTest {
return new TransactionIndex(100L, txnIdxFile)
}

private def createTimeIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata): TimeIndex = {
private def createTimeIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata, dir: File): TimeIndex = {
val maxEntries = (metadata.endOffset() - metadata.startOffset()).asInstanceOf[Int]
new TimeIndex(remoteTimeIndexFile(tpDir, metadata), metadata.startOffset(), maxEntries * 12)
new TimeIndex(remoteTimeIndexFile(dir, metadata), metadata.startOffset(), maxEntries * 12)
}

private def createOffsetIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata) = {
private def createOffsetIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata, dir: File) = {
val maxEntries = (metadata.endOffset() - metadata.startOffset()).asInstanceOf[Int]
new OffsetIndex(remoteOffsetIndexFile(tpDir, metadata), metadata.startOffset(), maxEntries * 8)
new OffsetIndex(remoteOffsetIndexFile(dir, metadata), metadata.startOffset(), maxEntries * 8)
}

private def generateRemoteLogSegmentMetadata(size: Int,
Expand Down Expand Up @@ -860,4 +932,15 @@ class RemoteIndexCacheTest {
createCorruptTxnIndexForSegmentMetadata(dir, rlsMetadata)
}
}

private def getIndexFileFromRemoteCacheDir(cache: RemoteIndexCache, suffix: String) = {
try {
Files.walk(cache.cacheDir().toPath())
.filter(Files.isRegularFile(_))
.filter(path => path.getFileName.toString.endsWith(suffix))
.findAny()
} catch {
case e @ (_ : NoSuchFileException | _ : UncheckedIOException) => Optional.empty()
}
}
}

0 comments on commit 10dc795

Please sign in to comment.