-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Describe the bug
RMSNorm() look like this, which is in models/normalization.py
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
self.weight = None
self.bias = None
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
if bias:
self.bias = nn.Parameter(torch.zeros(dim))
print(self.weight)
print(self)This module is used in Flux.1-dev, and maybe other models I don't know.
When I print the model, it looks like (norm_q): RMSNorm(), which means the self.weight is not registered as submodules.
I've checked that self.weight does exist. Actually, when initialized from Attention, elementwise_affine and bias are never modified, so they are using the default values.
The output looks like this:
tensor(..., device='meta', size=(128,), requires_grad=True)
RMSNorm()
LLM told me it's because the self.weight is assigned in a conditional branch, AND it's initialized on meta device with from_pretrained(). But it's not really the cause. I removed the branch but the bug persisted.
This bug might not be fatal in most cases, but it prevents me from using FSDP with Flux.
Reproduction
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
'black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16
).to('cuda')
prompt = "a cat holding a paper with word prompt on it"
image = pipe(
prompt,
height=1024,
width=1024,
).images[0]
image.save("flux.png")
'''
### Logs
```shellSystem Info
diffusers["torch"]==0.32.2
Who can help?
No response