Skip to content

Commit

Permalink
ViT and Swin symbolic tracing with torch.fx (#17182)
Browse files Browse the repository at this point in the history
* Support tracing for ViT

* Swin support

* Fix copies

* Fix type annotation issue

* Removed unused import
  • Loading branch information
michaelbenayoun committed May 12, 2022
1 parent 1a68870 commit 8c7481f
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/deit/modeling_deit.py
Expand Up @@ -168,7 +168,7 @@ def __init__(self, config: DeiTConfig) -> None:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -200,7 +200,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dpt/modeling_dpt.py
Expand Up @@ -177,7 +177,7 @@ def __init__(self, config: DPTConfig) -> None:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/maskformer/modeling_maskformer.py
Expand Up @@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
Expand Down Expand Up @@ -697,7 +697,7 @@ def __init__(self, config, dim, num_heads):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -750,7 +750,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/swin/modeling_swin.py
Expand Up @@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
Expand Down Expand Up @@ -435,7 +435,7 @@ def __init__(self, config, dim, num_heads):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -488,7 +488,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down Expand Up @@ -1071,7 +1071,7 @@ def forward(
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape
height = width = int(sequence_length**0.5)
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)

# Reconstruct pixel values
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/vit/modeling_vit.py
Expand Up @@ -213,7 +213,7 @@ def __init__(self, config: ViTConfig) -> None:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -245,7 +245,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down Expand Up @@ -687,7 +687,7 @@ def forward(
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = int(sequence_length**0.5)
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)

# Reconstruct pixel values
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/vit_mae/modeling_vit_mae.py
Expand Up @@ -342,7 +342,7 @@ def __init__(self, config: ViTMAEConfig) -> None:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -374,7 +374,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/yolos/modeling_yolos.py
Expand Up @@ -280,7 +280,7 @@ def __init__(self, config: YolosConfig) -> None:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -312,7 +312,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
58 changes: 49 additions & 9 deletions src/transformers/utils/fx.py
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.

import builtins
import collections
import functools
import inspect
import math
import random
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

import torch
Expand All @@ -31,6 +31,7 @@
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
Expand Down Expand Up @@ -71,6 +72,7 @@ def _generate_supported_model_classes(
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
}

Expand Down Expand Up @@ -100,6 +102,8 @@ def _generate_supported_model_classes(
"gpt_neo",
"t5",
"roberta",
"vit",
"swin",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet",
Expand Down Expand Up @@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)


def torch_roll(input, shifts, dims=None):
return input


def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
h_out = math.floor(
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
shape[-2:] = [h_out, w_out]
shape[-3] = self.out_channels
return torch.empty(shape, device="meta")


def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
Expand Down Expand Up @@ -317,9 +346,11 @@ def torch_nn_bcewithlogitsloss(self, input, target):
torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override,
torch.roll: torch_roll,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
Expand Down Expand Up @@ -368,6 +399,9 @@ def __getattr__(self, k):
# we peephole optimize to the method invocation
return HFAttribute(self, k)

def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})

def __contains__(self, key):
# To handle cases such as :
# `"some_key" in kwargs`
Expand Down Expand Up @@ -521,6 +555,15 @@ def _generate_dummy_input(
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class} not supported yet.")
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = model.config.image_size
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
height, width = image_size
inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
)

elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
Expand Down Expand Up @@ -663,6 +706,11 @@ def trace(
else:
self.graph.erase_node(node)

# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None

return self.graph

def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
Expand Down Expand Up @@ -761,12 +809,4 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)

# Copy all the original attributes to the traced GraphModule.
regular_module_attributes = dir(nn.Module())
for name in dir(model):
attr = getattr(model, name)
if name.startswith("_") or name in regular_module_attributes:
continue
setattr(traced, name, deepcopy(attr))

return traced
1 change: 1 addition & 0 deletions tests/models/swin/test_modeling_swin.py
Expand Up @@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True

test_pruning = False
test_resize_embeddings = False
Expand Down
1 change: 1 addition & 0 deletions tests/models/vit/test_modeling_vit.py
Expand Up @@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True

test_pruning = False
test_resize_embeddings = False
Expand Down
9 changes: 1 addition & 8 deletions tests/test_modeling_common.py
Expand Up @@ -738,8 +738,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]

labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
Expand All @@ -756,12 +755,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa

model_output = model(**filtered_inputs)

rank = len(input_ids.shape)
if rank not in [2, 3]:
raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
)

traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)

Expand Down

0 comments on commit 8c7481f

Please sign in to comment.