Skip to content

Commit

Permalink
correcting PR issues + Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieujouffroy committed Sep 1, 2022
1 parent e842785 commit 850c8eb
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 76 deletions.
5 changes: 1 addition & 4 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",

# The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here
if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
if tf_name[0] == "cvt":
tf_name[-1] = "weight"
else:
tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")

# Remove prefix if needed
tf_name = ".".join(tf_name)
Expand Down
Loading

0 comments on commit 850c8eb

Please sign in to comment.