diff --git a/modeling.py b/modeling.py index fed525971..19cd745c8 100644 --- a/modeling.py +++ b/modeling.py @@ -302,18 +302,18 @@ def get_activation(activation_string): return None act = activation_string.lower() - if act == "linear": - return None - elif act == "relu": - return tf.nn.relu - elif act == "gelu": - return gelu - elif act == "tanh": - return tf.tanh - else: + mapping_dict = { + "linear": None, + "relu": tf.nn.relu, + "gelu": gelu, + "tanh": tf.tanh, + } + + try: + return mapping_dict[act] + except Exception as e: raise ValueError("Unsupported activation: %s" % act) - def get_assignment_map_from_checkpoint(tvars, init_checkpoint): """Compute the union of the current variables and checkpoint variables.""" assignment_map = {}