In [None]:
from diffusers import CogView4Pipeline, CogView4Transformer2DModel #, BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import GlmModel, BitsAndBytesConfig as TransformersBitsAndBytesConfig
import torch, time, gc, os

def flush():
    gc.collect()
    torch.cuda.empty_cache()
def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

id = "THUDM/CogView4-6B"
shortid = "cogview4"
device = "cuda"
dtype = torch.bfloat16

text_encoder = GlmModel.from_pretrained(
    id,
    subfolder="text_encoder",
    device_map="auto",
)
pipeline = CogView4Pipeline.from_pretrained(
    id,
    text_encoder=text_encoder,
    transformer=None,
    vae=None,
    device_map="balanced",
)

In [None]:

with torch.no_grad():
    prompt = "A photorealistic close-up of a single, iridescent hummingbird hovering mid-air, its wings a blur of sapphire and emerald, drinking nectar from a luminous, bioluminescent flower that emits soft, swirling particles of golden light. The background is a hyper-detailed, otherworldly jungle at twilight, with colossal, crystalline trees reflecting a nebula-filled sky. In the foreground, a single dewdrop clings precariously to a spiderweb woven with threads of pure silver. The overall atmosphere should be one of serene magic and vibrant detail."

    negative_prompt = "poor quality, poor clarity, ugly, jpeg artifacts, cropped, lowres, error, out of frame, watermark"
    (
        prompt_embeds,
        negative_prompt_embeds,
    ) = pipeline.encode_prompt(prompt=prompt, negative_prompt=negative_prompt)

del text_encoder
del pipeline
flush()


In [None]:

pipeline = CogView4Pipeline.from_pretrained(
    id,
    text_encoder=None,
    tokenizer=None,
    torch_dtype=dtype
).to("cuda")
pipeline.enable_model_cpu_offload()


prompt_embeds_gen = prompt_embeds.to("cuda").to(dtype)
negative_prompt_embeds_gen = negative_prompt_embeds.to("cuda").to(dtype)


In [None]:
image = pipeline(
    prompt_embeds=prompt_embeds_gen,
    negative_prompt_embeds=negative_prompt_embeds_gen,
    guidance_scale=3.5,
    num_inference_steps=50,
    width=1344,
    height=768,                    
).images[0]


timestamp = str(int(time.time()))
filename = f"testembed_{timestamp}.png"
image.save(filename)
os.startfile(filename)
print(f"Image saved as {filename}") 
print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")