Skip to content

Commit

Permalink
Added invalidate-bad-cache to call caching options
Browse files Browse the repository at this point in the history
  • Loading branch information
cjllanwarne committed Oct 18, 2016
1 parent e5b36f8 commit ebba140
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 57 deletions.
5 changes: 5 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ workflow-options {

// Optional call-caching configuration.
call-caching {
# Allows re-use of existing results for jobs you've already run
enabled = false

# Whether to invalidate a cache result forever if we cannot reuse them.
# Disable this if you expect some users to be unable to copy some other users results:
# invalidate-bad-cache-results = true
}

google {
Expand Down
41 changes: 25 additions & 16 deletions core/src/main/scala/cromwell/core/callcaching/CallCachingMode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ sealed trait CallCachingMode {
* Return an equivalent of this call caching mode with READ disabled.
*/
val withoutRead: CallCachingMode

val withoutWrite: CallCachingMode

val readFromCache = false
Expand All @@ -19,22 +18,32 @@ case object CallCachingOff extends CallCachingMode {
override val withoutWrite = this
}

case class CallCachingActivity(readWriteMode: ReadWriteMode) extends CallCachingMode {
override val readFromCache = readWriteMode.r
override val writeToCache = readWriteMode.w
override lazy val withoutRead: CallCachingMode = if (!writeToCache) CallCachingOff else this.copy(readWriteMode = WriteCache)
override lazy val withoutWrite: CallCachingMode = if (!readFromCache) CallCachingOff else this.copy(readWriteMode = ReadCache)
override val toString = readWriteMode.toString
final case class ReadWriteMode(r: Boolean, w: Boolean)

sealed trait CallCachingActivity extends CallCachingMode {
val options: CallCachingOptions
}

sealed trait ReadWriteMode {
val r: Boolean = true
val w: Boolean = true
case class ReadCache(options: CallCachingOptions = CallCachingOptions()) extends CallCachingActivity {
override val readFromCache = true
override val writeToCache = false
override lazy val withoutRead: CallCachingMode = CallCachingOff
override lazy val withoutWrite: CallCachingMode = this
override val toString = "ReadCache"
}
case class WriteCache(options: CallCachingOptions = CallCachingOptions()) extends CallCachingActivity {
override val readFromCache = false
override val writeToCache = true
override lazy val withoutRead: CallCachingMode = this
override lazy val withoutWrite: CallCachingMode = CallCachingOff
override val toString = "WriteCache"
}
case class ReadAndWriteCache(options: CallCachingOptions = CallCachingOptions()) extends CallCachingActivity {
override val readFromCache = true
override val writeToCache = true
override lazy val withoutRead: CallCachingMode = WriteCache(options)
override lazy val withoutWrite: CallCachingMode = ReadCache(options)
override val toString = "ReadAndWriteCache"
}
case object ReadCache extends ReadWriteMode { override val w = false }
case object WriteCache extends ReadWriteMode { override val r = false }
case object ReadAndWriteCache extends ReadWriteMode

sealed trait DockerHashingType
case object HashDockerName extends DockerHashingType
case object HashDockerNameAndLookupDockerHash extends DockerHashingType
final case class CallCachingOptions(invalidateBadCacheResults: Boolean = true)
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,17 @@ object MaterializeWorkflowDescriptorActor {
}

val enabled = conf.as[Option[Boolean]]("call-caching.enabled").getOrElse(false)
val invalidateBadCacheResults = conf.as[Option[Boolean]]("call-caching.invalidate-bad-cache-results").getOrElse(true)
val callCachingOptions = CallCachingOptions(invalidateBadCacheResults)
if (enabled) {
val readFromCache = readOptionalOption(ReadFromCache)
val writeToCache = readOptionalOption(WriteToCache)

(readFromCache |@| writeToCache) map {
case (false, false) => CallCachingOff
case (true, false) => CallCachingActivity(ReadCache)
case (false, true) => CallCachingActivity(WriteCache)
case (true, true) => CallCachingActivity(ReadAndWriteCache)
case (true, false) => ReadCache(callCachingOptions)
case (false, true) => WriteCache(callCachingOptions)
case (true, true) => ReadAndWriteCache(callCachingOptions)
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,7 @@ class EngineJobExecutionActor(replyTo: ActorRef,
saveJobCompletionToJobStore(data.withSuccessResponse(response))
case Event(response: BackendJobExecutionResponse, data @ ResponsePendingData(_, _, _, Some(cacheHit))) =>
response match {
case f: BackendJobFailedResponse =>
invalidateCacheHit(cacheHit.cacheResultIds.head)
log.error(f.throwable, "Failed copying cache results for job {}, invalidating cache entry.", jobDescriptorKey)
goto(InvalidatingCacheEntry)
case f: BackendJobFailedResponse => invalidateCacheHitAndTransition(cacheHit.cacheResultIds.head, data, f.throwable)
case _ => runJob(data)
}

Expand Down Expand Up @@ -384,6 +381,22 @@ class EngineJobExecutionActor(replyTo: ActorRef,
()
}

private def invalidateCacheHitAndTransition(cacheId: CallCachingEntryId, data: ResponsePendingData, reason: Throwable) = {
val invalidationRequired = effectiveCallCachingMode match {
case CallCachingOff =>
log.error("Should not be calling invalidateCacheHit if call caching is off!")
false
case activity: CallCachingActivity => activity.options.invalidateBadCacheResults
}
if (invalidationRequired) {
log.error(reason, "Failed copying cache results for job {}, invalidating cache entry.", jobDescriptorKey)
invalidateCacheHit(cacheId)
goto(InvalidatingCacheEntry)
} else {
handleCacheInvalidatedResponse(CallCacheInvalidationUnnecessary, data)
}
}

protected def invalidateCacheHit(cacheId: CallCachingEntryId): Unit = {
val callCache = new CallCache(SingletonServicesStore.databaseInterface)
context.actorOf(CallCacheInvalidateActor.props(callCache, cacheId), s"CallCacheInvalidateActor${cacheId.id}-$tag")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ object CallCacheInvalidateActor {

sealed trait CallCacheInvalidatedResponse
case object CallCacheInvalidatedSuccess extends CallCacheInvalidatedResponse
case object CallCacheInvalidationUnnecessary extends CallCacheInvalidatedResponse
case class CallCacheInvalidatedFailure(t: Throwable) extends CallCacheInvalidatedResponse
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ class EJHADataSpec extends FlatSpec with Matchers {
val allHashKeys = Set(hashKey1, hashKey2, hashKey3, hashKey4, hashKey5)

it should "create lists appropriately in the apply method" in {
val readWriteData = EJHAData(allHashKeys, CallCachingActivity(ReadAndWriteCache))
val readWriteData = EJHAData(allHashKeys, ReadAndWriteCache())
readWriteData.remainingCacheChecks should be(allHashKeys)
readWriteData.remainingHashesNeeded should be(allHashKeys)

val readOnlyData = EJHAData(allHashKeys, CallCachingActivity(ReadCache))
val readOnlyData = EJHAData(allHashKeys, ReadCache())
readOnlyData.remainingCacheChecks should be(allHashKeys)
readOnlyData.remainingHashesNeeded should be(Set.empty[HashKey])

val writeOnlyData = EJHAData(allHashKeys, CallCachingActivity(WriteCache))
val writeOnlyData = EJHAData(allHashKeys, WriteCache())
writeOnlyData.remainingCacheChecks should be(Set.empty[HashKey])
writeOnlyData.remainingHashesNeeded should be(allHashKeys)
}


it should "accumulate new hashes" in {
val data = EJHAData(allHashKeys, CallCachingActivity(WriteCache))
val data = EJHAData(allHashKeys, WriteCache())
data.hashesKnown should be(Set.empty)
data.remainingHashesNeeded should be(allHashKeys)
data.allHashesKnown should be(false)
Expand All @@ -51,7 +51,7 @@ class EJHADataSpec extends FlatSpec with Matchers {
}

it should "intersect new cache meta info result IDs for cache hits" in {
val data = EJHAData(allHashKeys, CallCachingActivity(ReadCache))
val data = EJHAData(allHashKeys, ReadCache())
data.possibleCacheResults should be(None)
data.allCacheResultsIntersected should be(false)
data.isDefinitelyCacheHit should be(false)
Expand All @@ -75,7 +75,7 @@ class EJHADataSpec extends FlatSpec with Matchers {
}

it should "intersect new cache meta info result IDs for cache misses" in {
val data = EJHAData(allHashKeys, CallCachingActivity(ReadCache))
val data = EJHAData(allHashKeys, ReadCache())

// To save you time I'll just tell you: the intersection of all these sets is empty Set()
val cacheLookupResults: List[CacheResultMatchesForHashes] = List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,23 @@ class EngineJobHashingActorSpec extends TestKit(new CromwellTestkitSpec.TestWork

implicit val actorSystem: ActorSystem = system

val readModes = List(CallCachingActivity(ReadCache), CallCachingActivity(ReadAndWriteCache))
val writeModes = List(CallCachingActivity(WriteCache), CallCachingActivity(ReadAndWriteCache))
val allModes = List(CallCachingActivity(ReadCache), CallCachingActivity(WriteCache), CallCachingActivity(ReadAndWriteCache))
val readCache = ReadCache()
val writeCache = WriteCache()
val readAndWriteCache = ReadAndWriteCache()

val readModes = List(readCache, readAndWriteCache)
val writeModes = List(writeCache, readAndWriteCache)
val allModes = List(readCache, writeCache, readAndWriteCache)

"Engine job hashing actor" must {
allModes foreach { activity =>
val expectation = activity.readWriteMode match {
case ReadCache => "cache hit"
case WriteCache => "hashes"
case ReadAndWriteCache => "cache hit and hashes"
val expectation = activity match {
case ReadCache(_) => "cache hit"
case WriteCache(_) => "hashes"
case ReadAndWriteCache(_) => "cache hit and hashes"
}

s"Respect the CallCachingMode and report back $expectation for the ${activity.readWriteMode} activity" in {
s"Respect the CallCachingMode and report back $expectation for the $activity activity" in {
val singleCallCachingEntryIdSet = Set(CallCachingEntryId(1))
val replyTo = TestProbe()
val deathWatch = TestProbe()
Expand All @@ -57,7 +61,7 @@ class EngineJobHashingActorSpec extends TestKit(new CromwellTestkitSpec.TestWork
deathWatch.expectTerminated(ejha, 5 seconds)
}

s"Wait for requests to the FileHashingActor for the ${activity.readWriteMode} activity" in {
s"Wait for requests to the FileHashingActor for the $activity activity" in {
val singleCallCachingEntryIdSet = Set(CallCachingEntryId(1))
val replyTo = TestProbe()
val fileHashingActor = TestProbe()
Expand Down Expand Up @@ -96,7 +100,7 @@ class EngineJobHashingActorSpec extends TestKit(new CromwellTestkitSpec.TestWork
deathWatch.expectTerminated(ejha, 5 seconds)
}

s"Cache miss for bad FileHashingActor results but still return hashes in the ${activity.readWriteMode} activity" in {
s"Cache miss for bad FileHashingActor results but still return hashes in the $activity activity" in {
val singleCallCachingEntryIdSet = Set(CallCachingEntryId(1))
val replyTo = TestProbe()
val fileHashingActor = TestProbe()
Expand Down Expand Up @@ -139,7 +143,7 @@ class EngineJobHashingActorSpec extends TestKit(new CromwellTestkitSpec.TestWork
deathWatch.expectTerminated(ejha, 5 seconds)
}

s"Detect call cache misses for the ${activity.readWriteMode} activity" in {
s"Detect call cache misses for the $activity activity" in {
val singleCallCachingEntryIdSet = Set(CallCachingEntryId(1))
val replyTo = TestProbe()
val deathWatch = TestProbe()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package cromwell.engine.workflow.lifecycle.execution.ejea
import cats.data.NonEmptyList
import cromwell.engine.workflow.lifecycle.execution.EngineJobExecutionActor._
import EngineJobExecutionActorSpec._
import cromwell.core.callcaching.CallCachingMode
import cromwell.core.callcaching._
import cromwell.engine.workflow.lifecycle.execution.callcaching.EngineJobHashingActor.{CacheHit, CallCacheHashes, EJHAResponse, HashError}
import cromwell.engine.workflow.lifecycle.execution.callcaching.CallCachingEntryId

import scala.util.{Failure, Success, Try}
import cromwell.engine.workflow.lifecycle.execution.ejea.HasJobSuccessResponse.SuccessfulCallCacheHashes

Expand Down Expand Up @@ -97,20 +98,44 @@ class EjeaBackendIsCopyingCachedOutputsSpec extends EngineJobExecutionActorSpec
}
}

if (mode.readFromCache) {
s"invalidate a call for caching if backend coping failed when it was going to receive $hashComboName, if call caching is $mode" in {
ejea = ejeaInBackendIsCopyingCachedOutputsState(initialHashData, mode)
// Send the response from the copying actor
ejea ! failureNonRetryableResponse

expectInvalidateCallCacheActor(cacheId)
eventually { ejea.stateName should be(InvalidatingCacheEntry) }
ejea.stateData should be(ResponsePendingData(helper.backendJobDescriptor, helper. bjeaProps, initialHashData, cacheHit))
eventually {
ejea.stateName should be(InvalidatingCacheEntry)
}
ejea.stateData should be(ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, initialHashData, cacheHit))
}

s"not invalidate a call for caching if backend coping failed when invalidation is disabled, when it was going to receive $hashComboName, if call caching is $mode" in {
val invalidationDisabledOptions = CallCachingOptions(invalidateBadCacheResults = false)
val cacheInvalidationDisabledMode = mode match {
case r: ReadCache => ReadCache(invalidationDisabledOptions)
case rw: ReadAndWriteCache => ReadAndWriteCache(invalidationDisabledOptions)
case _ => fail(s"Mode $mode not appropriate for cache invalidation tests")
}
ejea = ejeaInBackendIsCopyingCachedOutputsState(initialHashData, cacheInvalidationDisabledMode)
// Send the response from the copying actor
ejea ! failureNonRetryableResponse

eventually {
ejea.stateName should be(RunningJob)
}
// Make sure we didn't start invalidating anything:
helper.invalidateCacheActorCreations.hasExactlyOne should be(false)
ejea.stateData should be(ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, initialHashData, None))
}

s"invalidate a call for caching if backend coping failed (preserving and received hashes) when call caching is $mode, the EJEA has $hashComboName and then gets a success result" in {
ejea = ejeaInBackendIsCopyingCachedOutputsState(initialHashData, mode)
// Send the response from the EJHA (if there was one!):
ejhaResponse foreach { ejea ! _ }
ejhaResponse foreach {
ejea ! _
}

// Nothing should happen here:
helper.jobStoreProbe.expectNoMsg(awaitAlmostNothing)
Expand All @@ -120,9 +145,12 @@ class EjeaBackendIsCopyingCachedOutputsSpec extends EngineJobExecutionActorSpec
ejea ! failureNonRetryableResponse

expectInvalidateCallCacheActor(cacheId)
eventually { ejea.stateName should be(InvalidatingCacheEntry) }
ejea.stateData should be(ResponsePendingData(helper.backendJobDescriptor, helper. bjeaProps, finalHashData, cacheHit))
eventually {
ejea.stateName should be(InvalidatingCacheEntry)
}
ejea.stateData should be(ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, finalHashData, cacheHit))
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package cromwell.engine.workflow.lifecycle.execution.ejea
import cats.data.NonEmptyList
import cromwell.engine.workflow.lifecycle.execution.EngineJobExecutionActor.{CheckingCallCache, FetchingCachedOutputsFromDatabase, ResponsePendingData, RunningJob}
import EngineJobExecutionActorSpec.EnhancedTestEJEA
import cromwell.core.callcaching.{CallCachingActivity, CallCachingOff, ReadCache}
import cromwell.core.callcaching.{CallCachingOff, ReadCache}
import cromwell.engine.workflow.lifecycle.execution.callcaching.EngineJobHashingActor.{CacheHit, CacheMiss, HashError}
import cromwell.engine.workflow.lifecycle.execution.callcaching.CallCachingEntryId
import org.scalatest.concurrent.Eventually
Expand Down Expand Up @@ -41,7 +41,7 @@ class EjeaCheckingCallCacheSpec extends EngineJobExecutionActorSpec with Eventua
}

private def createCheckingCallCacheEjea(restarting: Boolean = false): Unit = {
ejea = helper.buildEJEA(restarting = restarting, callCachingMode = CallCachingActivity(ReadCache))
ejea = helper.buildEJEA(restarting = restarting, callCachingMode = ReadCache())
ejea.setStateInline(state = CheckingCallCache, data = ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, None))
()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cromwell.engine.workflow.lifecycle.execution.ejea

import cromwell.backend.BackendCacheHitCopyingActor.CopyOutputsCommand
import cromwell.core.WorkflowId
import cromwell.core.callcaching.{CallCachingActivity, ReadAndWriteCache}
import cromwell.core.callcaching.ReadAndWriteCache
import cromwell.core.simpleton.WdlValueSimpleton
import cromwell.engine.workflow.lifecycle.execution.EngineJobExecutionActor._
import cromwell.engine.workflow.lifecycle.execution.callcaching.EngineJobHashingActor.HashError
Expand Down Expand Up @@ -79,5 +79,5 @@ class EjeaFetchingCachedOutputsFromDatabaseSpec extends EngineJobExecutionActorS
}

def initialData = ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, None)
def ejeaInFetchingCachedOutputsFromDatabaseState(restarting: Boolean = false) = helper.buildEJEA(restarting = restarting, callCachingMode = CallCachingActivity(ReadAndWriteCache)).setStateInline(data = initialData)
def ejeaInFetchingCachedOutputsFromDatabaseState(restarting: Boolean = false) = helper.buildEJEA(restarting = restarting, callCachingMode = ReadAndWriteCache()).setStateInline(data = initialData)
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cromwell.engine.workflow.lifecycle.execution.ejea

import cats.data.NonEmptyList
import cromwell.core.callcaching.{CallCachingActivity, ReadCache}
import cromwell.core.callcaching.ReadCache
import cromwell.engine.workflow.lifecycle.execution.EngineJobExecutionActor._
import cromwell.engine.workflow.lifecycle.execution.callcaching.EngineJobHashingActor.CacheHit
import cromwell.engine.workflow.lifecycle.execution.callcaching.{CallCacheInvalidatedFailure, CallCacheInvalidatedSuccess, CallCachingEntryId}
Expand Down Expand Up @@ -49,5 +49,5 @@ class EjeaInvalidatingCacheEntrySpec extends EngineJobExecutionActorSpec with Ca
}

def standardResponsePendingData(hit: Option[CacheHit]) = ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, None, hit)
def ejeaInvalidatingCacheEntryState(hit: Option[CacheHit], restarting: Boolean = false) = helper.buildEJEA(restarting = restarting, callCachingMode = CallCachingActivity(ReadCache)).setStateInline(data = standardResponsePendingData(hit))
def ejeaInvalidatingCacheEntryState(hit: Option[CacheHit], restarting: Boolean = false) = helper.buildEJEA(restarting = restarting, callCachingMode = ReadCache()).setStateInline(data = standardResponsePendingData(hit))
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ class EjeaRunningJobSpec extends EngineJobExecutionActorSpec with Eventually wit
}

def initialData = ResponsePendingData(helper.backendJobDescriptor, helper.bjeaProps, None)
def ejeaInRunningState(mode: CallCachingMode = CallCachingActivity(ReadAndWriteCache)) = helper.buildEJEA(callCachingMode = mode).setStateInline(state = RunningJob, data = initialData)
def ejeaInRunningState(mode: CallCachingMode = ReadAndWriteCache()) = helper.buildEJEA(callCachingMode = mode).setStateInline(state = RunningJob, data = initialData)
}

0 comments on commit ebba140

Please sign in to comment.