Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder #15938

Merged
merged 3 commits into from Mar 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 48 additions & 20 deletions tests/wav2vec2/test_modeling_flax_wav2vec2.py
Expand Up @@ -37,6 +37,7 @@
import jax
import jax.numpy as jnp
import optax
from flax.traverse_util import flatten_dict
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
FlaxWav2Vec2ForCTC,
Expand Down Expand Up @@ -236,39 +237,66 @@ def test_freeze_feature_encoder(self):
attention_mask = inputs_dict["attention_mask"]

model = FlaxWav2Vec2ForPreTraining(config)

outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=False,
)

outputs_frozen = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=True,
)
params = model.params

# dummy loss function
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
def compute_loss(
params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8
):
outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=freeze_feature_encoder,
params=params,
)
# compute cosine similarity of projected and projected_quantized states
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
cosine_sim = optax.cosine_similarity(
outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon
)
loss = cosine_sim.sum()
return loss

# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)

# compute loss and gradients for unfrozen model
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
loss, grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False)

# compare to loss and gradients for frozen model
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
loss_frozen, grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True)

self.assert_almost_equals(loss, loss_frozen, 1e-5)

grads = flatten_dict(grads)
grads_frozen = flatten_dict(grads_frozen)

# ensure that the dicts of gradients contain the same keys
self.assertEqual(grads.keys(), grads_frozen.keys())

# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)

for feature_extractor_grad, feature_extractor_grad_frozen in zip(
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7)

# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)

for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-7)

def assert_difference(self, a, b, tol: float):
diff = jnp.abs((a - b)).min()
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")

self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
self.assertEqual(grads.shape, grads_frozen.shape)
max_diff = np.amax(np.abs(grads - grads_frozen))
self.assertLessEqual(max_diff, 1e-5)
def assert_almost_equals(self, a, b, tol: float):
diff = jnp.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")

@slow
def test_model_from_pretrained(self):
Expand Down