Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small change to Wav2Vec2 model to support Tensor-Parallelism with DeepSpeed #14298

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment is useful when reading the new code. It creates more confusion than help, only the next one is really important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I can remove this.

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -257,7 +260,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a 119 char limits so you can use more horizontal space :-)

Also, I suggest the following change, more to the point:

Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be partitioned across GPUs when using tensor-parallelism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will reformat this :)

attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -1330,7 +1333,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -259,7 +262,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -257,7 +260,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -474,7 +477,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/m2m_100/modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -328,7 +331,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -274,7 +277,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -263,7 +266,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -274,7 +277,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -473,7 +476,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/speech_to_text/modeling_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -341,7 +344,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -281,7 +284,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -542,7 +545,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,10 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

# Use the class's parameter as the hidden_state's last dimension.
# This dimension cannot be used in case of enabling tensor-parallelism.
bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -543,7 +546,13 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the embed_dim from class rather than hidden_state, this is due to
# the reason that attn_output can be partitioned across GPUs
# when using tensor-parallelism, in which case the embed_dimension from
# the input is not equal to the attention's last dimension after merging
# heads.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Loading