In [None]:
import torch, os, gc, time, threading, subprocess, psutil, shutil
import numpy as np 
from typing import Any, Dict, List, Optional, Tuple 


from transformers import AutoTokenizer, LlamaForCausalLM, T5EncoderModel, T5Tokenizer, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel, UniPCMultistepScheduler, AutoencoderKL, BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.modeling_outputs import Transformer2DModelOutput 
from diffusers.utils import logging, deprecate, USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers 


logger = logging.get_logger(__name__)

def teacache_forward(
    self,
    hidden_states: torch.Tensor,
    timesteps: torch.LongTensor = None,
    encoder_hidden_states_t5: torch.Tensor = None,
    encoder_hidden_states_llama3: torch.Tensor = None,
    pooled_embeds: torch.Tensor = None,
    img_ids: Optional[torch.Tensor] = None,
    img_sizes: Optional[List[Tuple[int, int]]] = None,
    hidden_states_masks: Optional[torch.Tensor] = None,
    attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = True,
    **kwargs,
):
    encoder_hidden_states = kwargs.get("encoder_hidden_states", None)

    if encoder_hidden_states is not None:
        deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
        deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
        encoder_hidden_states_t5 = encoder_hidden_states[0]
        encoder_hidden_states_llama3 = encoder_hidden_states[1]

    if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
        deprecation_message = (
            "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
        )
        deprecate("img_ids", "0.35.0", deprecation_message)

    if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
        raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
    elif hidden_states_masks is not None and hidden_states.ndim != 3:
        raise ValueError(
            "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
        )

    if attention_kwargs is not None:
        attention_kwargs = attention_kwargs.copy()
        lora_scale = attention_kwargs.pop("scale", 1.0)
    else:
        lora_scale = 1.0

    if USE_PEFT_BACKEND:
        scale_lora_layers(self, lora_scale)
    else:
        if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
            logger.warning(
                "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
            )

    batch_size = hidden_states.shape[0]
    hidden_states_type = hidden_states.dtype

    if hidden_states_masks is None:
        hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)

    hidden_states = self.x_embedder(hidden_states)

    timesteps_embed = self.t_embedder(timesteps, hidden_states_type)
    p_embedder = self.p_embedder(pooled_embeds)
    temb = timesteps_embed + p_embedder

    encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]

    if self.caption_projection is not None:
        new_encoder_hidden_states = []
        for i, enc_hidden_state in enumerate(encoder_hidden_states):
            enc_hidden_state = self.caption_projection[i](enc_hidden_state)
            enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
            new_encoder_hidden_states.append(enc_hidden_state)
        encoder_hidden_states = new_encoder_hidden_states
        encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
        encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
        encoder_hidden_states.append(encoder_hidden_states_t5)

    txt_ids = torch.zeros(
        batch_size,
        encoder_hidden_states[-1].shape[1]
        + encoder_hidden_states[-2].shape[1]
        + encoder_hidden_states[0].shape[1],
        3,
        device=img_ids.device,
        dtype=img_ids.dtype,
    )
    ids = torch.cat((img_ids, txt_ids), dim=1)
    image_rotary_emb = self.pe_embedder(ids)

    block_id = 0
    initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
    initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]

    should_calc = True
    if self.enable_teacache:
        modulated_inp = timesteps_embed.clone()
        if self.cnt < self.ret_steps:
            should_calc = True
            self.accumulated_rel_l1_distance = 0
        else:
            distance = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
            rescale_func = np.poly1d(self.coefficients)
            self.accumulated_rel_l1_distance += rescale_func(distance)
            if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
                should_calc = False
            else:
                should_calc = True
                self.accumulated_rel_l1_distance = 0
        self.previous_modulated_input = modulated_inp 
        self.cnt += 1
        if self.cnt == self.num_steps:
            self.cnt = 0

    if self.enable_teacache and not should_calc:
        hidden_states += self.previous_residual
    else:
        ori_hidden_states = hidden_states.clone()
        for bid, block in enumerate(self.double_stream_blocks):
            cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
            cur_encoder_hidden_states = torch.cat(
                [initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
            )
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
                    block, hidden_states, hidden_states_masks, cur_encoder_hidden_states, temb, image_rotary_emb,
                )
            else:
                hidden_states, initial_encoder_hidden_states = block(
                    hidden_states=hidden_states, hidden_states_masks=hidden_states_masks,
                    encoder_hidden_states=cur_encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb,
                )
            initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
            block_id += 1

        image_tokens_seq_len = hidden_states.shape[1]
        hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
        hidden_states_seq_len = hidden_states.shape[1]
        if hidden_states_masks is not None:
            encoder_attention_mask_ones = torch.ones(
                (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
                device=hidden_states_masks.device, dtype=hidden_states_masks.dtype,
            )
            hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)

        for bid, block in enumerate(self.single_stream_blocks):
            cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
            hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                hidden_states = self._gradient_checkpointing_func(
                    block, hidden_states, hidden_states_masks, None, temb, image_rotary_emb,
                )
            else:
                hidden_states = block(
                    hidden_states=hidden_states, hidden_states_masks=hidden_states_masks,
                    encoder_hidden_states=None, temb=temb, image_rotary_emb=image_rotary_emb,
                )
            hidden_states = hidden_states[:, :hidden_states_seq_len]
            block_id += 1

        hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
        
        if self.enable_teacache:
            self.previous_residual = hidden_states - ori_hidden_states

    output = self.final_layer(hidden_states, temb)
    output = self.unpatchify(output, img_sizes, self.training)
    if hidden_states_masks is not None:
        hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]

    if USE_PEFT_BACKEND:
        unscale_lora_layers(self, lora_scale)

    if not return_dict:
        return (output,)
    return Transformer2DModelOutput(sample=output)


width=1024
height=1024

cfg, steps, model_repo_id = 3.5, 50, "HiDream-ai/HiDream-I1-Full"
llama_repo_id = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4"

TEMP_EMBED_DIR = "temp_embeddings"
tick_begins = time.time()

device = torch.device("cuda")
dtype = torch.bfloat16

def memory():
    try:
        gi = subprocess.run(['nvidia-smi','--query-gpu=pstate,memory.used,temperature.gpu,utilization.gpu',
                             '--format=csv,noheader'], capture_output=True, text=True, check=True).stdout.strip().split(',')
        vu, ps, util, temp, ram = float(gi[1].strip().replace(" MiB", "")) / 1024, gi[0].strip(), gi[3].strip(), gi[2].strip(), psutil.virtual_memory()
        print(f"   ..." + "\033[96m" + f"VRAM:" + "\033[93m" + f"{vu:.1f}" +  "\033[96m" + f"/24GB " + \
              "\033[93m" + f"{ps} " + "\033[96m" + f"{util} {temp}C | RAM:" + "\033[93m" + f"{ram.used/1024**3:.1f}" + "\033[96m" + f"/64GB")
    except Exception as e:
        print(f"   ...Could not get GPU stats: {e}")

def flush_vram():
    gc.collect()
    torch.cuda.empty_cache()

def save_and_open_image(image, timestamp):
    filename = f"{timestamp}_hidream_teacache.png"
    image.save(filename)
    print(f"\n   ...Saved image to {filename}")
    os.startfile(filename)

# === PHASE 1 (UNCHANGED) ===
def precompute_and_save_embeddings(prompts, negative_prompt, temp_dir):
    print(f"{'='*10} PHASE 1: Pre-computing and Saving Embeddings {'='*10}")
    os.makedirs(temp_dir, exist_ok=True)
    
    print("   ...Loading all text encoders...")
    tokenizer_1 = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer")
    text_encoder_1 = CLIPTextModelWithProjection.from_pretrained(model_repo_id, subfolder="text_encoder", torch_dtype=dtype).to(device)
    tokenizer_2 = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer_2")
    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_repo_id, subfolder="text_encoder_2", torch_dtype=dtype).to(device)
    tokenizer_3 = T5Tokenizer.from_pretrained(model_repo_id, subfolder="tokenizer_3")
    text_encoder_3 = T5EncoderModel.from_pretrained(model_repo_id, subfolder="text_encoder_3", torch_dtype=dtype).to(device)
    tokenizer_4 = AutoTokenizer.from_pretrained(llama_repo_id)
    tokenizer_4.pad_token = tokenizer_4.eos_token
    text_encoder_4 = LlamaForCausalLM.from_pretrained(llama_repo_id, output_hidden_states=True, torch_dtype=dtype).to(device)
    print("   ...All encoders loaded."); memory()

    for i, prompt in enumerate(prompts):
        print(f"   ...Encoding prompt {i+1}/{len(prompts)}...")
        with torch.inference_mode():
            prompt_list = [prompt]
            neg_prompt_list = [negative_prompt]
            prompt_embeds_1 = text_encoder_1(tokenizer_1(prompt_list, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device))[0]
            prompt_embeds_2 = text_encoder_2(tokenizer_2(prompt_list, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device))[0]
            pooled_prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
            neg_prompt_embeds_1 = text_encoder_1(tokenizer_1(neg_prompt_list, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device))[0]
            neg_prompt_embeds_2 = text_encoder_2(tokenizer_2(neg_prompt_list, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids.to(device))[0]
            negative_pooled_prompt_embeds = torch.cat([neg_prompt_embeds_1, neg_prompt_embeds_2], dim=-1)
            prompt_tokens_3 = tokenizer_3(prompt_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
            prompt_embeds_t5 = text_encoder_3(prompt_tokens_3.input_ids.to(device), attention_mask=prompt_tokens_3.attention_mask.to(device))[0]
            neg_prompt_tokens_3 = tokenizer_3(neg_prompt_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
            negative_prompt_embeds_t5 = text_encoder_3(neg_prompt_tokens_3.input_ids.to(device), attention_mask=neg_prompt_tokens_3.attention_mask.to(device))[0]
            prompt_tokens_4 = tokenizer_4(prompt_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
            prompt_outputs = text_encoder_4(prompt_tokens_4.input_ids.to(device), attention_mask=prompt_tokens_4.attention_mask.to(device))
            prompt_embeds_llama3 = torch.stack(prompt_outputs.hidden_states[1:], dim=0)
            neg_prompt_tokens_4 = tokenizer_4(neg_prompt_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
            neg_prompt_outputs = text_encoder_4(neg_prompt_tokens_4.input_ids.to(device), attention_mask=neg_prompt_tokens_4.attention_mask.to(device))
            negative_prompt_embeds_llama3 = torch.stack(neg_prompt_outputs.hidden_states[1:], dim=0)
            single_embeds = {
                "prompt_embeds_t5": prompt_embeds_t5.cpu(), "prompt_embeds_llama3": prompt_embeds_llama3.cpu(),
                "pooled_prompt_embeds": pooled_prompt_embeds.cpu(), "negative_prompt_embeds_t5": negative_prompt_embeds_t5.cpu(),
                "negative_prompt_embeds_llama3": negative_prompt_embeds_llama3.cpu(),
                "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds.cpu(),
            }
            torch.save(single_embeds, os.path.join(temp_dir, f"embed_{i}.pt"))

    print("   ...Purging all text encoders from memory...")
    del text_encoder_1, text_encoder_2, text_encoder_3, text_encoder_4, tokenizer_1, tokenizer_2, tokenizer_3, tokenizer_4
    flush_vram()
    print(f"{'='*10} PHASE 1 COMPLETE {'='*10}"); memory()


tick_0 = time.time()
prompts_to_generate = [
    "Gorilla in battlearena, his golden breastplate armor shines. He wields a massive black chain, it hangs down at his side, the length and weight a testament to his enormus strength. His fur rich brown, massive chest heaves with contained power, muscles rippled and buldge",
]

negative_prompt = "ugly, blurry, low quality" # Define the negative prompt


try:
    precompute_and_save_embeddings(prompts_to_generate, negative_prompt, TEMP_EMBED_DIR)

    print(f"\n{'='*10} PHASE 2: Generation Loop {'='*10}")
    print("   ...Loading generation components (Transformer, VAE)...")
    
    quant_config_diffusers = DiffusersBitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_use_double_quant=True,bnb_4bit_compute_dtype=dtype)
    transformer = HiDreamImageTransformer2DModel.from_pretrained(model_repo_id,subfolder="transformer",quantization_config=quant_config_diffusers,torch_dtype=dtype)
    vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae", torch_dtype=dtype)
    scheduler = UniPCMultistepScheduler(flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True)
    pipe = HiDreamImagePipeline(
        vae=vae, transformer=transformer, scheduler=scheduler,
        text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None,
        text_encoder_3=None, tokenizer_3=None, text_encoder_4=None, tokenizer_4=None
    ).to(device)
    print("   ...Generation components loaded and ready."); memory()


    print(f"\n{'='*10} ACTIVATING TEACACHE {'='*10}")
    HiDreamImageTransformer2DModel.forward = teacache_forward
    pipe.transformer.__class__.enable_teacache = True
    pipe.transformer.__class__.num_steps = steps
    pipe.transformer.__class__.ret_steps = int(steps * 0.1) # Warm-up steps
    # --- This is the main tuning knob for speed vs. quality ---
    REL_L1_THRESH = 0.45  # 0.3 gives ~2x speedup. Higher is faster but may reduce quality.
    pipe.transformer.__class__.rel_l1_thresh = REL_L1_THRESH
    pipe.transformer.__class__.coefficients = np.array([-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01])
    print(f"   ...TeaCache enabled with threshold: {REL_L1_THRESH}")
    # =====================================================================================

    for i in range(len(prompts_to_generate)):
        loop_start_time = time.time()
        print(f"\n--- Generating image {i+1}/{len(prompts_to_generate)} ---")
        
        # Reset TeaCache counter for each new image generation
        pipe.transformer.__class__.cnt = 0
        
        embed_path = os.path.join(TEMP_EMBED_DIR, f"embed_{i}.pt")
        single_embeds = torch.load(embed_path)
        for key in single_embeds:
            single_embeds[key] = single_embeds[key].to(device)
        print("   ...Embeddings loaded into VRAM."); memory()
        
        with torch.inference_mode():
            latents = pipe(
                **single_embeds,
                guidance_scale=cfg, num_inference_steps=steps, height=height, width=width,
                generator=torch.Generator(device).manual_seed(0), output_type="latent",
            ).images
            print("   ...Denoising complete."); memory()
            
            processed_latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
            image_tensor = vae.decode(processed_latents, return_dict=False)[0]
            vae_scale_factor = vae.config.scaling_factor * (2 ** (len(vae.config.block_out_channels) - 1))
            image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
            image = image_processor.postprocess(image_tensor, output_type="pil")[0]
            
            timestamp = int(time.time())
            threading.Thread(target=save_and_open_image, args=(image, timestamp)).start()
            del single_embeds, latents, image_tensor, image_processor
            flush_vram()

        print(f"   ...Image {i+1} completed in {time.time() - loop_start_time:.2f} secs")

finally:
    print(f"\n{'='*10} PHASE 3: Cleaning up temporary files {'='*10}")
    if os.path.exists(TEMP_EMBED_DIR):
        shutil.rmtree(TEMP_EMBED_DIR)
        print(f"   ...Removed temporary directory: {TEMP_EMBED_DIR}")
    print(f"\n--- All generations complete. Total script time: {time.time() - tick_0:.2f} secs ---\n--- ...   {time.time() - tick_begins:.2f} since start")