Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor suggestions #1

Closed
woctezuma opened this issue Dec 3, 2023 · 7 comments
Closed

Minor suggestions #1

woctezuma opened this issue Dec 3, 2023 · 7 comments

Comments

@woctezuma
Copy link

woctezuma commented Dec 3, 2023

I wanted to try the official code at https://huggingface.co/CompVis/stable-diffusion-safety-checker with:

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

model = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").cuda()

But I have noticed that this requires an additional input, which is clip_input in the code below.

@torch.no_grad()
def forward(self, clip_input, images):
    # [...]
    return images, has_nsfw_concepts

So I am a bit confused by this possibly text argument for now... and discovered your repository right afterwards.


def run_safety_checker(self, image, device, dtype):
    if self.safety_checker is None:
        has_nsfw_concept = None
    else:
        if torch.is_tensor(image):
            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
        else:
            feature_extractor_input = self.image_processor.numpy_to_pil(image)
        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
        image, has_nsfw_concept = self.safety_checker(
            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
        )
    return image, has_nsfw_concept

The code runs fine, but I would have a few suggestions:

  1. is it possible to output a safety score (float) instead of a yes/no answer (bool)?
  2. it is possible to apply the model to a batch of images, e.g. 8 images, so as to make the whole process faster?

Thank you for your attention!

@iyume
Copy link
Owner

iyume commented Dec 3, 2023

  1. I'm sorry that I am researching it too. I'm not an expert of CLIP.
  2. Yes! Allowed types are described here: https://github.com/huggingface/transformers/blob/2c658b5a4282f2e824b4e23dc3bcda7ef27d5827/src/transformers/image_utils.py#L59

For 2, I typed in my code as Union[Image.Image, List[Image.Image]] which is a mistake, I'll fix it soon.

@woctezuma
Copy link
Author

woctezuma commented Dec 3, 2023

  1. You could probably use the cosine similarity (between -1 and +1). cf. https://en.wikipedia.org/wiki/Cosine_similarity
  2. Thanks!

@woctezuma
Copy link
Author

woctezuma commented Dec 4, 2023

Actually, len(res["bad_concepts"]) should do the trick based on https://github.com/huggingface/diffusers/blob/d486f0e84669447b178569ad499eeb86c739b99e/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L84C11-L84C77

has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]

@iyume
Copy link
Owner

iyume commented Dec 4, 2023

The problem is that the network can't be exactly predict R15, R16, R17, R18. Current safety checker is trained for R15 detection and nobody could promise it to work well in R18 detection. Anymore I need more test result, or train another model.

@woctezuma
Copy link
Author

woctezuma commented Dec 4, 2023

In the end, I have reused some code from another one of my projects, which extracts image features:

to build a similar tool, which reports the IDs of the bad concepts:

@iyume
Copy link
Owner

iyume commented Dec 5, 2023

Thanks to your share!

As a conclusion, safety checker works well on R15, and I would recommend nsfw_model for R18 which I'm currently using in another project.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants