Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions invokeai/app/invocations/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Any, Literal, Optional, Union

import numpy as np

from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
Expand All @@ -12,7 +14,9 @@
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL

SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
Expand Down Expand Up @@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):

# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
) -> None:
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)

(width, height) = image.size
width *= 8
height *= 8

dataURL = image_to_dataURL(image, image_format="JPEG")

context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
float(step) / float(self.steps),
self.steps,
)

def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)

# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
Expand Down
11 changes: 8 additions & 3 deletions invokeai/app/services/events.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from typing import Any, Dict
from typing import Any, Dict, TypedDict

ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)

class EventServiceBase:
session_event: str = "session_event"
Expand All @@ -23,17 +26,19 @@ def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
percent: float,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
percent=percent,
total_steps=total_steps,
),
)

Expand Down
18 changes: 18 additions & 0 deletions invokeai/app/services/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,24 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)

# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}

def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""

Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ def repaste_and_color_correct(
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result

def sample_to_lowres_estimated_image(self, samples):
@staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

Expand Down
16 changes: 16 additions & 0 deletions invokeai/backend/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import multiprocessing as mp
import os
import re
import io
import base64

from collections import abc
from inspect import isfunction
from pathlib import Path
Expand Down Expand Up @@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None


def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64