-
Notifications
You must be signed in to change notification settings - Fork 63
Open
Description
🐛 Describe the bug
import torch
import os
from diffusers import FluxTransformer2DModel
torch.use_deterministic_algorithms(True)
def get_dummy_tensor_inputs(device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
device_module.reset_peak_memory_stats()
device_module.empty_cache()
model(**inputs)
max_mem_allocated = device_module.max_memory_allocated()
return max_mem_allocated
torch_device = "xpu" if torch.xpu.is_available() else "cuda"
if torch_device == "cuda":
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
device_module = torch.xpu if torch.xpu.is_available() else torch.cuda
model_id = "hf-internal-testing/tiny-flux-pipe"
print(f"max allocated memory before loading: {device_module.max_memory_allocated()}")
inputs = get_dummy_tensor_inputs(device=torch_device)
print(f"max allocated memory after get inputs: {device_module.max_memory_allocated()}")
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=None, torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get model: {device_module.max_memory_allocated()}")
with torch.no_grad(), torch.inference_mode():
transformer(**inputs)
print(f"max allocated memory after model inference: {device_module.max_memory_allocated()}")outputs on XPU (Intel(R) Data Center GPU Max 1550):
max allocated memory before loading: 1024
max allocated memory after get inputs: 6656
max allocated memory after get model: 158208
max allocated memory after model inference: 246272
Outputs on CUDA:
max allocated memory before loading: 0
max allocated memory after get inputs: 5632
max allocated memory after get model: 157184
max allocated memory after model inference: 1416704
The max allocated memory after model inference is 246272 on XPU but 1416704 on CUDA. We'd like to know where the difference comes from.
Versions
[pip3] torch==2.10.0.dev20251008+xpu
[pip3] torchao==0.14.0.dev20251009+xpu
transformers 4.57.0
diffusers 0.36.0.dev0
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working