In [None]:
#!pip install git+https://github.com/huggingface/diffusers  
#credit to ASOMOZA..  additions are neg prompt embeddings, only the full repo, allowing for guidance manipulation.. and even if it wont make it to video memory, still faster then the chopped up text encoders + q5 gguf i tried .

In [None]:
import gc, time, os, subprocess
import torch
from transformers import (
    CLIPTextModelWithProjection,
    CLIPTokenizer,
    LlamaForCausalLM,
    PreTrainedTokenizerFast,
    T5EncoderModel,
    T5Tokenizer,
)

from diffusers import AutoencoderKL, HiDreamImagePipeline, HiDreamImageTransformer2DModel, UniPCMultistepScheduler
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.image_processor import VaeImageProcessor


repo_id = "HiDream-ai/HiDream-I1-Full"
llama_repo_id = "meta-llama/Llama-3.1-8B-Instruct"
device = torch.device("cuda")
torch_dtype = torch.bfloat16
prompt = "A photorealistic close-up of a single, iridescent hummingbird hovering mid-air, its wings a blur of sapphire and emerald, drinking nectar from a bioluminescent flower that emits soft, swirling particles of golden light. In the foreground, a single dewdrop clings precariously to a spiderweb woven with threads of pure silver. The background is a hyper-detailed, otherworldly jungle at twilight, with colossal, natural crystalline trees reflecting a nebula-filled sky. The overall atmosphere should be one of serene magic and vibrant detail"

negative_prompt = "cartoon, anime"

width=1344
height=768
guidance_scale=7.5
num_inference_steps=28

def flush(device):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)

    print(f"Current CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"Current CUDA memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")


# Modified encode_prompt to include negative prompt handling
def encode_prompt(
    prompt, negative_prompt, pipeline_repo_id, llama_repo_id, device=device, dtype=torch_dtype
):
    # Ensure prompts are lists
    prompt = [prompt] if isinstance(prompt, str) else prompt
    negative_prompt = [negative_prompt] * len(prompt) # Create list from single negative prompt

    # --- CLIP 1 ---
    tokenizer = CLIPTokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer")
    text_encoder = CLIPTextModelWithProjection.from_pretrained(
        pipeline_repo_id, subfolder="text_encoder", torch_dtype=torch_dtype
    ).to(device)

    prompt_embeds = get_clip_prompt_embeds(prompt, tokenizer, text_encoder)
    prompt_embeds_1 = prompt_embeds.clone().detach()
    negative_prompt_embeds = get_clip_prompt_embeds(negative_prompt, tokenizer, text_encoder)
    negative_prompt_embeds_1 = negative_prompt_embeds.clone().detach()

    text_encoder.to("cpu")
    del prompt_embeds, negative_prompt_embeds
    del tokenizer
    del text_encoder
    flush(device)

    # --- CLIP 2 ---
    tokenizer = CLIPTokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer_2")
    text_encoder = CLIPTextModelWithProjection.from_pretrained(
        pipeline_repo_id, subfolder="text_encoder_2", torch_dtype=torch_dtype
    ).to(device)

    prompt_embeds = get_clip_prompt_embeds(prompt, tokenizer, text_encoder)
    prompt_embeds_2 = prompt_embeds.clone().detach()
    negative_prompt_embeds = get_clip_prompt_embeds(negative_prompt, tokenizer, text_encoder)
    negative_prompt_embeds_2 = negative_prompt_embeds.clone().detach()

    text_encoder.to("cpu")
    del prompt_embeds, negative_prompt_embeds
    del tokenizer
    del text_encoder
    flush(device)

    # --- Pooled Embeddings ---
    pooled_prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
    negative_pooled_prompt_embeds = torch.cat([negative_prompt_embeds_1, negative_prompt_embeds_2], dim=-1)

    # --- T5 ---
    tokenizer = T5Tokenizer.from_pretrained(pipeline_repo_id, subfolder="tokenizer_3", torch_dtype=torch_dtype)
    text_encoder = T5EncoderModel.from_pretrained(
        pipeline_repo_id, subfolder="text_encoder_3", torch_dtype=torch_dtype
    ).to(device)

    # Positive T5
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=128,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    attention_mask = text_inputs.attention_mask
    prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
    t5_prompt_embeds = prompt_embeds.clone().detach()

    # Negative T5
    negative_text_inputs = tokenizer(
        negative_prompt,
        padding="max_length",
        max_length=128,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
        )
    negative_input_ids = negative_text_inputs.input_ids
    negative_attention_mask = negative_text_inputs.attention_mask
    negative_prompt_embeds = text_encoder(negative_input_ids.to(device), attention_mask=negative_attention_mask.to(device))[0]
    t5_negative_prompt_embeds = negative_prompt_embeds.clone().detach()

    del prompt_embeds, negative_prompt_embeds
    del text_inputs, negative_text_inputs
    del text_encoder
    del tokenizer
    flush(device)

    # --- Llama3 ---
    tokenizer = PreTrainedTokenizerFast.from_pretrained(llama_repo_id)
    tokenizer.pad_token = tokenizer.eos_token
    text_encoder = LlamaForCausalLM.from_pretrained(
        llama_repo_id,
        output_hidden_states=True,
        output_attentions=True, # Keep original settings
        torch_dtype=torch_dtype,
    ).to(device)

    # Positive Llama3
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=128,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    attention_mask = text_inputs.attention_mask
    outputs = text_encoder(
        text_input_ids.to(device),
        attention_mask=attention_mask.to(device),
        output_hidden_states=True,
        output_attentions=True, 
    )
    # Stack all hidden layers (excluding input embeddings layer 0)
    prompt_embeds = torch.stack(outputs.hidden_states[1:], dim=0)
    llama3_prompt_embeds = prompt_embeds.clone().detach()

    # Negative Llama3
    negative_text_inputs = tokenizer(
        negative_prompt,
        padding="max_length",
        max_length=128,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    negative_input_ids = negative_text_inputs.input_ids
    negative_attention_mask = negative_text_inputs.attention_mask
    outputs = text_encoder(
        negative_input_ids.to(device),
        attention_mask=negative_attention_mask.to(device),
        output_hidden_states=True,
        output_attentions=True,
    )
    # Stack all hidden layers (excluding input embeddings layer 0)
    negative_prompt_embeds = torch.stack(outputs.hidden_states[1:], dim=0)
    llama3_negative_prompt_embeds = negative_prompt_embeds.clone().detach()

    del prompt_embeds, negative_prompt_embeds
    del outputs
    del text_inputs, negative_text_inputs
    del text_encoder
    del tokenizer
    flush(device)

    # --- Assemble Embeddings ---
    final_prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
    final_negative_prompt_embeds = [t5_negative_prompt_embeds, llama3_negative_prompt_embeds]

    embeds = {
        "prompt_embeds": final_prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "negative_prompt_embeds": final_negative_prompt_embeds,
        "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
    }

    return embeds



def get_clip_prompt_embeds(prompt, tokenizer, text_encoder):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids.to(device) 
    prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)

    
    prompt_embeds = prompt_embeds[0]

    return prompt_embeds


def denoise(embeddings, device=device, dtype=torch_dtype):
    scheduler = UniPCMultistepScheduler(
        flow_shift=3.0,
        prediction_type="flow_prediction",
        use_flow_sigmas=True,
    )

    transformer = HiDreamImageTransformer2DModel.from_pretrained(
        "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch_dtype
    )

    transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch_dtype)
    apply_group_offloading(
        transformer,
        onload_device=device,
        offload_device=torch.device("cpu"),
        offload_type="leaf_level",
        use_stream=True,
        low_cpu_mem_usage=True,
    )

    pipe = HiDreamImagePipeline.from_pretrained(
        repo_id,
        text_encoder=None,
        tokenizer=None,
        text_encoder_2=None,
        tokenizer_2=None,
        text_encoder_3=None,
        tokenizer_3=None,
        text_encoder_4=None,
        tokenizer_4=None,
        transformer=transformer,
        scheduler=scheduler,
        vae=None,
        torch_dtype=torch_dtype,
    )

    
    prompt_embeds_t5, prompt_embeds_llama3 = embeddings["prompt_embeds"]
    pooled_prompt_embeds = embeddings["pooled_prompt_embeds"]
    negative_prompt_embeds_t5, negative_prompt_embeds_llama3 = embeddings["negative_prompt_embeds"]
    negative_pooled_prompt_embeds = embeddings["negative_pooled_prompt_embeds"]


    latents = pipe(
        prompt_embeds_t5=prompt_embeds_t5,
        prompt_embeds_llama3=prompt_embeds_llama3, 
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_prompt_embeds_t5=negative_prompt_embeds_t5,
        negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, 
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator(device).manual_seed(0),
        output_type="latent",
        return_dict=False,
    )[0]
    
    print(pipe)
    del pipe
    flush(device)

    return latents

# --- Main Execution ---
time_gen = time.time()

with torch.no_grad():
    # Pass negative_prompt to the encoding function
    embeddings = encode_prompt(prompt, negative_prompt, repo_id, llama_repo_id, device=device, dtype=torch_dtype)


latents = denoise(embeddings, device=device, dtype=torch_dtype) 

vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=torch_dtype).to(device)

latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

with torch.no_grad():
    image = vae.decode(latents, return_dict=False)[0]


vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image = image_processor.postprocess(image, output_type="pil")[0]

filename = f"hidream_cfg{guidance_scale}_steps_{num_inference_steps}{str(int(time.time()))}.png"
result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used,temperature.gpu,utilization.gpu', '--format=csv,noheader'], encoding='utf-8', timeout=1.0)
image.save(filename)
os.startfile(filename)
print(f"   ... Generated in {time.time() - time_gen:.2f} secs, mem/temp/use: {result.strip()}")