diff --git a/docs/source/model_doc/beit.rst b/docs/source/model_doc/beit.rst index 238d5f46a28e..af658cf60dbc 100644 --- a/docs/source/model_doc/beit.rst +++ b/docs/source/model_doc/beit.rst @@ -98,6 +98,13 @@ BeitForImageClassification :members: forward +BeitForSemanticSegmentation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitForSemanticSegmentation + :members: forward + + FlaxBeitModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9fab897b3e2c..a64f592134c7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -635,6 +635,7 @@ "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "BeitForImageClassification", "BeitForMaskedImageModeling", + "BeitForSemanticSegmentation", "BeitModel", "BeitPreTrainedModel", ] @@ -2477,6 +2478,7 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, BeitForImageClassification, BeitForMaskedImageModeling, + BeitForSemanticSegmentation, BeitModel, BeitPreTrainedModel, ) diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py index b530f59d5b3a..c9e311d7cfe6 100644 --- a/src/transformers/models/beit/__init__.py +++ b/src/transformers/models/beit/__init__.py @@ -33,6 +33,7 @@ "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "BeitForImageClassification", "BeitForMaskedImageModeling", + "BeitForSemanticSegmentation", "BeitModel", "BeitPreTrainedModel", ] @@ -57,6 +58,7 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, BeitForImageClassification, BeitForMaskedImageModeling, + BeitForSemanticSegmentation, BeitModel, BeitPreTrainedModel, ) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index d31f83dd3a5e..bc1aa63197f7 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig): use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the CLS token, before applying the classification head. + out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`): + Indices of the feature maps to use for semantic segmentation. + pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (:obj:`int`, `optional`, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. Example:: @@ -117,6 +131,13 @@ def __init__( layer_scale_init_value=0.1, drop_path_rate=0.1, use_mean_pooling=True, + out_indices=[3, 5, 7, 11], + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, **kwargs ): super().__init__(**kwargs) @@ -142,3 +163,12 @@ def __init__( self.layer_scale_init_value = layer_scale_init_value self.drop_path_rate = drop_path_rate self.use_mean_pooling = use_mean_pooling + # decode head attributes (semantic segmentation) + self.out_indices = out_indices + self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index c550a56db36f..89e85fa8fe90 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -20,11 +20,18 @@ from pathlib import Path import torch +from datasets import load_dataset from PIL import Image import requests from huggingface_hub import cached_download, hf_hub_url -from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling +from transformers import ( + BeitConfig, + BeitFeatureExtractor, + BeitForImageClassification, + BeitForMaskedImageModeling, + BeitForSemanticSegmentation, +) from transformers.utils import logging @@ -33,27 +40,33 @@ # here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config, has_lm_head=False): +def create_rename_keys(config, has_lm_head=False, is_semantic=False): + prefix = "backbone." if is_semantic else "" + rename_keys = [] for i in range(config.num_hidden_layers): # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms - rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) - rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) - rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")) - rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")) - rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) - rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) - rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) - rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) - rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) - rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) + rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias") + ) + rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) # projection layer + position embeddings rename_keys.extend( [ - ("cls_token", "beit.embeddings.cls_token"), - ("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), - ("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), + (f"{prefix}cls_token", "beit.embeddings.cls_token"), + (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), + (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), ] ) @@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False): ("norm.bias", "layernorm.bias"), ] ) + elif is_semantic: + # semantic segmentation classification heads + rename_keys.extend( + [ + ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), + ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), + ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), + ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), + ] + ) else: # layernorm + classification head rename_keys.extend( @@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False): # we split up the matrix of each encoder layer into queries, keys and values -def read_in_q_k_v(state_dict, config, has_lm_head=False): +def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False): for i in range(config.num_hidden_layers): - prefix = "beit." + prefix = "backbone." if is_semantic else "" # queries, keys and values - in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") - q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias") - v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias") + in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight") + q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias") + v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias") - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ : config.hidden_size, : ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias + state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ config.hidden_size : config.hidden_size * 2, : ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ -config.hidden_size :, : ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias + state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias # gamma_1 and gamma_2 # we call them lambda because otherwise they are renamed when using .from_pretrained - gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1") - gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2") + gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1") + gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2") - state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1 - state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2 + state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1 + state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2 # relative_position bias table + index if not has_lm_head: # each layer has its own relative position bias - table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table") - index = state_dict.pop(f"blocks.{i}.attn.relative_position_index") + table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table") + index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index") state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" + f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" ] = table state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" + f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" ] = index @@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): # define default BEiT configuration config = BeitConfig() has_lm_head = False + is_semantic = False repo_id = "datasets/huggingface/label-files" # set config parameters based on URL if checkpoint_url[-9:-4] == "pt22k": @@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): config.image_size = 384 if "512" in checkpoint_url: config.image_size = 512 + elif "ade20k" in checkpoint_url: + # fine-tuning + config.use_relative_position_bias = True + config.num_labels = 150 + filename = "ade20k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + config.image_size = 640 + is_semantic = True else: - raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'") + raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'") # size of the architecture if "base" in checkpoint_url: @@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 + if "ade20k" in checkpoint_url: + config.image_size = 640 + config.out_indices = [7, 11, 15, 23] else: raise ValueError("Should either find 'base' or 'large' in checkpoint URL") # load state_dict of original model, remove and rename some keys - state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"] - rename_keys = create_rename_keys(config, has_lm_head=has_lm_head) + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True) + state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"] + + rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic) for src, dest in rename_keys: rename_key(state_dict, src, dest) - read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head) + read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic) + if is_semantic: + # add prefix to decoder keys + for key, val in state_dict.copy().items(): + val = state_dict.pop(key) + if key.startswith("backbone.fpn"): + key = key.replace("backbone.fpn", "fpn") + state_dict[key] = val # load HuggingFace model if checkpoint_url[-9:-4] == "pt22k": model = BeitForMaskedImageModeling(config) + elif "ade20k" in checkpoint_url: + model = BeitForSemanticSegmentation(config) else: model = BeitForImageClassification(config) model.eval() model.load_state_dict(state_dict) # Check outputs on an image - feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) - encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + if is_semantic: + feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False) + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + image = Image.open(ds[0]["file"]) + else: + feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) + image = prepare_img() + + encoding = feature_extractor(images=image, return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) @@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) expected_class_idx = 761 + elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): + expected_shape = (1, 150, 160, 160) + expected_logits = torch.tensor( + [ + [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], + [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], + [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], + ] + ) + elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"): + expected_shape = (1, 150, 160, 160) + expected_logits = torch.tensor( + [ + [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]], + [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]], + [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]], + ] + ) else: raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" - print("Shape of logits:", logits.shape) if not has_lm_head: - print("Predicted class idx:", logits.argmax(-1).item()) - assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" - assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" + if is_semantic: + assert torch.allclose( + logits[0, :3, :3, :3], expected_logits, atol=1e-3 + ), "First elements of logits not as expected" + else: + print("Predicted class idx:", logits.argmax(-1).item()) + assert torch.allclose( + logits[0, :3], expected_logits, atol=1e-3 + ), "First elements of logits not as expected" + assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 1ad3fcd1e6d1..bf496d92a7fc 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -163,6 +163,7 @@ def forward(self, pixel_values): f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x @@ -499,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @@ -851,3 +852,354 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class BeitConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + + return output + + +class BeitPyramidPoolingModule(nn.ModuleList): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + align_corners (bool): align_corners argument of F.interpolate. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, pool_scales, in_channels, channels, align_corners): + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + BeitConvModule(self.in_channels, self.channels, kernel_size=1), + ) + ) + + def forward(self, x): + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class BeitUperHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet + `_. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config): + super().__init__() + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = BeitPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = BeitConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = BeitConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states): + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +class BeitFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet + `_. + + Args: + config (BeitConfig): Configuration. + in_channels + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config, in_index=2, kernel_size=3, dilation=1): + super().__init__() + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + BeitConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + for i in range(self.num_convs - 1): + convs.append( + BeitConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = BeitConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states): + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +@add_start_docstrings( + """ + Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + BEIT_START_DOCSTRING, +) +class BeitForSemanticSegmentation(BeitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=False) + + # FPNs + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + nn.BatchNorm2d(config.hidden_size), + nn.GELU(), + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # Semantic segmentation head(s) + self.decode_head = BeitUperHead(config) + self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None + + self.init_weights() + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=255) + main_loss = loss_fct(upsampled_logits, labels) + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`): + Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed + (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640') + >>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height/4, width/4) + >>> logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] + batch_size = pixel_values.shape[0] + patch_resolution = self.config.image_size // self.config.patch_size + features = [ + x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features + ] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + logits = self.decode_head(features) + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[2:] + else: + output = (logits,) + outputs[3:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3ac8fcbd0ed4..b1fbe83fca3a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -600,6 +600,11 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BeitForSemanticSegmentation: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeitModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 6557936d59b0..f0a89031416c 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -18,6 +18,8 @@ import inspect import unittest +from datasets import load_dataset + from transformers import BeitConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.models.auto import get_values @@ -31,7 +33,13 @@ import torch from torch import nn - from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel + from transformers import ( + MODEL_MAPPING, + BeitForImageClassification, + BeitForMaskedImageModeling, + BeitForSemanticSegmentation, + BeitModel, + ) from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple @@ -53,7 +61,7 @@ def __init__( is_training=True, use_labels=True, hidden_size=32, - num_hidden_layers=5, + num_hidden_layers=4, num_attention_heads=4, intermediate_size=37, hidden_act="gelu", @@ -63,6 +71,7 @@ def __init__( initializer_range=0.02, num_labels=3, scope=None, + out_indices=[0, 1, 2, 3], ): self.parent = parent self.vocab_size = 100 @@ -82,6 +91,7 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.scope = scope + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -109,6 +119,7 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -160,7 +171,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): """ all_model_classes = ( - (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else () + (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation) + if is_torch_available() + else () ) test_pruning = False @@ -212,11 +225,14 @@ def test_training(self): config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING): - continue # we don't test BeitForMaskedImageModeling - if model_class.__name__ == "BeitForMaskedImageModeling": + if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]: continue + # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + # this can then be incorporated into _prepare_for_class in test_modeling_common.py + elif model_class.__name__ == "BeitForSemanticSegmentation": + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() model = model_class(config) model.to(torch_device) model.train() @@ -233,11 +249,17 @@ def test_training_gradient_checkpointing(self): config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: - continue # we don't test BeitForMaskedImageModeling - if model_class.__name__ == "BeitForMaskedImageModeling": + if ( + model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling] + or not model_class.supports_gradient_checkpointing + ): continue + # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + # this can then be incorporated into _prepare_for_class in test_modeling_common.py + elif model_class.__name__ == "BeitForSemanticSegmentation": + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() model = model_class(config) model.to(torch_device) model.train() @@ -298,7 +320,8 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertListEqual( @@ -316,15 +339,9 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) + self.assertEqual(out_len + 1, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( @@ -472,3 +489,32 @@ def test_inference_image_classification_head_imagenet_22k(self): expected_class_idx = 2396 self.assertEqual(logits.argmax(-1).item(), expected_class_idx) + + @slow + def test_inference_semantic_segmentation(self): + model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + model = model.to(torch_device) + + feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) + + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + image = Image.open(ds[0]["file"]) + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + outputs = model(**inputs) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 150, 160, 160)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = torch.tensor( + [ + [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], + [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], + [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], + ] + ).to(torch_device) + + self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3890198edbab..42330cc37dbf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -88,7 +88,7 @@ def _config_zero_init(config): configs_no_init = copy.deepcopy(config) for key in configs_no_init.__dict__.keys(): - if "_range" in key or "_std" in key or "initializer_factor" in key: + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: setattr(configs_no_init, key, 1e-10) return configs_no_init diff --git a/utils/check_repo.py b/utils/check_repo.py index a7839f318f02..40d980d93930 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -100,6 +100,7 @@ # models to ignore for model xxx mapping "SegformerDecodeHead", "SegformerForSemanticSegmentation", + "BeitForSemanticSegmentation", "FlaxBeitForMaskedImageModeling", "BeitForMaskedImageModeling", "CLIPTextModel",