Skip to content

Commit

Permalink
✨✅🚧Implemented workflow stuffs
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Jun 8, 2023
1 parent bbdb810 commit 260f7c7
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 1 deletion.
84 changes: 84 additions & 0 deletions examples/carefree_creator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cftool.misc import shallow_copy_dict
from cfcreator.common import InpaintingMode
from cflearn.misc.toolkit import new_seed
from cfcreator.sdks.apis import Workflow

from cfdraw import *

Expand Down Expand Up @@ -515,6 +516,87 @@ async def process(self, data: ISocketRequest) -> List[Image.Image]:
return await get_apis().harmonization(model)


# workflow


@register_node_validator("draw_workflow")
def validate_draw_workflow(data: ISocketRequest) -> bool:
identifier = data.nodeData.identifier
if identifier is None or identifier not in key2endpoints:
return False
return True


class DrawWorkflow(IFieldsPlugin):
@property
def settings(self) -> IPluginSettings:
return IPluginSettings(
w=240,
h=110,
src=constants.WORKFLOW_ICON,
nodeConstraintValidator="draw_workflow",
tooltip=I18N(
zh="绘制工作流",
en="Draw Workflow",
),
pluginInfo=IFieldsPluginInfo(
header=I18N(
zh="绘制工作流",
en="Draw Workflow",
),
definitions={},
),
no_offload=True,
)

async def process(self, data: ISocketRequest) -> List[Image.Image]:
workflow = trace_workflow(data.nodeData.meta)
self.set_extra_response(WORKFLOW_KEY, workflow.to_json())
return [workflow.render()]


@register_node_validator("workflow")
def validate_workflow(data: ISocketRequest) -> bool:
if data.nodeData.extra_responses is None:
return False
return WORKFLOW_KEY in data.nodeData.extra_responses


class ExecuteWorkflow(IFieldsPlugin):
@property
def settings(self) -> IPluginSettings:
return IPluginSettings(
w=240,
h=110,
src=constants.EXECUTE_WORKFLOW_ICON,
pivot=PivotType.BOTTOM,
follow=True,
nodeConstraintValidator="workflow",
tooltip=I18N(
zh="执行工作流",
en="Execute Workflow",
),
pluginInfo=IFieldsPluginInfo(
header=I18N(
zh="工作流",
en="Workflow",
),
numColumns=2,
definitions={},
),
)

async def process(self, data: ISocketRequest) -> List[Image.Image]:
def callback(step: int, num_steps: int) -> bool:
return self.send_progress(step / num_steps)

kw = dict(step_callback=callback)
workflow_json = data.nodeData.extra_responses[WORKFLOW_KEY]
workflow = Workflow.from_json(workflow_json)
results = await get_apis().execute(workflow, workflow.last.key, **kw)
return results[workflow.last.key]


# groups


Expand Down Expand Up @@ -576,6 +658,7 @@ def settings(self) -> IPluginSettings:
CaptioningKey: Captioning,
ControlNetHintKey: ControlHints,
VariationKey: Variation,
DrawWorkflowKey: DrawWorkflow,
},
),
)
Expand Down Expand Up @@ -654,4 +737,5 @@ def settings(self) -> IPluginSettings:
register_plugin("image_followers")(ImageFollowers)
register_plugin("image_and_mask_followers")(ImageAndMaskFollowers)
register_plugin("canvas_followers")(CanvasFollowers)
register_plugin(ExecuteWorkflowKey)(ExecuteWorkflow)
app = App(notification)
134 changes: 133 additions & 1 deletion examples/carefree_creator/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Any
from typing import Dict
from cfdraw import cache_resource
from collections import defaultdict
from cftool.misc import random_hash
from cfcreator.workflow import *
from cfcreator.endpoints import *
from cfcreator.sdks.apis import *


WORKFLOW_KEY = "$workflow"
DATA_MODEL_KEY = "$data_model"

UPLOAD_META_TYPE = "upload"
PYTHON_FIELDS_META_TYPE = "python.fields"

Txt2ImgKey = "txt2img"
Img2ImgKey = "img2img"
SRKey = "sr"
Expand All @@ -17,14 +26,133 @@
ControlNetHintKey = "control_net_hint"
MultiControlNetKey = "multi_control_net"
ImageHarmonizationKey = "image_harmonization"
DrawWorkflowKey = "draw_workflow"
ExecuteWorkflowKey = "execute_workflow"

key2endpoints = {
Txt2ImgKey: txt2img_sd_endpoint,
Img2ImgKey: img2img_sd_endpoint,
SRKey: img2img_sr_endpoint,
SODKey: img2img_sod_endpoint,
CaptioningKey: img2txt_caption_endpoint,
InpaintingKey: img2img_inpainting_endpoint,
SDInpaintingKey: txt2img_sd_inpainting_endpoint,
SDOutpaintingKey: txt2img_sd_outpainting_endpoint,
ControlNetHintKey: CONTROL_HINT_ENDPOINT,
MultiControlNetKey: new_control_multi_endpoint,
ImageHarmonizationKey: img2img_harmonization_endpoint,
}


@cache_resource
def get_apis() -> APIs:
return APIs()
return APIs(
focuses_endpoints=[
txt2img_sd_endpoint,
img2img_sd_endpoint,
new_control_multi_endpoint,
# control_canny_hint_endpoint,
# control_pose_hint_endpoint,
# img2img_harmonization_endpoint,
# txt2img_sd_inpainting_endpoint,
],
)


def trace_workflow(meta: Dict[str, Any]) -> Workflow:
def _get_key(k: str) -> str:
new_key = f"{k}_{counts[k]}"
counts[k] += 1
return new_key

def _convert_field_key(k: str) -> str:
# list field workaround
k_split = k.split(".")
if len(k_split) >= 3:
k_split.insert(2, "data")
k = ".".join(k_split)
return k

def _trace(meta_: Dict[str, Any]) -> None:
mtype, mdata = map(meta_.get, ["type", "data"])
if mtype is None or mdata is None:
return
alias = mdata.get("alias", random_hash())
key = alias2key.get(alias)
if key is not None and workflow.get(key) is not None:
return
if mtype == UPLOAD_META_TYPE:
if key is None:
key = _get_key(UPLOAD_META_TYPE)
alias2key[alias] = key
workflow.push(
WorkNode(
key=key,
endpoint=UPLOAD_ENDPOINT,
injections={},
data=dict(url=mdata["url"]),
)
)
elif mtype == PYTHON_FIELDS_META_TYPE:
identifier, response = map(mdata.get, ["identifier", "response"])
if identifier is None or response is None:
return
data_model_d = response.get("extra", {}).get(DATA_MODEL_KEY)
if data_model_d is None:
return
raw_injections = mdata.get("injections", {})
injections = {}
for k, v in raw_injections.items():
k = _convert_field_key(k)
v_meta = v.get("meta")
if v_meta is None:
continue
v_type, v_data = map(v_meta.get, ["type", "data"])
if v_type is None or v_data is None:
continue
v_alias = v_data.get("alias", random_hash())
v_key = alias2key.get(v_alias)
if v_type == "upload":
if v_key is None:
v_key = _get_key("upload")
alias2key[v_alias] = v_key
injections[v_key] = InjectionPack(index=0, field=k)
elif v_type == PYTHON_FIELDS_META_TYPE:
v_identifier = v_data.get("identifier")
if v_identifier is None:
continue
v_index = v_data.get("response", {}).get("index", 0)
v_injection_pack = InjectionPack(index=v_index, field=k)
if v_key is None:
v_key = _get_key(v_identifier)
alias2key[v_alias] = v_key
injections[v_key] = v_injection_pack
else:
raise ValueError(f"unknown type: {v_type}")
_trace(v_meta)
if key is None:
key = _get_key(identifier)
alias2key[alias] = key
workflow.push(
WorkNode(
key=key,
endpoint=key2endpoints[identifier],
injections=injections,
data=data_model_d,
)
)
else:
raise ValueError(f"unknown type: {mtype}")

counts = defaultdict(int)
workflow = Workflow()
alias2key = {}
_trace(meta)
return workflow


__all__ = [
"WORKFLOW_KEY",
"DATA_MODEL_KEY",
"Txt2ImgKey",
"Img2ImgKey",
Expand All @@ -38,7 +166,11 @@ def get_apis() -> APIs:
"ControlNetHintKey",
"MultiControlNetKey",
"ImageHarmonizationKey",
"DrawWorkflowKey",
"ExecuteWorkflowKey",
"get_apis",
"trace_workflow",
"key2endpoints",
"HighresModel",
"Img2TxtModel",
"Txt2ImgSDModel",
Expand Down

0 comments on commit 260f7c7

Please sign in to comment.