Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise exceptions instead of using asserts in modeling_openai #12789 #14386

Merged
merged 2 commits into from Nov 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 23 additions & 26 deletions src/transformers/models/openai/modeling_openai.py
Expand Up @@ -83,13 +83,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
# del init_params[1]
init_params = [arr.squeeze() for arr in init_params]

try:
assert model.tokens_embed.weight.shape == init_params[1].shape
assert model.positions_embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
raise
# Check that the token and position embeddings weight dimensions map those of the init parameters.
if model.tokens_embed.weight.shape != init_params[1].shape:
raise ValueError(
f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape: {init_params[1].shape}"
)

if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape: {init_params[0].shape}"
)

model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
Expand All @@ -100,7 +103,8 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):

for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
if name[-2:] != ":0":
raise ValueError(f"Layer {name} does not end with :0")
name = name[:-2]
name = name.split("/")
pointer = model
Expand All @@ -120,20 +124,11 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise

# Ensure that the pointer and array have compatible shapes.
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")

logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
Expand All @@ -147,7 +142,8 @@ def __init__(self, nx, n_positions, config, scale=False):
super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implementation]
assert n_state % config.n_head == 0
if n_state % config.n_head != 0:
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
self.register_buffer(
"bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
)
Expand Down Expand Up @@ -804,9 +800,10 @@ def forward(
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
# Ensure the batch size is > 1 if there is no padding.
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")

if self.config.pad_token_id is None:
sequence_lengths = -1
else:
Expand Down