diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index d74ba9ca5..00676ec31 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -32,7 +32,6 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; -import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -316,30 +315,29 @@ private Flowable run( () -> { InvocationContext invocationContext = createInvocationContext(parentContext); + Flowable mainAndAfterEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)) + .concatWith( + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .toFlowable())); + return callCallback( beforeCallbacksToFunctions( invocationContext.pluginManager(), beforeAgentCallback), invocationContext) .flatMapPublisher( - beforeEventOpt -> { + beforeEvent -> { if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); + return Flowable.just(beforeEvent); } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); + return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); }) + .switchIfEmpty(mainAndAfterEvents) .compose( Tracing.traceAgent( "invoke_agent " + name(), name(), description(), invocationContext)); @@ -383,13 +381,13 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. + * @return maybe emitting first event, or empty if none. */ - private Single> callCallback( + private Maybe callCallback( List>> agentCallbacks, InvocationContext invocationContext) { if (agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); + return Maybe.empty(); } CallbackContext callbackContext = @@ -404,21 +402,20 @@ private Single> callCallback( .map( content -> { invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch().orElse(null)) - .actions(callbackContext.eventActions()) - .content(content) - .build()); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .actions(callbackContext.eventActions()) + .content(content) + .build(); }) .toFlowable(); }) .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -429,9 +426,9 @@ private Single> callCallback( .branch(invocationContext.branch().orElse(null)) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 6ed9ccaa3..e1afca2b1 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -170,52 +170,51 @@ private Flowable callLlm( LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) - .flatMapPublisher( - beforeResponse -> { - if (beforeResponse.isPresent()) { - return Flowable.just(beforeResponse.get()); - } - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose(Tracing.trace("call_llm")) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); - }); + .toFlowable() + .switchIfEmpty( + Flowable.defer( + () -> { + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm")) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()); + })); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Single} with the callback result or {@link Optional#empty()}. + * @return A {@link Maybe} with the callback result. */ - private Single> handleBeforeModelCallback( + private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = @@ -228,7 +227,7 @@ private Single> handleBeforeModelCallback( List callbacks = agent.canonicalBeforeModelCallbacks(); if (callbacks.isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); + return pluginResult; } Maybe callbackResult = @@ -238,10 +237,7 @@ private Single> handleBeforeModelCallback( .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) .firstElement()); - return pluginResult - .switchIfEmpty(callbackResult) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + return pluginResult.switchIfEmpty(callbackResult); } /**