diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index c61baefa08f4..42e65d898cec 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -24,12 +24,12 @@ if is_torch_npu_available(): import torch_npu -ACTIVATION_FUNCTIONS = { - "swish": nn.SiLU(), - "silu": nn.SiLU(), - "mish": nn.Mish(), - "gelu": nn.GELU(), - "relu": nn.ReLU(), +ACT2CLS = { + "swish": nn.SiLU, + "silu": nn.SiLU, + "mish": nn.Mish, + "gelu": nn.GELU, + "relu": nn.ReLU, } @@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module: """ act_fn = act_fn.lower() - if act_fn in ACTIVATION_FUNCTIONS: - return ACTIVATION_FUNCTIONS[act_fn] + if act_fn in ACT2CLS: + return ACT2CLS[act_fn]() else: - raise ValueError(f"Unsupported activation function: {act_fn}") + raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") class FP32SiLU(nn.Module):