In [None]:
# !pip list | egrep -w "diffusers|quanto|transformers"

```
diffusers                     0.30.0
optimum-quanto                0.2.4
sentence-transformers         2.3.1
transformers                  4.40.2
transformers-stream-generator 0.0.5
```

In [None]:
import torch

from optimum.quanto import freeze, qfloat8, quantize, qint8, qint4

from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast

# black-forest-labs/FLUX.1-schnell
fluxrepo_local = "/home/g/models/FLUX.1-schnell"
dtype = torch.bfloat16

scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    fluxrepo_local, 
    subfolder="scheduler"
)
text_encoder = CLIPTextModel.from_pretrained(
    fluxrepo_local, 
    subfolder="text_encoder", 
    torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(
    fluxrepo_local, 
    subfolder="tokenizer", 
    torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    fluxrepo_local, 
    subfolder="text_encoder_2", 
    torch_dtype=dtype
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
    fluxrepo_local, 
    subfolder="tokenizer_2", 
    torch_dtype=dtype
)
vae = AutoencoderKL.from_pretrained(
    fluxrepo_local, 
    subfolder="vae", 
    torch_dtype=dtype
)
transformer = FluxTransformer2DModel.from_pretrained(
    fluxrepo_local, 
    subfolder="transformer", 
    torch_dtype=dtype
)

## OOM on a 3090 if you don't quantize these two 
print("quantizing transformer ..")
quantize(transformer.to("cuda"), weights=qfloat8)
freeze(transformer)

print("quantizing text_encoder_2 ..")
quantize(text_encoder_2.to("cuda"), weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline(
    scheduler=scheduler,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    text_encoder_2=text_encoder_2,
    tokenizer_2=tokenizer_2,
    vae=vae,
    transformer=transformer,
)
# .to("cuda")

In [None]:
# just to check how much VRAM actually used: ~18GB 
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
prompt = "photo of a cat riding a boeing 747"
seed = 42

# SCHNELL gen. params (https://huggingface.co/black-forest-labs/FLUX.1-schnell)
image = pipe(
    prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")

# DEV gen. params (https://huggingface.co/black-forest-labs/FLUX.1-dev)
# image = pipe(
#     prompt,
#     height=1024,
#     width=1024,
#     guidance_scale=3.5,
#     num_inference_steps=50,
#     max_sequence_length=512,
#     generator=torch.Generator("cpu").manual_seed(seed)
# ).images[0]
# image.save("flux-dev.png")