Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Come on, come on, let's adapt the conversion script to SD 2.0 #1388

Closed
piEsposito opened this issue Nov 24, 2022 · 56 comments
Closed

Come on, come on, let's adapt the conversion script to SD 2.0 #1388

piEsposito opened this issue Nov 24, 2022 · 56 comments
Labels
stale Issues that haven't received updates

Comments

@piEsposito
Copy link
Contributor

Is your feature request related to a problem? Please describe.
It would be great if we could run SD 2 with cpu_offload, attention slicing, xformers, etc...

Describe the solution you'd like
Adapt the conversion script to SD 2.0

Describe alternatives you've considered
Stability AI's repo is not as flexible.

@averad
Copy link

averad commented Nov 24, 2022

🤗 Diffusers with Stable Diffusion 2 is live!

anton-l commented (#1388 (comment))
diffusers==0.9.0 with Stable Diffusion 2 are live!

Installation
pip install diffusers[torch]==0.9 transformers

Release Information
https://github.com/huggingface/diffusers/releases/tag/v0.9.0

Contributors
@kashif
@pcuenca
@patrickvonplaten
@anton-l
@patil-suraj

📰 News

✏️ Notes & Information

Related huggingface/diffusers Pull Requests:

👇 Quick Links:

👁️ User Submitted Resources:

💭 User Story (Prior to Huggingface Diffusers 0.9.0 Release)

Stability-AI has released Stable Diffusion 2.0 models/workflow. When you run convert_original_stable_diffusion_to_diffusers.py on the new Stability-AI/stablediffusion models the following errors occur.

convert_original_stable_diffusion_to_diffusers.py --checkpoint_path="./512-inpainting-ema.ckpt" --dump_path="./512-inpainting-ema_diffusers"

Output:

Traceback (most recent call last):
File "convert_original_stable_diffusion_to_diffusers.py", line 720, in <module> 
        unet.load_state_dict(converted_unet_checkpoint)
File "lib\site-packages\torch\nn\modules\module.py", line 1667, in load_state_dict
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
        size mismatch for down_blocks.0.attentions.0.proj_in.weight: copying a param with shape torch.Size([320, 320]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
        size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([320, 1024]) from checkpoint, the shape in current model is torch.Size([320, 768]).
        size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([320, 1024]) from checkpoint, the shape in current model is torch.Size([320, 768]).
        size mismatch for down_blocks.0.attentions.0.proj_out.weight: copying a param with shape torch.Size([320, 320]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
.... blocks.1.attentions blocks.2.attentions etc. etc.

@devilismyfriend
Copy link

trying to but likely I won't be able to do it lol

@devilismyfriend
Copy link

after looking at it I'm not sure it has anything to do with the script, seems like the u2net on diffusers needs to have 4 dimensions on the tensor size.

@AugmentedRealityCat
Copy link

needs to have 4 dimensions

So I guess this will take time...

@devilismyfriend
Copy link

needs to have 4 dimensions

So I guess this will take time...

maybe not, I'm not that knowledgeable on the subject but I assume a unet2D needs to be 4D, or maybe you can just artificially add it idk

@0xdevalias
Copy link
Contributor

rudimentary support for stable diffusion 2.0

MrCheeze/stable-diffusion-webui@069591b

Originally posted by @152334H in AUTOMATIC1111/stable-diffusion-webui#5011 (comment)

@hafriedlander
Copy link
Contributor

https://github.com/hafriedlander/diffusers/blob/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py

Notes:

  • Only tested on the two txt2img models, not inpaint / depth2img / upscaling
  • You will need to change your text embedding to use the penultimate layer too
  • It spits out a bunch of warnings about vision_model, but that's fine
  • I have no idea if this is right or not. It generates images, no guarantee beyond that. (Hence no PR - if you're patient, I'm sure the Diffusers team will do a better job than I have)

@averad averad mentioned this issue Nov 24, 2022
2 tasks
@hafriedlander
Copy link
Contributor

Here's an example of accessing the penultimate text embedding layer https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/b34bb27cf30940f6a6a41f4b77c5b77bea11fd76/sdgrpcserver/pipeline/text_embedding/basic_text_embedding.py#L33

@devilismyfriend
Copy link

devilismyfriend commented Nov 24, 2022

https://github.com/hafriedlander/diffusers/blob/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py

Notes:

  • Only tested on the two txt2img models, not inpaint / depth2img / upscaling
  • You will need to change your text embedding to use the penultimate layer too
  • It spits out a bunch of warnings about vision_model, but that's fine
  • I have no idea if this is right or not. It generates images, no guarantee beyond that. (Hence no PR - if you're patient, I'm sure the Diffusers team will do a better job than I have)

doesn't seem to work for me on the 768-v model using the v2 config for v

TypeError: EulerDiscreteScheduler.init() got an unexpected keyword argument 'prediction_type'

@CoffeeVampir3
Copy link

Appears I'm also having unexpected argument error, but of a different arg:

Command:

python convert.py --checkpoint_path="models/512-base-ema.ckpt" --dump_path="outputs/" --original_config_file="v2-inference.yaml"

Result:

│ 736 │ unet = UNet2DConditionModel(**unet_config) │
│ 737 │ unet.load_state_dict(converted_unet_checkpoint)
TypeError: init() got an unexpected keyword argument 'use_linear_projection'

I can't seem to find a resolution to this one.

@hafriedlander
Copy link
Contributor

You need to use the absolute latest Diffusers and merge this PR (or use my branch which has it in it) #1386

@hafriedlander
Copy link
Contributor

@patrickvonplaten
Copy link
Contributor

Amazing to see the excitement here! We'll merge #1386 in a bit :-)

@hafriedlander
Copy link
Contributor

@patrickvonplaten the problems I've run into so far:

  • attention_slicing doesn't work when attention_head_dim is a list (maybe you have a more elegant solution than that)
  • tokenizer.model_max_length is max_long when using my converter above, so I use text_encoder.config.max_position_embeddings instead

@patrickvonplaten
Copy link
Contributor

That's super helpful @hafriedlander - thanks!

BTW, weights for the 512x512 are up:

Looking into the 768x768 model now

@hafriedlander
Copy link
Contributor

Nice. Do you have a solution in mind for how to flag to the pipeline to use the penultimate layer in the CLIP model? (I just pass it in as an option at the moment)

@patrickvonplaten
Copy link
Contributor

Can you send me a link? Does the pipeline not work out of the box? cc @anton-l @patil-suraj

@hafriedlander
Copy link
Contributor

It works but I don't think it's correct. The Stability configuration files explicitly say to use the penultimate CLIP layer https://github.com/Stability-AI/stablediffusion/blob/33910c386eaba78b7247ce84f313de0f2c314f61/configs/stable-diffusion/v2-inference-v.yaml#L68

@hafriedlander
Copy link
Contributor

It's relatively easy to get access to the penultimate layer. I do it in my custom pipeline like this:

https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/b34bb27cf30940f6a6a41f4b77c5b77bea11fd76/sdgrpcserver/pipeline/text_embedding/basic_text_embedding.py#L33

The problem is knowing when to do it and when not to.

@patrickvonplaten
Copy link
Contributor

I see! Thanks for the links - so they do this for both the 512x512 SD 2 and 768x768 SD 2 model?

@hafriedlander
Copy link
Contributor

Both

@hafriedlander
Copy link
Contributor

It's a technique NovelAI discovered FYI (https://blog.novelai.net/novelai-improvements-on-stable-diffusion-e10d38db82ac)

@hafriedlander
Copy link
Contributor

@patrickvonplaten how sure are you that your conversion is correct? I'm trying to diagnose a difference I get between your 768 weights and my conversion script. There's a big difference, and in general I much prefer the results from my conversion. It seems specific to the unet - if I replace my unet with yours I get the same results.

@hafriedlander
Copy link
Contributor

OK, differential diagnostic done, it's the Tokenizer. How did you create the Tokenizer at https://huggingface.co/stabilityai/stable-diffusion-2/tree/main/tokenizer? I just built a Tokenizer using AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - it seems to give much better results.

@0xdevalias
Copy link
Contributor

Yes, stable_diffusion2 is working now. And the few lines of code to get inference is in here: colab.research.google.com/drive/1Na9x7w7RSbk2UFbcnrnuurg7kFGeqBsa?usp=sharing

@hamzafar In one of the last cells (that sets up EulerDiscreteScheduler) the following warning is shown. I wonder if things would work differently/better if ftfy or spacy was installed alongside the other requirements?

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.

@0xdevalias
Copy link
Contributor

From @pcuenca on the HF discord:

We are busy preparing a new release of diffusers to fully support Stable Diffusion 2. We are still ironing things out, but the basics already work from the main branch in github. Here's how to do it:

  • Install diffusers from github alongside its dependencies:
pip install --upgrade git+https://github.com/huggingface/diffusers.git transformers accelerate scipy
  • Use the code in this script to run your predictions:
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
import torch

repo_id = "stabilityai/stable-diffusion-2"
device = "cuda"

scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler", prediction_type="v_prediction")
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16", scheduler=scheduler)
pipe = pipe.to(device)

prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, width=768, height=768, guidance_scale=9).images[0]
image.save("astronaut.png")

Originally posted by @vvvm23 in #1392 (comment)

@hafriedlander
Copy link
Contributor

I've put "my" version of the Tokenizer at https://huggingface.co/halffried/sd2-laion-clipH14-tokenizer/tree/main. You can just replace the tokenizer in any pipeline to test it if you're interested.

@0xdevalias
Copy link
Contributor

How did you create the Tokenizer at huggingface.co/stabilityai/stable-diffusion-2/tree/main/tokenizer?

@hafriedlander Given that is the official stabilityai repo, presumably noone here in huggingface/diffusers made it, and that was just what was released with SDv2?

@hafriedlander
Copy link
Contributor

@0xdevalias not sure. @patrickvonplaten said that the penultimate layer fix was invented by @patil-suraj, who's a HuggingFace person, not a Stability person. Anyway, I'm not saying mine is correct or anything, just that, in the limited testing I've done, I like the result way more, and that's weird.

@patil-suraj
Copy link
Contributor

OK, differential diagnostic done, it's the Tokenizer. How did you create the Tokenizer at https://huggingface.co/stabilityai/stable-diffusion-2/tree/main/tokenizer? I just built a Tokenizer using AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - it seems to give much better results.

Thanks, will take a look. Also, could you post some results here so we could see the differences ? I'm compare the results with original repo and they seemed to match, I'll take a look again.

@patil-suraj
Copy link
Contributor

Also could you post the prompts that gave you bad results ?

@hafriedlander
Copy link
Contributor

The whole model seems very sensitive to style shifts.

https://imgur.com/a/dUb93fD is three images with the standard tokenizer. The prompt for the first is

"A full portrait of a teenage smiling, beautiful post apocalyptic female princess, intricate, elegant, highly detailed, digital painting, artstation, smooth, sharp focus, illustration, art by krenz cushart and artem demura and alphonse mucha"

The prompt for the second is exactly the same, but with the addition of a negative prompt "bad teeth, missing teeth"

The third is the first prompt, but without the word smiling

Here is the same with my version of the tokenizer https://imgur.com/a/Wr5Sw9P

The second version with the original tokenizer is great. But I would not normally expect to see a big shift in quality from the addition of a negative prompt like that.

I'll track down another of my recent prompts where I much preferred my tokenizer, and see if adding a negative prompt helps.

@patil-suraj
Copy link
Contributor

Thank you! Will also compare using these prompts.

@patil-suraj
Copy link
Contributor

I noticed one difference, the original open_clip tokenizer that is used to train SD2 uses 0 as pad_token_id, while the AutoTokenizer that you posted uses 49407. So the current tokenizer matches the original implementation, we can verify it using the code below.

from transformers import CLIPTokenizer, AutoTokenizer
from open_clip import tokenize

tok = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
tok2 = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") 

prompt = "A full portrait of a teenage smiling, beautiful post apocalyptic female princess, intricate, elegant, highly detailed, digital painting, artstation, smooth, sharp focus, illustration, art by krenz cushart and artem demura and alphonse mucha"

tok_orig = tokenize(prompt)
tok_current = tok(prompt, padding="max_length", max_length=77, return_tensors="pt").input_ids
tok_auto = tok2(prompt, padding="max_length", max_length=77, return_tensors="pt", truncation=True).input_ids

assert torch.all(tok_orig == tok_current) # True
assert torch.all(tok_orig == tok_auto) # False

cc @patrickvonplaten

@anton-l
Copy link
Member

anton-l commented Nov 25, 2022

diffusers==0.9.0 with Stable Diffusion 2 is live! https://github.com/huggingface/diffusers/releases/tag/v0.9.0

@hamzafar
Copy link

Yes, stable_diffusion2 is working now. And the few lines of code to get inference is in here: colab.research.google.com/drive/1Na9x7w7RSbk2UFbcnrnuurg7kFGeqBsa?usp=sharing

@hamzafar In one of the last cells (that sets up EulerDiscreteScheduler) the following warning is shown. I wonder if things would work differently/better if ftfy or spacy was installed alongside the other requirements?

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.

@0xdevalias I have generated images with and without ftfy. I can't observe any difference in the results:
https://colab.research.google.com/drive/1Na9x7w7RSbk2UFbcnrnuurg7kFGeqBsa?usp=sharing

@patrickvonplaten
Copy link
Contributor

Sorry the warning is misleading and coming from transformers - you can safely ignore it. I'll try to fix it in Transformers

@0xdevalias
Copy link
Contributor

0xdevalias commented Nov 26, 2022

when will Dreambooth support sd2

While it's not dreambooth, this repo seems to have support for finetuning SDv2:

Originally posted by @0xdevalias in JoePenna/Dreambooth-Stable-Diffusion#112 (comment)


And looking at the huggingface/diffusers repo, there are a few issues that seem to imply people may be getting dreambooth things working with that (or at least trying to), eg.:

Originally posted by @0xdevalias in JoePenna/Dreambooth-Stable-Diffusion#112 (comment)

@vvsotnikov
Copy link
Contributor

vvsotnikov commented Nov 29, 2022

UPDATE: the issue is gone with the newer build of xformers

Hi, I'm using diffusers==0.9.0 and xformers==0.0.15.dev0+1515f77.d20221129, and for me, xformers makes SD 2.0 roughly x1.5 slower with xformers than without it (while it indeed saves some VRAM). At the same time, SD 1.5 runs about x1.5 faster with xformers, so it's unlikely that there's something wrong with my setup :) Is it a known issue? Here are some code samples to reproduce the issue:

# SD2, xformers disabled -> 5.02it/s
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
repo_id = "stabilityai/stable-diffusion-2"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
pipe.disable_xformers_memory_efficient_attention()
prompt = "An oil painting of white De Tomaso Pantera parked in the forest by Ivan Shishkin"
image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]  # warmup
image = pipe(prompt, guidance_scale=9, num_inference_steps=250, width=1024, height=576).images[0]
Fetching 12 files: 100%|##########| 12/12 [00:00<00:00, 52648.17it/s]
100%|##########| 25/25 [00:05<00:00,  4.70it/s]
100%|##########| 250/250 [00:49<00:00,  5.02it/s]
# SD2, xformers enabled ->  2.93it/s
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
repo_id = "stabilityai/stable-diffusion-2"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()  # explicitly enable xformers just in case
prompt = "An oil painting of white De Tomaso Pantera parked in the forest by Ivan Shishkin"
image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]  # warmup
image = pipe(prompt, guidance_scale=9, num_inference_steps=250, width=1024, height=576).images[0]
Fetching 12 files: 100%|##########| 12/12 [00:00<00:00, 43804.74it/s]
100%|##########| 25/25 [00:08<00:00,  2.90it/s]
100%|##########| 250/250 [01:25<00:00,  2.93it/s]
# SD1.5, xformers disabled -> 5.66it/s
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, revision="fp16")
pipe = pipe.to("cuda")
pipe.disable_xformers_memory_efficient_attention() 
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
image = pipe(prompt, width=880, num_inference_steps=150).images[0]  
Fetching 15 files: 100%|##########| 15/15 [00:00<00:00, 56987.83it/s]
100%|##########| 51/51 [00:04<00:00, 10.85it/s]
100%|##########| 151/151 [00:26<00:00,  5.66it/s]
# SD1.5, xformers enabled -> 7.94it/s
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, revision="fp16")
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
image = pipe(prompt, width=880, num_inference_steps=150).images[0]  
Fetching 15 files: 100%|##########| 15/15 [00:00<00:00, 54660.78it/s]
100%|##########| 51/51 [00:04<00:00, 12.42it/s]
100%|##########| 151/151 [00:19<00:00,  7.94it/s]

@alfredplpl
Copy link

alfredplpl commented Nov 30, 2022

Hi @hafriedlander ,
Your conversion script is really nice!

#1388 (comment)

It is helpful for me to fine-tune the SD 2.0 model.
However, I find two weird things.

  1. the max_length of tokenizer is too big such as 1000000000000000019884624838656. I fix it by hard coding temporally.
  2. The output of the converted model and the one of the original model are very different.
    Please tell me about these things if you are any ideas.

I tell you the details.

I generate the former attach image by the following prompt.

girl silver hair, sharp lighting. bright color, sun shining through, sharp focus, illustration, by konstantin razumov

I generated the former image by the following command.

python script/txt2img.py --prompt "girl silver hair, sharp lighting. bright color, sun shining through, sharp focus, illustration, by konstantin razumov" --W 512 --H 512 --ckpt /path/to/checkpoints/last.ckpt

Then, I got the former image.
00501

Next, I converted the model by the following command.

python scripts/convert_original_stable_diffusion_2_to_diffusers.py --checkpoint_path /path/to/checkpoints/last.ckpt --dump_path /path/to/dump --original_config_file /path/to/config/v2-finetune.yaml --scheduler_type ddim

I generated the latter image by the following code.

pipe = StableDiffusionPipeline.from_pretrained( "dump", torch_dtype=torch.float16 ) pipe=pipe.to("cuda:0") image = pipe(prompt, num_inference_steps=50,seed=0,height=512,width=512).images[0] image.save(out_path)
Then, I got the latter image.

000000492

Thanks in advance.

@hafriedlander
Copy link
Contributor

@alfredplpl I haven't kept the converter updated for the recent Diffusers changes, since there's an official release of the model. The tokenizer specifically is definitely wrong, and it uses "velocity" instead of "v_predict". It's probably broken in other ways too. I'd start of by just trying to copy the tokenizer from the official SD2 version and see if that helps

@djdookie
Copy link

djdookie commented Dec 2, 2022

I finetuned stabilityai/stable-diffusion-2-base with the diffusers repo.
Then I used https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py to convert the weights to original stable diffusion ckpt to use it in AUTOMATIC1111's web-ui.
Unfortunately the results in this web-ui gave me completely different and wrong results for the converted checkpoint compared to those inferenced from the finetuned diffusers weights using DDIMScheduler given the exact same parameters.

I put this yaml besides the converted checkpoint to use it in the webui:

model:
  base_learning_rate: 1.0e-4
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False # we set this to false because this is an inference only config

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        use_checkpoint: True
        use_fp16: True
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64 # need to fix for flash-attn
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          #attn_type: "vanilla-xformers"
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
      params:
        freeze: True
        layer: "penultimate"

What can I do to get the same results? Do we need an updated conversion script for SD2.0?

@patrickvonplaten
Copy link
Contributor

Two things here:
1 - we should try to get diffusers supported in AUTOMATIC1111 webui. The constant back and forth converting really becomes a mess
2 - it's not trivial to convert stable diffusion v2 to .ckpt format once trained with dreambooth because the text encoder is different

@alfredplpl
Copy link

alfredplpl commented Dec 3, 2022

@hafriedlander Kohya S., a Japanse depevelor, have written the conversion code. https://note.com/kohya_ss/n/n374f316fe4ad
It seems a good work. The following images are generated by the coverted model.

000000111
000000118

Filarh referenced this issue in TheLastBen/fast-stable-diffusion Dec 3, 2022
@github-actions
Copy link

github-actions bot commented Jan 7, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests