From 79d3096908ebe11cacf7dbf2f4045747e11fc078 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 19 Feb 2025 15:40:13 +0100 Subject: [PATCH 1/2] store cls instead of an obj --- src/diffusers/models/activations.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index c61baefa08f4..f6680c02afae 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -24,12 +24,12 @@ if is_torch_npu_available(): import torch_npu -ACTIVATION_FUNCTIONS = { - "swish": nn.SiLU(), - "silu": nn.SiLU(), - "mish": nn.Mish(), - "gelu": nn.GELU(), - "relu": nn.ReLU(), +ACT2CLS = { + "swish": nn.SiLU, + "silu": nn.SiLU, + "mish": nn.Mish, + "gelu": nn.GELU, + "relu": nn.ReLU, } @@ -44,11 +44,10 @@ def get_activation(act_fn: str) -> nn.Module: """ act_fn = act_fn.lower() - if act_fn in ACTIVATION_FUNCTIONS: - return ACTIVATION_FUNCTIONS[act_fn] + if act_fn in ACT2CLS: + return ACT2CLS[act_fn]() else: - raise ValueError(f"Unsupported activation function: {act_fn}") - + raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") class FP32SiLU(nn.Module): r""" From 392704bd787c0f88b03080e98f1709c0272380d3 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 19 Feb 2025 15:47:43 +0100 Subject: [PATCH 2/2] style --- src/diffusers/models/activations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f6680c02afae..42e65d898cec 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -49,6 +49,7 @@ def get_activation(act_fn: str) -> nn.Module: else: raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") + class FP32SiLU(nn.Module): r""" SiLU activation function with input upcasted to torch.float32.