From 4b334be7d030d369bbac2d0dc4d30fb27954fa2b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 23 Jul 2023 12:27:59 +1000 Subject: [PATCH] feat(nodes,ui): fix soft locks on session/invocation retrieval When a queue item is popped for processing, we need to retrieve its session from the DB. Pydantic serializes the graph at this stage. It's possible for a graph to have been made invalid during the graph preparation stage (e.g. an ancestor node executes, and its output is not valid for its successor node's input field). When this occurs, the session in the DB will fail validation, but we don't have a chance to find out until it is retrieved and parsed by pydantic. This logic was previously not wrapped in any exception handling. Just after retrieving a session, we retrieve the specific invocation to execute from the session. It's possible that this could also have some sort of error, though it should be impossible for it to be a pydantic validation error (that would have been caught during session validation). There was also no exception handling here. When either of these processes fail, the processor gets soft-locked because the processor's cleanup logic is never run. (I didn't dig deeper into exactly what cleanup is not happening, because the fix is to just handle the exceptions.) This PR adds exception handling to both the session retrieval and node retrieval and events for each: `session_retrieval_error` and `invocation_retrieval_error`. These events are caught and displayed in the UI as toasts, along with the type of the python exception (e.g. `Validation Error`). The events are also logged to the browser console. --- invokeai/app/services/events.py | 76 +++++++++++++++---- invokeai/app/services/processor.py | 41 +++++++--- .../middleware/listenerMiddleware/index.ts | 4 + .../listeners/sessionCreated.ts | 7 +- .../listeners/sessionInvoked.ts | 3 +- .../socketInvocationRetrievalError.ts | 20 +++++ .../socketio/socketSessionRetrievalError.ts | 20 +++++ .../src/features/system/store/systemSlice.ts | 62 +++++++++------ .../web/src/services/api/thunks/session.ts | 13 +++- .../web/src/services/events/actions.ts | 34 +++++++++ .../frontend/web/src/services/events/types.ts | 26 +++++++ .../services/events/util/setEventListeners.ts | 24 ++++++ 12 files changed, 273 insertions(+), 57 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 35003536e64..73d74de2d95 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,7 +3,13 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp -from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo +from invokeai.app.services.model_manager_service import ( + BaseModelType, + ModelType, + SubModelType, + ModelInfo, +) + class EventServiceBase: session_event: str = "session_event" @@ -38,7 +44,9 @@ def emit_generator_progress( graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, - progress_image=progress_image.dict() if progress_image is not None else None, + progress_image=progress_image.dict() + if progress_image is not None + else None, step=step, total_steps=total_steps, ), @@ -67,6 +75,7 @@ def emit_invocation_error( graph_execution_state_id: str, node: dict, source_node_id: str, + error_type: str, error: str, ) -> None: """Emitted when an invocation has completed""" @@ -76,6 +85,7 @@ def emit_invocation_error( graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, + error_type=error_type, error=error, ), ) @@ -102,13 +112,13 @@ def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None: ), ) - def emit_model_load_started ( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + def emit_model_load_started( + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, ) -> None: """Emitted when a model is requested""" self.__emit_session_event( @@ -123,13 +133,13 @@ def emit_model_load_started ( ) def emit_model_load_completed( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, + model_info: ModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_session_event( @@ -145,3 +155,37 @@ def emit_model_load_completed( precision=str(model_info.precision), ), ) + + def emit_session_retrieval_error( + self, + graph_execution_state_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when session retrieval fails""" + self.__emit_session_event( + event_name="session_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + error_type=error_type, + error=error, + ), + ) + + def emit_invocation_retrieval_error( + self, + graph_execution_state_id: str, + node_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when invocation retrieval fails""" + self.__emit_session_event( + event_name="invocation_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + node_id=node_id, + error_type=error_type, + error=error, + ), + ) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index e11eb84b3db..5995e4ffc3b 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -39,21 +39,41 @@ def __process(self, stop_event: Event): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() except Exception as e: - logger.debug("Exception while getting from queue: %s" % e) + self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e) if not queue_item: # Probably stopping # do not hammer the queue time.sleep(0.5) continue - graph_execution_state = ( - self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id + try: + graph_execution_state = ( + self.__invoker.services.graph_execution_manager.get( + queue_item.graph_execution_state_id + ) ) - ) - invocation = graph_execution_state.execution_graph.get_node( - queue_item.invocation_id - ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) + self.__invoker.services.events.emit_session_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue + + try: + invocation = graph_execution_state.execution_graph.get_node( + queue_item.invocation_id + ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) + self.__invoker.services.events.emit_invocation_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + node_id=queue_item.invocation_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] @@ -114,11 +134,13 @@ def __process(self, stop_event: Event): graph_execution_state ) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) # Send error event self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=error, ) @@ -136,11 +158,12 @@ def __process(self, stop_event: Event): try: self.__invoker.invoke(graph_execution_state, invoke_all=True) except Exception as e: - logger.error("Error while invoking: %s" % e) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=traceback.format_exc() ) elif is_complete: diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 04f0ce7a0be..5adc4f5e5e4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; +import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; +import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; export const listenerMiddleware = createListenerMiddleware(); @@ -153,6 +155,8 @@ addSocketDisconnectedListener(); addSocketSubscribedListener(); addSocketUnsubscribedListener(); addModelLoadEventListener(); +addSessionRetrievalErrorEventListener(); +addInvocationRetrievalErrorEventListener(); // Session Created addSessionCreatedPendingListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts index 5709d87d227..e89acb75428 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts @@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => { effect: (action) => { const log = logger('session'); if (action.payload) { - const { error } = action.payload; + const { error, status } = action.payload; const graph = parseify(action.meta.arg); - const stringifiedError = JSON.stringify(error); log.error( - { graph, error: serializeError(error) }, - `Problem creating session: ${stringifiedError}` + { graph, status, error: serializeError(error) }, + `Problem creating session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts index 60009ed1945..a62f75d9572 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts @@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => { const { session_id } = action.meta.arg; if (action.payload) { const { error } = action.payload; - const stringifiedError = JSON.stringify(error); log.error( { session_id, error: serializeError(error), }, - `Problem invoking session: ${stringifiedError}` + `Problem invoking session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts new file mode 100644 index 00000000000..aa88457eb78 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketInvocationRetrievalError, + socketInvocationRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addInvocationRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketInvocationRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Invocation retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketInvocationRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts new file mode 100644 index 00000000000..7efb7f463ad --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketSessionRetrievalError, + socketSessionRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addSessionRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketSessionRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Session retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketSessionRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 629a4f01391..b7a5e606e29 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,5 +1,5 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { InvokeLogLevel } from 'app/logging/logger'; import { userInvoked } from 'app/store/actions'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; @@ -16,13 +16,16 @@ import { appSocketGraphExecutionStateComplete, appSocketInvocationComplete, appSocketInvocationError, + appSocketInvocationRetrievalError, appSocketInvocationStarted, + appSocketSessionRetrievalError, appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { makeToast } from '../util/makeToast'; import { LANGUAGES } from './constants'; +import { startCase } from 'lodash-es'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -288,25 +291,6 @@ export const systemSlice = createSlice({ } }); - /** - * Invocation Error - */ - builder.addCase(appSocketInvocationError, (state) => { - state.isProcessing = false; - state.isCancelable = true; - // state.currentIteration = 0; - // state.totalIterations = 0; - state.currentStatusHasSteps = false; - state.currentStep = 0; - state.totalSteps = 0; - state.statusTranslationKey = 'common.statusError'; - state.progressImage = null; - - state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) - ); - }); - /** * Graph Execution State Complete */ @@ -362,7 +346,7 @@ export const systemSlice = createSlice({ * Session Invoked - REJECTED * Session Created - REJECTED */ - builder.addMatcher(isAnySessionRejected, (state) => { + builder.addMatcher(isAnySessionRejected, (state, action) => { state.isProcessing = false; state.isCancelable = false; state.isCancelScheduled = false; @@ -372,7 +356,35 @@ export const systemSlice = createSlice({ state.progressImage = null; state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: + action.payload?.status === 422 ? 'Validation Error' : undefined, + }) + ); + }); + + /** + * Any server error + */ + builder.addMatcher(isAnyServerError, (state, action) => { + state.isProcessing = false; + state.isCancelable = true; + // state.currentIteration = 0; + // state.totalIterations = 0; + state.currentStatusHasSteps = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusError'; + state.progressImage = null; + + state.toastQueue.push( + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: startCase(action.payload.data.error_type), + }) ); }); }, @@ -400,3 +412,9 @@ export const { } = systemSlice.actions; export default systemSlice.reducer; + +const isAnyServerError = isAnyOf( + appSocketInvocationError, + appSocketSessionRetrievalError, + appSocketInvocationRetrievalError +); diff --git a/invokeai/frontend/web/src/services/api/thunks/session.ts b/invokeai/frontend/web/src/services/api/thunks/session.ts index 6d20b9dd334..5588f25b467 100644 --- a/invokeai/frontend/web/src/services/api/thunks/session.ts +++ b/invokeai/frontend/web/src/services/api/thunks/session.ts @@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required< >; type CreateSessionThunkConfig = { - rejectValue: { arg: CreateSessionArg; error: unknown }; + rejectValue: { arg: CreateSessionArg; status: number; error: unknown }; }; /** @@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk< }); if (error) { - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } return data; @@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = { rejectValue: { arg: InvokedSessionArg; error: unknown; + status: number; }; }; @@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk< if (error) { if (isErrorWithStatus(error) && error.status === 403) { - return rejectWithValue({ arg, error: (error as any).body.detail }); + return rejectWithValue({ + arg, + status: response.status, + error: (error as any).body.detail, + }); } - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } }); diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index b6316c5e95d..35ebb725cba 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -4,9 +4,11 @@ import { GraphExecutionStateCompleteEvent, InvocationCompleteEvent, InvocationErrorEvent, + InvocationRetrievalErrorEvent, InvocationStartedEvent, ModelLoadCompletedEvent, ModelLoadStartedEvent, + SessionRetrievalErrorEvent, } from 'services/events/types'; // Create actions for each socket @@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{ export const appSocketModelLoadCompleted = createAction<{ data: ModelLoadCompletedEvent; }>('socket/appSocketModelLoadCompleted'); + +/** + * Socket.IO Session Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/socketSessionRetrievalError'); + +/** + * App-level Session Retrieval Error + */ +export const appSocketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/appSocketSessionRetrievalError'); + +/** + * Socket.IO Invocation Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/socketInvocationRetrievalError'); + +/** + * App-level Invocation Retrieval Error + */ +export const appSocketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/appSocketInvocationRetrievalError'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index ec1b55e3fef..37f5f24eacc 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -87,6 +87,7 @@ export type InvocationErrorEvent = { graph_execution_state_id: string; node: BaseNode; source_node_id: string; + error_type: string; error: string; }; @@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = { graph_execution_state_id: string; }; +/** + * A `session_retrieval_error` socket.io event. + * + * @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... } + */ +export type SessionRetrievalErrorEvent = { + graph_execution_state_id: string; + error_type: string; + error: string; +}; + +/** + * A `invocation_retrieval_error` socket.io event. + * + * @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... } + */ +export type InvocationRetrievalErrorEvent = { + graph_execution_state_id: string; + node_id: string; + error_type: string; + error: string; +}; + export type ClientEmitSubscribe = { session: string; }; @@ -128,6 +152,8 @@ export type ServerToClientEvents = { ) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_completed: (payload: ModelLoadCompletedEvent) => void; + session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; + invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void; }; export type ClientToServerEvents = { diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index d44a549183e..9ebb7ffbffc 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -11,9 +11,11 @@ import { socketGraphExecutionStateComplete, socketInvocationComplete, socketInvocationError, + socketInvocationRetrievalError, socketInvocationStarted, socketModelLoadCompleted, socketModelLoadStarted, + socketSessionRetrievalError, socketSubscribed, } from '../actions'; import { ClientToServerEvents, ServerToClientEvents } from '../types'; @@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => { }) ); }); + + /** + * Session retrieval error + */ + socket.on('session_retrieval_error', (data) => { + dispatch( + socketSessionRetrievalError({ + data, + }) + ); + }); + + /** + * Invocation retrieval error + */ + socket.on('invocation_retrieval_error', (data) => { + dispatch( + socketInvocationRetrievalError({ + data, + }) + ); + }); };