Skip to content

Commit

Permalink
fix load weights (#8528)
Browse files Browse the repository at this point in the history
* fix load weights

* delete line
  • Loading branch information
patrickvonplaten committed Nov 13, 2020
1 parent f6f4da8 commit f6cdafd
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/transformers/modeling_t5.py
Expand Up @@ -108,19 +108,38 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
continue
pointer = model
array = tf_weights[txt_name]

for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] in ["kernel", "scale", "embedding"]:
pointer = getattr(pointer, "weight")
elif scope_names[0] == "self_attention":
pointer = getattr(pointer, "layer")
pointer = pointer[0]
elif scope_names[0] == "enc_dec_attention":
pointer = getattr(pointer, "layer")
pointer = pointer[1]
elif scope_names[0] == "dense_relu_dense":
pointer = getattr(pointer, "layer")
pointer = pointer[2]
elif scope_names[0] == "rms_norm":
if hasattr(pointer, "layer_norm"):
pointer = getattr(pointer, "layer_norm")
elif hasattr(pointer, "final_layer_norm"):
pointer = getattr(pointer, "final_layer_norm")
elif scope_names[0] == "scale":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
elif scope_names[0] == "decoder" and name[1] == "logits":
continue
elif scope_names[0] == "logits":
pointer = getattr(pointer, "lm_head")
else:
try:
pointer = getattr(pointer, scope_names[0])
Expand Down

0 comments on commit f6cdafd

Please sign in to comment.