From 977e494986153dd0e755fb362f1230959ccafcbf Mon Sep 17 00:00:00 2001 From: BenWhitehead Date: Mon, 14 Apr 2025 17:50:23 -0400 Subject: [PATCH] chore: fix a ConcurrentModificationException during BlobReadSession#close() A ConcurrentModificationException could happen while cleaning up an individual read at the same time the session had multiple child streams. --- .../cloud/storage/ObjectReadSessionImpl.java | 117 +++++++----- .../cloud/storage/ObjectReadSessionTest.java | 167 ++++++++++++++++++ 2 files changed, 242 insertions(+), 42 deletions(-) create mode 100644 google-cloud-storage/src/test/java/com/google/cloud/storage/ObjectReadSessionTest.java diff --git a/google-cloud-storage/src/main/java/com/google/cloud/storage/ObjectReadSessionImpl.java b/google-cloud-storage/src/main/java/com/google/cloud/storage/ObjectReadSessionImpl.java index a4d54a17dd..2998f4d8ff 100644 --- a/google-cloud-storage/src/main/java/com/google/cloud/storage/ObjectReadSessionImpl.java +++ b/google-cloud-storage/src/main/java/com/google/cloud/storage/ObjectReadSessionImpl.java @@ -32,12 +32,13 @@ import java.util.ArrayList; import java.util.IdentityHashMap; import java.util.Iterator; +import java.util.List; import java.util.Locale; import java.util.Map.Entry; import java.util.concurrent.ExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.locks.ReentrantLock; -import org.checkerframework.checker.lock.qual.GuardedBy; +import java.util.function.BiFunction; final class ObjectReadSessionImpl implements ObjectReadSession { @@ -49,12 +50,8 @@ final class ObjectReadSessionImpl implements ObjectReadSession { private final Object resource; private final RetryContextProvider retryContextProvider; - @GuardedBy("this.lock") - private final IdentityHashMap children; + private final ConcurrentIdentityMap children; - private final ReentrantLock lock; - - @GuardedBy("this.lock") private volatile boolean open; ObjectReadSessionImpl( @@ -69,8 +66,7 @@ final class ObjectReadSessionImpl implements ObjectReadSession { this.state = state; this.resource = state.getMetadata(); this.retryContextProvider = retryContextProvider; - this.children = new IdentityHashMap<>(); - this.lock = new ReentrantLock(); + this.children = new ConcurrentIdentityMap<>(); this.open = true; } @@ -81,45 +77,35 @@ public Object getResource() { @Override public Projection readAs(ReadProjectionConfig config) { - lock.lock(); - try { - checkState(open, "Session already closed"); - switch (config.getType()) { - case STREAM_READ: - long readId = state.newReadId(); - ObjectReadSessionStreamRead read = - config.cast().newRead(readId, retryContextProvider.create()); - registerReadInState(readId, read); - return read.project(); - case SESSION_USER: - return config.project(this, IOAutoCloseable.noOp()); - default: - throw new IllegalStateException( - String.format( - Locale.US, - "Broken java enum %s value=%s", - ProjectionType.class.getName(), - config.getType().name())); - } - } finally { - lock.unlock(); + checkState(open, "Session already closed"); + switch (config.getType()) { + case STREAM_READ: + long readId = state.newReadId(); + ObjectReadSessionStreamRead read = + config.cast().newRead(readId, retryContextProvider.create()); + registerReadInState(readId, read); + return read.project(); + case SESSION_USER: + return config.project(this, IOAutoCloseable.noOp()); + default: + throw new IllegalStateException( + String.format( + Locale.US, + "Broken java enum %s value=%s", + ProjectionType.class.getName(), + config.getType().name())); } } @Override public void close() throws IOException { - open = false; - lock.lock(); try { - Iterator> it = - children.entrySet().iterator(); - ArrayList> closing = new ArrayList<>(children.size()); - while (it.hasNext()) { - Entry next = it.next(); - ObjectReadSessionStream subStream = next.getKey(); - it.remove(); - closing.add(subStream.closeAsync()); + if (!open) { + return; } + open = false; + List> closing = + children.drainEntries((subStream, subStreamState) -> subStream.closeAsync()); stream.close(); ApiFutures.allAsList(closing).get(); } catch (ExecutionException e) { @@ -127,8 +113,6 @@ public void close() throws IOException { } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new InterruptedIOException(); - } finally { - lock.unlock(); } } @@ -152,4 +136,53 @@ private void registerReadInState(long readId, ObjectReadSessionStreamRead rea newStream.send(request); } } + + @VisibleForTesting + static final class ConcurrentIdentityMap { + private final ReentrantLock lock; + private final IdentityHashMap children; + + @VisibleForTesting + ConcurrentIdentityMap() { + lock = new ReentrantLock(); + children = new IdentityHashMap<>(); + } + + public void put(K key, V value) { + lock.lock(); + try { + children.put(key, value); + } finally { + lock.unlock(); + } + } + + public void remove(K key) { + lock.lock(); + try { + children.remove(key); + } finally { + lock.unlock(); + } + } + + public ArrayList drainEntries(BiFunction f) { + lock.lock(); + try { + Iterator> it = children.entrySet().iterator(); + ArrayList results = new ArrayList<>(children.size()); + while (it.hasNext()) { + Entry entry = it.next(); + K key = entry.getKey(); + V value = entry.getValue(); + it.remove(); + R r = f.apply(key, value); + results.add(r); + } + return results; + } finally { + lock.unlock(); + } + } + } } diff --git a/google-cloud-storage/src/test/java/com/google/cloud/storage/ObjectReadSessionTest.java b/google-cloud-storage/src/test/java/com/google/cloud/storage/ObjectReadSessionTest.java new file mode 100644 index 0000000000..2207117380 --- /dev/null +++ b/google-cloud-storage/src/test/java/com/google/cloud/storage/ObjectReadSessionTest.java @@ -0,0 +1,167 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.storage; + +import static com.google.cloud.storage.TestUtils.assertAll; +import static com.google.common.truth.Truth.assertThat; + +import com.google.cloud.storage.ObjectReadSessionImpl.ConcurrentIdentityMap; +import com.google.common.base.MoreObjects; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public final class ObjectReadSessionTest { + private static final AtomicInteger vCounter = new AtomicInteger(1); + + private static ListeningExecutorService exec; + + @BeforeClass + public static void beforeClass() { + exec = + MoreExecutors.listeningDecorator( + Executors.newFixedThreadPool(2, new ThreadFactoryBuilder().setDaemon(true).build())); + } + + @AfterClass + public static void afterClass() { + exec.shutdownNow(); + } + + @Test + public void concurrentIdentityMap_basic() throws Exception { + ConcurrentIdentityMap map = new ConcurrentIdentityMap<>(); + + map.put(new Key("k1"), new Value()); + map.put(new Key("k2"), new Value()); + map.put(new Key("k3"), new Value()); + map.put(new Key("k4"), new Value()); + + List strings = map.drainEntries((k, v) -> String.format("%s -> %s", k, v)); + assertThat(strings).hasSize(4); + + String joined = String.join("\n", strings); + assertAll( + () -> assertThat(joined).contains("k1"), + () -> assertThat(joined).contains("k2"), + () -> assertThat(joined).contains("k3"), + () -> assertThat(joined).contains("k4")); + } + + @Test + public void concurrentIdentityMap_multipleThreadsAdding() throws Exception { + ConcurrentIdentityMap map = new ConcurrentIdentityMap<>(); + + CountDownLatch cdl = new CountDownLatch(1); + map.put(new Key("t1k1"), new Value()); + map.put(new Key("t1k2"), new Value()); + + ListenableFuture submitted = + exec.submit( + () -> { + try { + boolean await = cdl.await(3, TimeUnit.SECONDS); + assertThat(await).isTrue(); + map.put(new Key("t2k1"), new Value()); + return true; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + BiFunction f = + (k, v) -> { + cdl.countDown(); + return String.format("%s -> %s", k, v); + }; + List strings = map.drainEntries(f); + assertThat(strings).hasSize(2); + String joined = String.join("\n", strings); + assertAll(() -> assertThat(joined).contains("t1k1"), () -> assertThat(joined).contains("t1k2")); + + submitted.get(1, TimeUnit.SECONDS); + List drain2 = map.drainEntries(f); + assertThat(drain2).hasSize(1); + } + + @Test + public void concurrentIdentityMap_removeAfterDrainClean() throws Exception { + ConcurrentIdentityMap map = new ConcurrentIdentityMap<>(); + + CountDownLatch cdl = new CountDownLatch(1); + map.put(new Key("t1k1"), new Value()); + Key t1k2 = new Key("t1k2"); + map.put(t1k2, new Value()); + + ListenableFuture submit = + exec.submit( + () -> { + try { + boolean await = cdl.await(3, TimeUnit.SECONDS); + assertThat(await).isTrue(); + map.remove(t1k2); + return true; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + BiFunction f = + (k, v) -> { + cdl.countDown(); + return String.format("%s -> %s", k, v); + }; + List strings = map.drainEntries(f); + assertThat(strings).hasSize(2); + String joined = String.join("\n", strings); + assertAll(() -> assertThat(joined).contains("t1k1"), () -> assertThat(joined).contains("t1k2")); + + assertThat(submit.get(1, TimeUnit.SECONDS)).isEqualTo(true); + } + + private static final class Key { + private final String k; + + private Key(String k) { + this.k = k; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("k", k).toString(); + } + } + + private static final class Value { + private final String v = String.format("v/%d", vCounter.getAndIncrement()); + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("v", v).toString(); + } + } +}