Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions invokeai/app/services/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.services.model_manager_service import (
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)


class EventServiceBase:
session_event: str = "session_event"
Expand Down Expand Up @@ -38,7 +44,9 @@ def emit_generator_progress(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
progress_image=progress_image.dict()
if progress_image is not None
else None,
step=step,
total_steps=total_steps,
),
Expand Down Expand Up @@ -67,6 +75,7 @@ def emit_invocation_error(
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when an invocation has completed"""
Expand All @@ -76,6 +85,7 @@ def emit_invocation_error(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
error_type=error_type,
error=error,
),
)
Expand All @@ -102,13 +112,13 @@ def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
),
)

def emit_model_load_started (
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
def emit_model_load_started(
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
Expand All @@ -123,13 +133,13 @@ def emit_model_load_started (
)

def emit_model_load_completed(
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: ModelInfo,
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
Expand All @@ -145,3 +155,37 @@ def emit_model_load_completed(
precision=str(model_info.precision),
),
)

def emit_session_retrieval_error(
self,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
event_name="session_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
)

def emit_invocation_retrieval_error(
self,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
event_name="invocation_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)
41 changes: 32 additions & 9 deletions invokeai/app/services/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,41 @@ def __process(self, stop_event: Event):
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
logger.debug("Exception while getting from queue: %s" % e)
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)

if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue

graph_execution_state = (
self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
try:
graph_execution_state = (
self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
)
)
invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue

try:
invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue

# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
Expand Down Expand Up @@ -114,11 +134,13 @@ def __process(self, stop_event: Event):
graph_execution_state
)

self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error,
)

Expand All @@ -136,11 +158,12 @@ def __process(self, stop_event: Event):
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
logger.error("Error while invoking: %s" % e)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc()
)
elif is_complete:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';

export const listenerMiddleware = createListenerMiddleware();

Expand Down Expand Up @@ -153,6 +155,8 @@ addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
addModelLoadEventListener();
addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();

// Session Created
addSessionCreatedPendingListener();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action) => {
const log = logger('session');
if (action.payload) {
const { error } = action.payload;
const { error, status } = action.payload;
const graph = parseify(action.meta.arg);
const stringifiedError = JSON.stringify(error);
log.error(
{ graph, error: serializeError(error) },
`Problem creating session: ${stringifiedError}`
{ graph, status, error: serializeError(error) },
`Problem creating session`
);
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
const stringifiedError = JSON.stringify(error);
log.error(
{
session_id,
error: serializeError(error),
},
`Problem invoking session: ${stringifiedError}`
`Problem invoking session`
);
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationRetrievalError,
socketInvocationRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';

export const addInvocationRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketInvocationRetrievalError(action.payload));
},
});
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketSessionRetrievalError,
socketSessionRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';

export const addSessionRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketSessionRetrievalError(action.payload));
},
});
};
62 changes: 40 additions & 22 deletions invokeai/frontend/web/src/features/system/store/systemSlice.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
Expand All @@ -16,13 +16,16 @@ import {
appSocketGraphExecutionStateComplete,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationRetrievalError,
appSocketInvocationStarted,
appSocketSessionRetrievalError,
appSocketSubscribed,
appSocketUnsubscribed,
} from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants';
import { startCase } from 'lodash-es';

export type CancelStrategy = 'immediate' | 'scheduled';

Expand Down Expand Up @@ -288,25 +291,6 @@ export const systemSlice = createSlice({
}
});

/**
* Invocation Error
*/
builder.addCase(appSocketInvocationError, (state) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;

state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
);
});

/**
* Graph Execution State Complete
*/
Expand Down Expand Up @@ -362,7 +346,7 @@ export const systemSlice = createSlice({
* Session Invoked - REJECTED
* Session Created - REJECTED
*/
builder.addMatcher(isAnySessionRejected, (state) => {
builder.addMatcher(isAnySessionRejected, (state, action) => {
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
Expand All @@ -372,7 +356,35 @@ export const systemSlice = createSlice({
state.progressImage = null;

state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
makeToast({
title: t('toast.serverError'),
status: 'error',
description:
action.payload?.status === 422 ? 'Validation Error' : undefined,
})
);
});

/**
* Any server error
*/
builder.addMatcher(isAnyServerError, (state, action) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;

state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description: startCase(action.payload.data.error_type),
})
);
});
},
Expand Down Expand Up @@ -400,3 +412,9 @@ export const {
} = systemSlice.actions;

export default systemSlice.reducer;

const isAnyServerError = isAnyOf(
appSocketInvocationError,
appSocketSessionRetrievalError,
appSocketInvocationRetrievalError
);
Loading