diff --git a/hail/src/main/scala/is/hail/io/fs/FS.scala b/hail/src/main/scala/is/hail/io/fs/FS.scala index 0058c0b204e..a7f89cbd0af 100644 --- a/hail/src/main/scala/is/hail/io/fs/FS.scala +++ b/hail/src/main/scala/is/hail/io/fs/FS.scala @@ -192,9 +192,6 @@ abstract class FSSeekableInputStream extends InputStream with Seekable { } else { bb.clear() bb.limit(0) - if (bb.remaining() != 0) { - assert(false, bb.remaining().toString()) - } physicalSeek(newPos) } pos = newPos diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index 45f08874086..6d712ff0bc4 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -8,8 +8,8 @@ import com.google.cloud.storage.Storage.{BlobGetOption, BlobListOption, BlobWrit import com.google.cloud.storage.{Blob, BlobId, BlobInfo, Storage, StorageException, StorageOptions} import com.google.cloud.{ReadChannel, WriteChannel} import is.hail.io.fs.FSUtil.dropTrailingSlash -import is.hail.services.retryTransientErrors -import is.hail.utils.fatal +import is.hail.services.{retryTransientErrors, isTransientError} +import is.hail.utils._ import org.apache.log4j.Logger import java.io.{ByteArrayInputStream, FileNotFoundException, IOException} @@ -137,39 +137,27 @@ class GoogleStorageFS( } private[this] def retryIfRequesterPays[T, U]( - exc: Exception, - message: String, - code: Int, + exc: Throwable, makeRequest: Seq[U] => T, makeUserProjectOption: String => U, bucket: String ): T = { - if (message == null) { - throw exc - } - - val probablyNeedsRequesterPays = message.equals("userProjectMissing") || (code == 400 && message.contains("requester pays")) - if (!probablyNeedsRequesterPays) { + if (isRequesterPaysException(exc)) { + makeRequest(requesterPaysOptions(bucket, makeUserProjectOption)) + } else { throw exc } - - makeRequest(requesterPaysOptions(bucket, makeUserProjectOption)) } - def retryIfRequesterPays[T, U]( - exc: Throwable, - makeRequest: Seq[U] => T, - makeUserProjectOption: String => U, - bucket: String - ): T = exc match { + def isRequesterPaysException(exc: Throwable): Boolean = exc match { case exc: IOException if exc.getCause() != null => - retryIfRequesterPays(exc.getCause(), makeRequest, makeUserProjectOption, bucket) + isRequesterPaysException(exc.getCause()) case exc: StorageException => - retryIfRequesterPays(exc, exc.getMessage(), exc.getCode(), makeRequest, makeUserProjectOption, bucket) + exc.getMessage != null && (exc.getMessage.equals("userProjectMissing") || (exc.getCode == 400 && exc.getMessage.contains("requester pays"))) case exc: GoogleJsonResponseException => - retryIfRequesterPays(exc, exc.getMessage(), exc.getStatusCode(), makeRequest, makeUserProjectOption, bucket) + exc.getMessage != null && (exc.getMessage.equals("userProjectMissing") || (exc.getStatusCode == 400 && exc.getMessage.contains("requester pays"))) case exc: Throwable => - throw exc + false } private[this] def handleRequesterPays[T, U]( @@ -213,30 +201,28 @@ class GoogleStorageFS( val is: SeekableInputStream = new FSSeekableInputStream { private[this] var reader: ReadChannel = null - - private[this] def retryingRead(): Int = { - retryTransientErrors( - { reader.read(bb) }, - reset = Some({ () => reader.seek(getPosition) }) - ) - } + private[this] var options: Option[Seq[BlobSourceOption]] = None private[this] def readHandlingRequesterPays(bb: ByteBuffer): Int = { - if (reader != null) { - retryingRead() - } else { - handleRequesterPays( - { (options: Seq[BlobSourceOption]) => - reader = retryTransientErrors { - storage.reader(url.bucket, url.path, options:_*) - } + while (true) { + try { + if (reader == null) { + val opts = options.getOrElse(FastSeq()) + reader = storage.reader(url.bucket, url.path, opts:_*) reader.seek(getPosition) - retryingRead() - }, - BlobSourceOption.userProject, - url.bucket - ) + } + return reader.read(bb) + } catch { + case exc: Exception if isRequesterPaysException(exc) && options.isEmpty => + reader = null + bb.clear() + options = Some(requesterPaysOptions(url.bucket, BlobSourceOption.userProject)) + case exc: Exception if isTransientError(exc) => + reader = null + bb.clear() + } } + throw new RuntimeException("unreachable") } override def close(): Unit = {