Skip to content

RMSNorm's weight not registered as submodules when initializing #11938

@Darkbblue

Description

@Darkbblue

Describe the bug

RMSNorm() look like this, which is in models/normalization.py

class RMSNorm(nn.Module):
    def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
        super().__init__()

        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if isinstance(dim, numbers.Integral):
            dim = (dim,)

        self.dim = torch.Size(dim)

        self.weight = None
        self.bias = None

        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
            if bias:
                self.bias = nn.Parameter(torch.zeros(dim))
        print(self.weight)
        print(self)

This module is used in Flux.1-dev, and maybe other models I don't know.
When I print the model, it looks like (norm_q): RMSNorm(), which means the self.weight is not registered as submodules.
I've checked that self.weight does exist. Actually, when initialized from Attention, elementwise_affine and bias are never modified, so they are using the default values.
The output looks like this:

tensor(..., device='meta', size=(128,), requires_grad=True)
RMSNorm()

LLM told me it's because the self.weight is assigned in a conditional branch, AND it's initialized on meta device with from_pretrained(). But it's not really the cause. I removed the branch but the bug persisted.
This bug might not be fatal in most cases, but it prevents me from using FSDP with Flux.

Reproduction

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16
).to('cuda')

prompt = "a cat holding a paper with word prompt on it"
image = pipe(
    prompt,
    height=1024,
    width=1024,
).images[0]
image.save("flux.png")
'''

### Logs

```shell

System Info

diffusers["torch"]==0.32.2

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions