Skip to content

Commit

Permalink
feat(nodes): use new blur_if_nsfw method
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed May 13, 2024
1 parent 4b7f4f0 commit c4069d8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 15 deletions.
16 changes: 2 additions & 14 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from pathlib import Path
from typing import Literal, Optional

import cv2
Expand Down Expand Up @@ -504,7 +503,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
title="Blur NSFW Image",
tags=["image", "nsfw"],
category="image",
version="1.2.2",
version="1.2.3",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images"""
Expand All @@ -516,23 +515,12 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

logger = context.logger
logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
image = SafetyChecker.blur_if_nsfw(image)

image_dto = context.images.save(image=image)

return ImageOutput.build(image_dto)

def _get_caution_img(self) -> Image.Image:
import invokeai.app.assets.images as image_assets

caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))


@invocation(
"img_watermark",
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/image_util/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def has_nsfw_concept(cls, image: Image.Image) -> bool:
@classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
if cls.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image
Expand Down

0 comments on commit c4069d8

Please sign in to comment.