From b17363284553c5ff9038e4014fd43a9fe7207f3e Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 15 Oct 2021 10:01:31 +0200 Subject: [PATCH 01/17] Add first draft --- src/transformers/models/beit/modeling_beit.py | 264 ++++++++++++++++++ 1 file changed, 264 insertions(+) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 1ad3fcd1e6d1..3d6e21c94ff1 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -851,3 +851,267 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class ConvModule(nn.Module): + def __init__(self): + super(ConvModule).__init__() + + def forward(self): + return -1 + + +class PPM(nn.ModuleList): + """ + Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, act_cfg, align_corners, **kwargs): + super(PPM, self).__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **kwargs, + ), + ) + ) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class UPerHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet + `_. + """ + + def __init__(self, config): + super(UPerHead).__init__() + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = config.hidden_sizes # e.g. [768, 768, 768, 768] + self.in_index = [0, 1, 2, 3] + self.channels = config.hidden_size + self.conv_cfg = None + self.norm_cfg = dict(type="SyncBN", requires_grad=True) + self.act_cfg = dict(type="ReLU") + self.align_corners = False + + # PSP Module + self.psp_modules = PPM( + self.pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners, + ) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False, + ) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False, + ) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, hidden_states): + # build laterals + laterals = [lateral_conv(hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.cls_seg(output) + return output + + +@add_start_docstrings( + """ + Beit Model transformer with an semantic segmentation head on top e.g. for ADE20k. + """, + BEIT_START_DOCSTRING, +) +class BeitForSemanticSegmentation(BeitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=True) + + # Semantic segmentation head + self.head = UPerHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import BeitFeatureExtractor, BeitForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') + >>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.head(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From bb52b18c6f5602b8b811642daab3eebc2ea4fc20 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 18 Oct 2021 16:50:30 +0200 Subject: [PATCH 02/17] Make forward pass work --- src/transformers/__init__.py | 2 + src/transformers/models/beit/__init__.py | 2 + .../models/beit/configuration_beit.py | 9 ++ src/transformers/models/beit/modeling_beit.py | 97 +++++++++++-------- src/transformers/models/beit/test_semantic.py | 14 +++ src/transformers/utils/dummy_pt_objects.py | 5 + 6 files changed, 87 insertions(+), 42 deletions(-) create mode 100644 src/transformers/models/beit/test_semantic.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9fab897b3e2c..a64f592134c7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -635,6 +635,7 @@ "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "BeitForImageClassification", "BeitForMaskedImageModeling", + "BeitForSemanticSegmentation", "BeitModel", "BeitPreTrainedModel", ] @@ -2477,6 +2478,7 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, BeitForImageClassification, BeitForMaskedImageModeling, + BeitForSemanticSegmentation, BeitModel, BeitPreTrainedModel, ) diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py index b530f59d5b3a..c9e311d7cfe6 100644 --- a/src/transformers/models/beit/__init__.py +++ b/src/transformers/models/beit/__init__.py @@ -33,6 +33,7 @@ "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "BeitForImageClassification", "BeitForMaskedImageModeling", + "BeitForSemanticSegmentation", "BeitModel", "BeitPreTrainedModel", ] @@ -57,6 +58,7 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, BeitForImageClassification, BeitForMaskedImageModeling, + BeitForSemanticSegmentation, BeitModel, BeitPreTrainedModel, ) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index d31f83dd3a5e..853a3d6dbfb8 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -78,6 +78,10 @@ class BeitConfig(PretrainedConfig): use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the CLS token, before applying the classification head. + out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`): + Indices of the feature maps to use for semantic segmentation. + pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 6)`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. Example:: @@ -117,6 +121,8 @@ def __init__( layer_scale_init_value=0.1, drop_path_rate=0.1, use_mean_pooling=True, + out_indices=[3, 5, 7, 11], + pool_scales=(1, 2, 3, 6), **kwargs ): super().__init__(**kwargs) @@ -142,3 +148,6 @@ def __init__( self.layer_scale_init_value = layer_scale_init_value self.drop_path_rate = drop_path_rate self.use_mean_pooling = use_mean_pooling + # semantic segmentation attributes + self.out_indices = out_indices + self.pool_scales = pool_scales diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 3d6e21c94ff1..a934c3493e35 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -163,6 +163,7 @@ def forward(self, pixel_values): f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x @@ -854,11 +855,26 @@ def forward( class ConvModule(nn.Module): - def __init__(self): - super(ConvModule).__init__() + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + """ - def forward(self): - return -1 + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvModule, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding + ) + self.norm = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input): + output = self.conv(input) + print("Shape of output after conv:", output.shape) + output = self.norm(output) + output = self.activation(output) + + return output class PPM(nn.ModuleList): @@ -870,21 +886,15 @@ class PPM(nn.ModuleList): Module. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. - conv_cfg (dict|None): Config of conv layers. - norm_cfg (dict|None): Config of norm layers. - act_cfg (dict): Config of activation layers. align_corners (bool): align_corners argument of F.interpolate. """ - def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, act_cfg, align_corners, **kwargs): + def __init__(self, pool_scales, in_channels, channels, align_corners): super(PPM, self).__init__() self.pool_scales = pool_scales self.align_corners = align_corners self.in_channels = in_channels self.channels = channels - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg for pool_scale in pool_scales: self.append( nn.Sequential( @@ -893,10 +903,6 @@ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, act_c self.in_channels, self.channels, 1, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, - **kwargs, ), ) ) @@ -920,25 +926,19 @@ class UPerHead(nn.Module): """ def __init__(self, config): - super(UPerHead).__init__() + super(UPerHead, self).__init__() self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) - self.in_channels = config.hidden_sizes # e.g. [768, 768, 768, 768] - self.in_index = [0, 1, 2, 3] + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] self.channels = config.hidden_size - self.conv_cfg = None - self.norm_cfg = dict(type="SyncBN", requires_grad=True) - self.act_cfg = dict(type="ReLU") self.align_corners = False + self.classifier = nn.Conv2d(config.hidden_size, config.num_labels, kernel_size=1) # PSP Module self.psp_modules = PPM( self.pool_scales, self.in_channels[-1], self.channels, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, align_corners=self.align_corners, ) self.bottleneck = ConvModule( @@ -946,9 +946,6 @@ def __init__(self, config): self.channels, 3, padding=1, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, ) # FPN Module self.lateral_convs = nn.ModuleList() @@ -958,20 +955,14 @@ def __init__(self, config): in_channels, self.channels, 1, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, - inplace=False, + # inplace=False, ) fpn_conv = ConvModule( self.channels, self.channels, 3, padding=1, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, - inplace=False, + # inplace=False, ) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) @@ -981,9 +972,6 @@ def __init__(self, config): self.channels, 3, padding=1, - conv_cfg=self.conv_cfg, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, ) def psp_forward(self, inputs): @@ -1021,7 +1009,8 @@ def forward(self, hidden_states): ) fpn_outs = torch.cat(fpn_outs, dim=1) output = self.fpn_bottleneck(fpn_outs) - output = self.cls_seg(output) + output = self.classifier(output) + return output @@ -1038,6 +1027,19 @@ def __init__(self, config): self.num_labels = config.num_labels self.beit = BeitModel(config, add_pooling_layer=True) + # FPNs + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + nn.BatchNorm2d(config.hidden_size), + nn.GELU(), + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + # Semantic segmentation head self.head = UPerHead(config) @@ -1087,13 +1089,24 @@ def forward( pixel_values, head_mask=head_mask, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_hidden_states=True, # we need the intermediate hidden states return_dict=return_dict, ) - pooled_output = outputs.pooler_output if return_dict else outputs[1] + # only keep certain features, and reshape + features = [feature for idx, feature in enumerate(outputs.hidden_states) if idx + 1 in self.config.out_indices] + batch_size = pixel_values.shape[0] + patch_resolution = self.config.image_size // self.config.patch_size + features = [ + x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features + ] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) - logits = self.head(pooled_output) + logits = self.head(features) loss = None if labels is not None: diff --git a/src/transformers/models/beit/test_semantic.py b/src/transformers/models/beit/test_semantic.py new file mode 100644 index 000000000000..3f85ff9b9bd1 --- /dev/null +++ b/src/transformers/models/beit/test_semantic.py @@ -0,0 +1,14 @@ +import torch + +from transformers import BeitConfig, BeitForSemanticSegmentation + + +config = BeitConfig(image_size=512) +model = BeitForSemanticSegmentation(config) +model.eval() + +pixel_values = torch.randn((1, 3, 512, 512)) + +outputs = model(pixel_values) + +print("Shape of logits:", outputs.logits.shape) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3ac8fcbd0ed4..b1fbe83fca3a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -600,6 +600,11 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BeitForSemanticSegmentation: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeitModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) From 81beac4cce17a90f85c96f2e1b19b4b466c740d7 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 10:34:17 +0200 Subject: [PATCH 03/17] Improve conversion script --- .../beit/convert_beit_unilm_to_pytorch.py | 113 ++++++++++++------ src/transformers/models/beit/modeling_beit.py | 12 +- 2 files changed, 84 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index c550a56db36f..22304f125e71 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -24,7 +24,13 @@ import requests from huggingface_hub import cached_download, hf_hub_url -from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling +from transformers import ( + BeitConfig, + BeitFeatureExtractor, + BeitForImageClassification, + BeitForMaskedImageModeling, + BeitForSemanticSegmentation, +) from transformers.utils import logging @@ -33,27 +39,33 @@ # here we list all keys to be renamed (original name on the left, our name on the right) -def create_rename_keys(config, has_lm_head=False): +def create_rename_keys(config, has_lm_head=False, is_semantic=False): + prefix = "backbone." if is_semantic else "" + rename_keys = [] for i in range(config.num_hidden_layers): # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms - rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) - rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) - rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")) - rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")) - rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) - rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) - rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) - rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) - rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) - rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) + rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias") + ) + rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) # projection layer + position embeddings rename_keys.extend( [ - ("cls_token", "beit.embeddings.cls_token"), - ("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), - ("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), + (f"{prefix}cls_token", "beit.embeddings.cls_token"), + (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), + (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), ] ) @@ -74,6 +86,14 @@ def create_rename_keys(config, has_lm_head=False): ("norm.bias", "layernorm.bias"), ] ) + elif is_semantic: + # semantic segmentation head + classification head + rename_keys.extend( + [ + ("decode_head.conv_seg.weight", "head.classifier.weight"), + ("decode_head.conv_seg.bias", "head.classifier.bias"), + ] + ) else: # layernorm + classification head rename_keys.extend( @@ -89,45 +109,45 @@ def create_rename_keys(config, has_lm_head=False): # we split up the matrix of each encoder layer into queries, keys and values -def read_in_q_k_v(state_dict, config, has_lm_head=False): +def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False): for i in range(config.num_hidden_layers): - prefix = "beit." + prefix = "backbone." if is_semantic else "" # queries, keys and values - in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") - q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias") - v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias") + in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight") + q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias") + v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias") - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ : config.hidden_size, : ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias + state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ config.hidden_size : config.hidden_size * 2, : ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ -config.hidden_size :, : ] - state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias + state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias # gamma_1 and gamma_2 # we call them lambda because otherwise they are renamed when using .from_pretrained - gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1") - gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2") + gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1") + gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2") - state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1 - state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2 + state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1 + state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2 # relative_position bias table + index if not has_lm_head: # each layer has its own relative position bias - table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table") - index = state_dict.pop(f"blocks.{i}.attn.relative_position_index") + table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table") + index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index") state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" + f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" ] = table state_dict[ - f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" + f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" ] = index @@ -152,6 +172,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): # define default BEiT configuration config = BeitConfig() has_lm_head = False + is_semantic = False repo_id = "datasets/huggingface/label-files" # set config parameters based on URL if checkpoint_url[-9:-4] == "pt22k": @@ -185,8 +206,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): config.image_size = 384 if "512" in checkpoint_url: config.image_size = 512 + elif "ade20k" in checkpoint_url: + # fine-tuning + config.use_relative_position_bias = True + config.num_labels = 150 + filename = "ade20k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + config.image_size = 640 + is_semantic = True else: - raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'") + raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'") # size of the architecture if "base" in checkpoint_url: @@ -196,19 +228,26 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): config.intermediate_size = 4096 config.num_hidden_layers = 24 config.num_attention_heads = 16 + if "ade20k" in checkpoint_url: + config.image_size = 640 + config.out_indices = [7, 11, 15, 23] else: raise ValueError("Should either find 'base' or 'large' in checkpoint URL") # load state_dict of original model, remove and rename some keys - state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"] - rename_keys = create_rename_keys(config, has_lm_head=has_lm_head) + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True) + state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"] + + rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic) for src, dest in rename_keys: rename_key(state_dict, src, dest) - read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head) + read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic) # load HuggingFace model if checkpoint_url[-9:-4] == "pt22k": model = BeitForMaskedImageModeling(config) + elif "ade20k" in checkpoint_url: + model = BeitForSemanticSegmentation(config) else: model = BeitForImageClassification(config) model.eval() diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index a934c3493e35..6526b58d83ff 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -863,14 +863,17 @@ class ConvModule(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, padding=0): super(ConvModule, self).__init__() self.conv = nn.Conv2d( - in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, ) self.norm = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU() def forward(self, input): output = self.conv(input) - print("Shape of output after conv:", output.shape) output = self.norm(output) output = self.activation(output) @@ -879,7 +882,7 @@ def forward(self, input): class PPM(nn.ModuleList): """ - Pooling Pyramid Module used in PSPNet. + Pyramid Pooling Module (PPM) used in PSPNet. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid @@ -1025,7 +1028,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.beit = BeitModel(config, add_pooling_layer=True) + self.beit = BeitModel(config, add_pooling_layer=False) # FPNs self.fpn1 = nn.Sequential( @@ -1094,6 +1097,7 @@ def forward( ) # only keep certain features, and reshape + # note that we do +1 as outputs.hidden_states also includes the initial embeddings features = [feature for idx, feature in enumerate(outputs.hidden_states) if idx + 1 in self.config.out_indices] batch_size = pixel_values.shape[0] patch_resolution = self.config.image_size // self.config.patch_size From f0bd61e15d08bdc9758745b03dce137871998ec7 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 12:02:06 +0200 Subject: [PATCH 04/17] Add notebook that checks if it works --- .../beit/convert_beit_unilm_to_pytorch.py | 23 ++++++++++++++++--- src/transformers/models/beit/modeling_beit.py | 4 ++-- src/transformers/models/beit/test.ipynb | 0 src/transformers/models/beit/test_semantic.py | 22 +++++++++++------- 4 files changed, 36 insertions(+), 13 deletions(-) create mode 100644 src/transformers/models/beit/test.ipynb diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 22304f125e71..88e7735e1cad 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -242,6 +242,22 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): for src, dest in rename_keys: rename_key(state_dict, src, dest) read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic) + if is_semantic: + # add prefix to decoder keys + for key, val in state_dict.copy().items(): + print("Old key:", key) + val = state_dict.pop(key) + if key.startswith("decode_head"): + key = key.replace("decode_head", "head") + elif key.startswith("backbone.fpn"): + key = key.replace("backbone.fpn", "fpn") + + if "auxiliary_head" in key: + # we skip the auxiliary head for now + pass + else: + print("Setting new key:", key) + state_dict[key] = val # load HuggingFace model if checkpoint_url[-9:-4] == "pt22k": @@ -296,12 +312,13 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) expected_class_idx = 761 + elif is_semantic: + expected_shape = (1, 150, 160, 160) else: raise ValueError("Can't verify logits as model is not supported") - + assert logits.shape == expected_shape, "Shape of logits not as expected" - print("Shape of logits:", logits.shape) - if not has_lm_head: + if not has_lm_head and not is_semantic: print("Predicted class idx:", logits.argmax(-1).item()) assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 6526b58d83ff..a8fa00e70180 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -869,12 +869,12 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0): padding=padding, bias=False, ) - self.norm = nn.BatchNorm2d(out_channels) + self.bn = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU() def forward(self, input): output = self.conv(input) - output = self.norm(output) + output = self.bn(output) output = self.activation(output) return output diff --git a/src/transformers/models/beit/test.ipynb b/src/transformers/models/beit/test.ipynb new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/models/beit/test_semantic.py b/src/transformers/models/beit/test_semantic.py index 3f85ff9b9bd1..fbba9cc0b2d9 100644 --- a/src/transformers/models/beit/test_semantic.py +++ b/src/transformers/models/beit/test_semantic.py @@ -1,14 +1,20 @@ -import torch +from datasets import load_dataset +from PIL import Image +from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation -from transformers import BeitConfig, BeitForSemanticSegmentation +# load image + ground truth map +ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") +image = Image.open(ds[0]["file"]) +segmentation_map = Image.open(ds[1]["file"]) +# load model +model_name = "nielsr/beit-test" +feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) +model = BeitForSemanticSegmentation.from_pretrained(model_name) -config = BeitConfig(image_size=512) -model = BeitForSemanticSegmentation(config) -model.eval() - -pixel_values = torch.randn((1, 3, 512, 512)) - +pixel_values = feature_extractor(image, return_tensors="pt").pixel_values outputs = model(pixel_values) +logits = outputs.logits print("Shape of logits:", outputs.logits.shape) +print("First elements of logits:", logits[0,:3,:3,:3]) From c0e093d9bb205896e832ed221d9c5659e22aede3 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 14:35:20 +0200 Subject: [PATCH 05/17] Add BeitForSemanticSegmentation to the tests --- .../models/beit/configuration_beit.py | 4 +- .../beit/convert_beit_unilm_to_pytorch.py | 5 +- src/transformers/models/beit/modeling_beit.py | 14 ++-- src/transformers/models/beit/test_semantic.py | 6 +- tests/test_modeling_beit.py | 68 ++++++++++++++++--- 5 files changed, 73 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 853a3d6dbfb8..32347b6e6ac3 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -80,7 +80,7 @@ class BeitConfig(PretrainedConfig): CLS token, before applying the classification head. out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`): Indices of the feature maps to use for semantic segmentation. - pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 6)`): + pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`): Pooling scales used in Pooling Pyramid Module applied on the last feature map. Example:: @@ -122,7 +122,7 @@ def __init__( drop_path_rate=0.1, use_mean_pooling=True, out_indices=[3, 5, 7, 11], - pool_scales=(1, 2, 3, 6), + pool_scales=[1, 2, 3, 6], **kwargs ): super().__init__(**kwargs) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 88e7735e1cad..51e2e4f95093 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -245,18 +245,15 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): if is_semantic: # add prefix to decoder keys for key, val in state_dict.copy().items(): - print("Old key:", key) val = state_dict.pop(key) if key.startswith("decode_head"): key = key.replace("decode_head", "head") elif key.startswith("backbone.fpn"): key = key.replace("backbone.fpn", "fpn") - if "auxiliary_head" in key: # we skip the auxiliary head for now pass else: - print("Setting new key:", key) state_dict[key] = val # load HuggingFace model @@ -316,7 +313,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): expected_shape = (1, 150, 160, 160) else: raise ValueError("Can't verify logits as model is not supported") - + assert logits.shape == expected_shape, "Shape of logits not as expected" if not has_lm_head and not is_semantic: print("Predicted class idx:", logits.argmax(-1).item()) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index a8fa00e70180..df5d8ec41657 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -1114,13 +1114,15 @@ def forward( loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + loss_fct = CrossEntropyLoss(ignore_index=255) + loss = loss_fct(upsampled_logits, labels) if not return_dict: output = (logits,) + outputs[2:] diff --git a/src/transformers/models/beit/test_semantic.py b/src/transformers/models/beit/test_semantic.py index fbba9cc0b2d9..ffc34346bc54 100644 --- a/src/transformers/models/beit/test_semantic.py +++ b/src/transformers/models/beit/test_semantic.py @@ -1,14 +1,16 @@ from datasets import load_dataset from PIL import Image + from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation + # load image + ground truth map ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") image = Image.open(ds[0]["file"]) segmentation_map = Image.open(ds[1]["file"]) # load model -model_name = "nielsr/beit-test" +model_name = "nielsr/beit-base-finetuned-ade20k" feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) model = BeitForSemanticSegmentation.from_pretrained(model_name) @@ -17,4 +19,4 @@ logits = outputs.logits print("Shape of logits:", outputs.logits.shape) -print("First elements of logits:", logits[0,:3,:3,:3]) +print("First elements of logits:", logits[0, :3, :3, :3]) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 6557936d59b0..1381afc6af1b 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -31,7 +31,13 @@ import torch from torch import nn - from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel + from transformers import ( + MODEL_MAPPING, + BeitForImageClassification, + BeitForMaskedImageModeling, + BeitForSemanticSegmentation, + BeitModel, + ) from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple @@ -53,7 +59,7 @@ def __init__( is_training=True, use_labels=True, hidden_size=32, - num_hidden_layers=5, + num_hidden_layers=4, num_attention_heads=4, intermediate_size=37, hidden_act="gelu", @@ -63,6 +69,7 @@ def __init__( initializer_range=0.02, num_labels=3, scope=None, + out_indices=[0,1,2,3], ): self.parent = parent self.vocab_size = 100 @@ -82,6 +89,7 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.scope = scope + self.out_indices = out_indices def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -109,6 +117,7 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + out_indices=self.out_indices, ) def create_and_check_model(self, config, pixel_values, labels): @@ -159,8 +168,10 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): attention_mask and seq_length. """ + maxDiff = None + all_model_classes = ( - (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else () + (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation) if is_torch_available() else () ) test_pruning = False @@ -212,11 +223,14 @@ def test_training(self): config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING): - continue # we don't test BeitForMaskedImageModeling - if model_class.__name__ == "BeitForMaskedImageModeling": + if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]: continue + # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + # this can then be incorporated into _prepare_for_class in test_modeling_common.py + elif model_class.__name__ == "BeitForSemanticSegmentation": + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() model = model_class(config) model.to(torch_device) model.train() @@ -233,11 +247,14 @@ def test_training_gradient_checkpointing(self): config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: - continue # we don't test BeitForMaskedImageModeling - if model_class.__name__ == "BeitForMaskedImageModeling": + if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling] or not model_class.supports_gradient_checkpointing: continue + # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + # this can then be incorporated into _prepare_for_class in test_modeling_common.py + elif model_class.__name__ == "BeitForSemanticSegmentation": + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() model = model_class(config) model.to(torch_device) model.train() @@ -378,7 +395,7 @@ def test_for_masked_lm(self): def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) - + @slow def test_model_from_pretrained(self): for model_name in BEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -472,3 +489,34 @@ def test_inference_image_classification_head_imagenet_22k(self): expected_class_idx = 2396 self.assertEqual(logits.argmax(-1).item(), expected_class_idx) + + @slow + def test_inference_semantic_segmentation(self): + # TODO rename nielsr to microsoft + model = BeitForSemanticSegmentation.from_pretrained("nielsr/beit-base-finetuned-ade20k").to(torch_device) + + feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) + + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + image = Image.open(ds[0]["file"]) + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + outputs = model(**inputs) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 150, 160, 160)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = torch.tensor( + [ + [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], + [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], + [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], + ] + ).to(torch_device) + + self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) From fd4c18c1db239ce80f390db19b3b399c7ea918a7 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 15:14:38 +0200 Subject: [PATCH 06/17] More improvements --- .../beit/convert_beit_unilm_to_pytorch.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 51e2e4f95093..e5dc1fad9f10 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -24,6 +24,7 @@ import requests from huggingface_hub import cached_download, hf_hub_url +from datasets import load_dataset from transformers import ( BeitConfig, BeitFeatureExtractor, @@ -267,8 +268,15 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): model.load_state_dict(state_dict) # Check outputs on an image - feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) - encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + if is_semantic: + feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False) + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + image = Image.open(ds[0]["file"]) + else: + feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) + image = prepare_img() + + encoding = feature_extractor(images=image, return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) @@ -309,16 +317,28 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) expected_class_idx = 761 - elif is_semantic: + elif checkpoint_url[-4:].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): expected_shape = (1, 150, 160, 160) + expected_logits = torch.tensor( + [ + [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], + [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], + [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], + ] + ) + elif checkpoint_url[-4:].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): + raise NotImplementedError("To do") else: raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" - if not has_lm_head and not is_semantic: - print("Predicted class idx:", logits.argmax(-1).item()) - assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" - assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" + if not has_lm_head: + if is_semantic: + assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" + else: + print("Predicted class idx:", logits.argmax(-1).item()) + assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" + assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") From 4b7c6c1f79b84c512e14c4854989cbfd200fa08d Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 15:37:00 +0200 Subject: [PATCH 07/17] Make BeitForSemanticSegmentation consistent with Segformer --- .../models/beit/convert_beit_unilm_to_pytorch.py | 14 ++++++-------- src/transformers/models/beit/modeling_beit.py | 16 +++++++++------- src/transformers/models/beit/test_semantic.py | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index e5dc1fad9f10..86965718b205 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -91,8 +91,8 @@ def create_rename_keys(config, has_lm_head=False, is_semantic=False): # semantic segmentation head + classification head rename_keys.extend( [ - ("decode_head.conv_seg.weight", "head.classifier.weight"), - ("decode_head.conv_seg.bias", "head.classifier.bias"), + ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), + ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), ] ) else: @@ -247,9 +247,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): # add prefix to decoder keys for key, val in state_dict.copy().items(): val = state_dict.pop(key) - if key.startswith("decode_head"): - key = key.replace("decode_head", "head") - elif key.startswith("backbone.fpn"): + if key.startswith("backbone.fpn"): key = key.replace("backbone.fpn", "fpn") if "auxiliary_head" in key: # we skip the auxiliary head for now @@ -317,7 +315,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) expected_class_idx = 761 - elif checkpoint_url[-4:].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): + elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): expected_shape = (1, 150, 160, 160) expected_logits = torch.tensor( [ @@ -326,14 +324,14 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], ] ) - elif checkpoint_url[-4:].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): + elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"): raise NotImplementedError("To do") else: raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" if not has_lm_head: - if is_semantic: + if is_semantic and "base" in checkpoint_url: assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" else: print("Predicted class idx:", logits.argmax(-1).item()) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index df5d8ec41657..db143b476eaa 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -987,11 +987,11 @@ def psp_forward(self, inputs): return output - def forward(self, hidden_states): + def forward(self, encoder_hidden_states): # build laterals - laterals = [lateral_conv(hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] - laterals.append(self.psp_forward(hidden_states)) + laterals.append(self.psp_forward(encoder_hidden_states)) # build top-down path used_backbone_levels = len(laterals) @@ -1044,7 +1044,7 @@ def __init__(self, config): self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) # Semantic segmentation head - self.head = UPerHead(config) + self.decode_head = UPerHead(config) self.init_weights() @@ -1096,9 +1096,11 @@ def forward( return_dict=return_dict, ) + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + # only keep certain features, and reshape - # note that we do +1 as outputs.hidden_states also includes the initial embeddings - features = [feature for idx, feature in enumerate(outputs.hidden_states) if idx + 1 in self.config.out_indices] + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] batch_size = pixel_values.shape[0] patch_resolution = self.config.image_size // self.config.patch_size features = [ @@ -1110,7 +1112,7 @@ def forward( for i in range(len(features)): features[i] = ops[i](features[i]) - logits = self.head(features) + logits = self.decode_head(features) loss = None if labels is not None: diff --git a/src/transformers/models/beit/test_semantic.py b/src/transformers/models/beit/test_semantic.py index ffc34346bc54..d67f96321925 100644 --- a/src/transformers/models/beit/test_semantic.py +++ b/src/transformers/models/beit/test_semantic.py @@ -10,7 +10,7 @@ segmentation_map = Image.open(ds[1]["file"]) # load model -model_name = "nielsr/beit-base-finetuned-ade20k" +model_name = "/Users/NielsRogge/Documents/BEIT/beit-base-finetuned-ade20k" feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) model = BeitForSemanticSegmentation.from_pretrained(model_name) From 2b795110094592a2c86e4c5cc7a6aa93b4b49f5d Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 16:06:53 +0200 Subject: [PATCH 08/17] Small bug fix --- src/transformers/models/beit/modeling_beit.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index db143b476eaa..17cd2be04733 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -922,14 +922,14 @@ def forward(self, x): return ppm_outs -class UPerHead(nn.Module): +class BeitUperHead(nn.Module): """ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet `_. """ def __init__(self, config): - super(UPerHead, self).__init__() + super(BeitUperHead, self).__init__() self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] @@ -1044,7 +1044,7 @@ def __init__(self, config): self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) # Semantic segmentation head - self.decode_head = UPerHead(config) + self.decode_head = BeitUperHead(config) self.init_weights() @@ -1096,7 +1096,7 @@ def forward( return_dict=return_dict, ) - encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2] # only keep certain features, and reshape # note that we do +1 as the encoder_hidden_states also includes the initial embeddings @@ -1125,7 +1125,7 @@ def forward( ) loss_fct = CrossEntropyLoss(ignore_index=255) loss = loss_fct(upsampled_logits, labels) - + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output From cb52ae9ec0a1c22da68f9a8ed285c45b853d230e Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 19 Oct 2021 17:01:02 +0200 Subject: [PATCH 09/17] Add BeitForSemanticSegmentation to docs --- docs/source/model_doc/beit.rst | 7 ++++++ .../beit/convert_beit_unilm_to_pytorch.py | 14 +++++++----- src/transformers/models/beit/modeling_beit.py | 22 +++++++++---------- tests/test_modeling_beit.py | 17 +++++++++----- utils/check_repo.py | 1 + 5 files changed, 38 insertions(+), 23 deletions(-) diff --git a/docs/source/model_doc/beit.rst b/docs/source/model_doc/beit.rst index 238d5f46a28e..af658cf60dbc 100644 --- a/docs/source/model_doc/beit.rst +++ b/docs/source/model_doc/beit.rst @@ -98,6 +98,13 @@ BeitForImageClassification :members: forward +BeitForSemanticSegmentation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitForSemanticSegmentation + :members: forward + + FlaxBeitModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 86965718b205..73e290a3366b 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -20,11 +20,11 @@ from pathlib import Path import torch +from datasets import load_dataset from PIL import Image import requests from huggingface_hub import cached_download, hf_hub_url -from datasets import load_dataset from transformers import ( BeitConfig, BeitFeatureExtractor, @@ -269,11 +269,11 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): if is_semantic: feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False) ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") - image = Image.open(ds[0]["file"]) + image = Image.open(ds[0]["file"]) else: feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) image = prepare_img() - + encoding = feature_extractor(images=image, return_tensors="pt") pixel_values = encoding["pixel_values"] @@ -332,10 +332,14 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): assert logits.shape == expected_shape, "Shape of logits not as expected" if not has_lm_head: if is_semantic and "base" in checkpoint_url: - assert torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" + assert torch.allclose( + logits[0, :3, :3, :3], expected_logits, atol=1e-3 + ), "First elements of logits not as expected" else: print("Predicted class idx:", logits.argmax(-1).item()) - assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" + assert torch.allclose( + logits[0, :3], expected_logits, atol=1e-3 + ), "First elements of logits not as expected" assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" Path(pytorch_dump_folder_path).mkdir(exist_ok=True) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 17cd2be04733..3a5111586708 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -1019,7 +1019,7 @@ def forward(self, encoder_hidden_states): @add_start_docstrings( """ - Beit Model transformer with an semantic segmentation head on top e.g. for ADE20k. + Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. """, BEIT_START_DOCSTRING, ) @@ -1060,31 +1060,29 @@ def forward( return_dict=None, ): r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): - Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., - config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`): + Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed + (Cross-Entropy). Returns: Examples:: - >>> from transformers import BeitFeatureExtractor, BeitForImageClassification + >>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) - >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') - >>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade20k') + >>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade20k') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height/4, width/4) >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = logits.argmax(-1).item() - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1125,7 +1123,7 @@ def forward( ) loss_fct = CrossEntropyLoss(ignore_index=255) loss = loss_fct(upsampled_logits, labels) - + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 1381afc6af1b..1ebdaa791cba 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -69,7 +69,7 @@ def __init__( initializer_range=0.02, num_labels=3, scope=None, - out_indices=[0,1,2,3], + out_indices=[0, 1, 2, 3], ): self.parent = parent self.vocab_size = 100 @@ -169,9 +169,11 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): """ maxDiff = None - + all_model_classes = ( - (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation) if is_torch_available() else () + (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation) + if is_torch_available() + else () ) test_pruning = False @@ -248,7 +250,10 @@ def test_training_gradient_checkpointing(self): for model_class in self.all_model_classes: # we don't test BeitForMaskedImageModeling - if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling] or not model_class.supports_gradient_checkpointing: + if ( + model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling] + or not model_class.supports_gradient_checkpointing + ): continue # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING # this can then be incorporated into _prepare_for_class in test_modeling_common.py @@ -395,7 +400,7 @@ def test_for_masked_lm(self): def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) - + @slow def test_model_from_pretrained(self): for model_name in BEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -496,7 +501,7 @@ def test_inference_semantic_segmentation(self): model = BeitForSemanticSegmentation.from_pretrained("nielsr/beit-base-finetuned-ade20k").to(torch_device) feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) - + from datasets import load_dataset ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") diff --git a/utils/check_repo.py b/utils/check_repo.py index a7839f318f02..40d980d93930 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -100,6 +100,7 @@ # models to ignore for model xxx mapping "SegformerDecodeHead", "SegformerForSemanticSegmentation", + "BeitForSemanticSegmentation", "FlaxBeitForMaskedImageModeling", "BeitForMaskedImageModeling", "CLIPTextModel", From 81c098f06943c2043b01c51869e12aff298b868f Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 20 Oct 2021 11:21:22 +0200 Subject: [PATCH 10/17] Make sure model doesn't output hidden states when the user doesn't want to --- src/transformers/models/beit/modeling_beit.py | 10 ++++++++-- tests/test_modeling_beit.py | 13 ++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 3a5111586708..df1151b4f19c 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -1085,6 +1085,9 @@ def forward( >>> logits = outputs.logits """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) outputs = self.beit( pixel_values, @@ -1125,12 +1128,15 @@ def forward( loss = loss_fct(upsampled_logits, labels) if not return_dict: - output = (logits,) + outputs[2:] + if output_hidden_states: + output = (logits,) + outputs[2:] + else: + output = (logits,) + outputs[3:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, - hidden_states=outputs.hidden_states, + hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, ) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 1ebdaa791cba..7bf975b2040e 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -320,7 +320,8 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + attentions = outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertListEqual( @@ -338,15 +339,9 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) + self.assertEqual(out_len + 1, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( From b7bb9926e3a9e90ecb809a5661c0158531f954b6 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 20 Oct 2021 11:24:35 +0200 Subject: [PATCH 11/17] Make it possible to convert the large model --- .../models/beit/convert_beit_unilm_to_pytorch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 73e290a3366b..2af7ad956b7b 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -325,17 +325,18 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): ] ) elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"): - raise NotImplementedError("To do") + expected_shape = (1, 150, 160, 160) else: raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" + print("First elements of logits:", logits[0, :3, :3, :3]) if not has_lm_head: if is_semantic and "base" in checkpoint_url: assert torch.allclose( logits[0, :3, :3, :3], expected_logits, atol=1e-3 ), "First elements of logits not as expected" - else: + elif not is_semantic: print("Predicted class idx:", logits.argmax(-1).item()) assert torch.allclose( logits[0, :3], expected_logits, atol=1e-3 From 28afb5a9b33ff98be08f364cf0224417f700d02c Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 20 Oct 2021 11:41:22 +0200 Subject: [PATCH 12/17] Fix issue --- src/transformers/models/beit/convert_beit_unilm_to_pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 2af7ad956b7b..30ceb20da881 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -330,6 +330,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" + print("Shape of logits:", logits.shape) print("First elements of logits:", logits[0, :3, :3, :3]) if not has_lm_head: if is_semantic and "base" in checkpoint_url: From 59174a8032a9c0ee2872cf9208bb6590a04964ad Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 20 Oct 2021 13:10:59 +0200 Subject: [PATCH 13/17] Fix conversion script for large model --- .../models/beit/convert_beit_unilm_to_pytorch.py | 13 +++++++++---- src/transformers/models/beit/modeling_beit.py | 4 ++-- tests/test_modeling_beit.py | 3 +-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 30ceb20da881..27fd73b657c8 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -326,18 +326,23 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): ) elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"): expected_shape = (1, 150, 160, 160) + expected_logits = torch.tensor( + [ + [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]], + [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]], + [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]], + ] + ) else: raise ValueError("Can't verify logits as model is not supported") assert logits.shape == expected_shape, "Shape of logits not as expected" - print("Shape of logits:", logits.shape) - print("First elements of logits:", logits[0, :3, :3, :3]) if not has_lm_head: - if is_semantic and "base" in checkpoint_url: + if is_semantic: assert torch.allclose( logits[0, :3, :3, :3], expected_logits, atol=1e-3 ), "First elements of logits not as expected" - elif not is_semantic: + else: print("Predicted class idx:", logits.argmax(-1).item()) assert torch.allclose( logits[0, :3], expected_logits, atol=1e-3 diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index df1151b4f19c..b83078415400 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -1076,8 +1076,8 @@ def forward( >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) - >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade20k') - >>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade20k') + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640') + >>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 7bf975b2040e..31acf40ffc0e 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -492,8 +492,7 @@ def test_inference_image_classification_head_imagenet_22k(self): @slow def test_inference_semantic_segmentation(self): - # TODO rename nielsr to microsoft - model = BeitForSemanticSegmentation.from_pretrained("nielsr/beit-base-finetuned-ade20k").to(torch_device) + model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640").to(torch_device) feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) From 85582b2147cd8cec4e1a87f91e56c7e6e6020195 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 20 Oct 2021 15:46:27 +0200 Subject: [PATCH 14/17] Add auxiliary_head option to semantic segmentation model --- .../models/beit/configuration_beit.py | 23 +++- .../beit/convert_beit_unilm_to_pytorch.py | 10 +- src/transformers/models/beit/modeling_beit.py | 112 ++++++++++++++++-- tests/test_modeling_beit.py | 4 +- 4 files changed, 131 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 32347b6e6ac3..477042db8842 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -82,6 +82,16 @@ class BeitConfig(PretrainedConfig): Indices of the feature maps to use for semantic segmentation. pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`): Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to use an auxiliary head during training. + loss_weight (:obj:`float`, `optional`, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + channels (:obj:`int`, `optional`, defaults to 256): + Number of channels to use in the auxiliary head. + num_convs (:obj:`int`, `optional`, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. Example:: @@ -123,6 +133,11 @@ def __init__( use_mean_pooling=True, out_indices=[3, 5, 7, 11], pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + loss_weight=0.4, + channels=256, + num_convs=1, + concat_input=False, **kwargs ): super().__init__(**kwargs) @@ -148,6 +163,12 @@ def __init__( self.layer_scale_init_value = layer_scale_init_value self.drop_path_rate = drop_path_rate self.use_mean_pooling = use_mean_pooling - # semantic segmentation attributes + # decode head attributes (semantic segmentation) self.out_indices = out_indices self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.loss_weight = loss_weight + self.channels = channels + self.num_convs = num_convs + self.concat_input = concat_input diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py index 27fd73b657c8..89e85fa8fe90 100644 --- a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -88,11 +88,13 @@ def create_rename_keys(config, has_lm_head=False, is_semantic=False): ] ) elif is_semantic: - # semantic segmentation head + classification head + # semantic segmentation classification heads rename_keys.extend( [ ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), + ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), + ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), ] ) else: @@ -249,11 +251,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): val = state_dict.pop(key) if key.startswith("backbone.fpn"): key = key.replace("backbone.fpn", "fpn") - if "auxiliary_head" in key: - # we skip the auxiliary head for now - pass - else: - state_dict[key] = val + state_dict[key] = val # load HuggingFace model if checkpoint_url[-9:-4] == "pt22k": diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index b83078415400..61134817fc37 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -858,16 +858,19 @@ class ConvModule(nn.Module): """ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ - def __init__(self, in_channels, out_channels, kernel_size, padding=0): + def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1): super(ConvModule, self).__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, - bias=False, + bias=bias, + dilation=dilation, ) self.bn = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU() @@ -890,6 +893,8 @@ class PPM(nn.ModuleList): in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. align_corners (bool): align_corners argument of F.interpolate. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ def __init__(self, pool_scales, in_channels, channels, align_corners): @@ -926,6 +931,8 @@ class BeitUperHead(nn.Module): """ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet `_. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ def __init__(self, config): @@ -935,7 +942,7 @@ def __init__(self, config): self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] self.channels = config.hidden_size self.align_corners = False - self.classifier = nn.Conv2d(config.hidden_size, config.num_labels, kernel_size=1) + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) # PSP Module self.psp_modules = PPM( @@ -1017,6 +1024,75 @@ def forward(self, encoder_hidden_states): return output +class BeitFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet + `_. + + Args: + config (BeitConfig): Configuration. + in_channels + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config, in_index=2, kernel_size=3, dilation=1): + super(BeitFCNHead, self).__init__() + self.in_channels = config.hidden_size + self.channels = config.channels + self.num_convs = config.num_convs + self.concat_input = config.concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + ConvModule( + self.in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + ) + ) + for i in range(self.num_convs - 1): + convs.append( + ConvModule( + self.channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states): + """Forward function.""" + # just take the relevant feature maps + x = encoder_hidden_states[self.in_index] + output = self.convs(x) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.classifier(output) + return output + + @add_start_docstrings( """ Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. @@ -1043,11 +1119,29 @@ def __init__(self, config): self.fpn3 = nn.Identity() self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) - # Semantic segmentation head + # Semantic segmentation head(s) self.decode_head = BeitUperHead(config) + self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None self.init_weights() + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=255) + loss = loss_fct(upsampled_logits, labels) + self.config.loss_weight * loss_fct( + upsampled_auxiliary_logits, labels + ) + + return loss + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1114,18 +1208,16 @@ def forward( features[i] = ops[i](features[i]) logits = self.decode_head(features) + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) loss = None if labels is not None: if self.config.num_labels == 1: raise ValueError("The number of labels should be greater than one") else: - # upsample logits to the images' original size - upsampled_logits = nn.functional.interpolate( - logits, size=labels.shape[-2:], mode="bilinear", align_corners=False - ) - loss_fct = CrossEntropyLoss(ignore_index=255) - loss = loss_fct(upsampled_logits, labels) + loss = self.compute_loss(logits, auxiliary_logits, labels) if not return_dict: if output_hidden_states: diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 31acf40ffc0e..5bc4b01ea856 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -492,7 +492,9 @@ def test_inference_image_classification_head_imagenet_22k(self): @slow def test_inference_semantic_segmentation(self): - model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640").to(torch_device) + model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640").to( + torch_device + ) feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) From dfae213f961c60038f26ec9c68f30b660f1f3f69 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 25 Oct 2021 11:42:52 +0200 Subject: [PATCH 15/17] Apply suggestions from @sgugger's review --- src/transformers/models/beit/modeling_beit.py | 81 ++++++------------- src/transformers/models/beit/test.ipynb | 0 src/transformers/models/beit/test_semantic.py | 22 ----- tests/test_modeling_beit.py | 9 +-- 4 files changed, 30 insertions(+), 82 deletions(-) delete mode 100644 src/transformers/models/beit/test.ipynb delete mode 100644 src/transformers/models/beit/test_semantic.py diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 61134817fc37..acd9d4433059 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -854,7 +854,7 @@ def forward( ) -class ConvModule(nn.Module): +class BeitConvModule(nn.Module): """ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). @@ -863,7 +863,7 @@ class ConvModule(nn.Module): """ def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1): - super(ConvModule, self).__init__() + super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -883,7 +883,7 @@ def forward(self, input): return output -class PPM(nn.ModuleList): +class BeitPyramidPoolingModule(nn.ModuleList): """ Pyramid Pooling Module (PPM) used in PSPNet. @@ -898,7 +898,7 @@ class PPM(nn.ModuleList): """ def __init__(self, pool_scales, in_channels, channels, align_corners): - super(PPM, self).__init__() + super().__init__() self.pool_scales = pool_scales self.align_corners = align_corners self.in_channels = in_channels @@ -907,16 +907,11 @@ def __init__(self, pool_scales, in_channels, channels, align_corners): self.append( nn.Sequential( nn.AdaptiveAvgPool2d(pool_scale), - ConvModule( - self.in_channels, - self.channels, - 1, - ), + BeitConvModule(self.in_channels, self.channels, kernel_size=1), ) ) def forward(self, x): - """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(x) @@ -936,7 +931,7 @@ class BeitUperHead(nn.Module): """ def __init__(self, config): - super(BeitUperHead, self).__init__() + super().__init__() self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] @@ -945,47 +940,35 @@ def __init__(self, config): self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) # PSP Module - self.psp_modules = PPM( + self.psp_modules = BeitPyramidPoolingModule( self.pool_scales, self.in_channels[-1], self.channels, align_corners=self.align_corners, ) - self.bottleneck = ConvModule( + self.bottleneck = BeitConvModule( self.in_channels[-1] + len(self.pool_scales) * self.channels, self.channels, - 3, + kernel_size=3, padding=1, ) # FPN Module self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for in_channels in self.in_channels[:-1]: # skip the top layer - l_conv = ConvModule( - in_channels, - self.channels, - 1, - # inplace=False, - ) - fpn_conv = ConvModule( - self.channels, - self.channels, - 3, - padding=1, - # inplace=False, - ) + l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) - self.fpn_bottleneck = ConvModule( + self.fpn_bottleneck = BeitConvModule( len(self.in_channels) * self.channels, self.channels, - 3, + kernel_size=3, padding=1, ) def psp_forward(self, inputs): - """Forward function of PSP module.""" x = inputs[-1] psp_outs = [x] psp_outs.extend(self.psp_modules(x)) @@ -1040,7 +1023,7 @@ class BeitFCNHead(nn.Module): """ def __init__(self, config, in_index=2, kernel_size=3, dilation=1): - super(BeitFCNHead, self).__init__() + super().__init__() self.in_channels = config.hidden_size self.channels = config.channels self.num_convs = config.num_convs @@ -1050,22 +1033,14 @@ def __init__(self, config, in_index=2, kernel_size=3, dilation=1): conv_padding = (kernel_size // 2) * dilation convs = [] convs.append( - ConvModule( - self.in_channels, - self.channels, - kernel_size=kernel_size, - padding=conv_padding, - dilation=dilation, + BeitConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation ) ) for i in range(self.num_convs - 1): convs.append( - ConvModule( - self.channels, - self.channels, - kernel_size=kernel_size, - padding=conv_padding, - dilation=dilation, + BeitConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation ) ) if self.num_convs == 0: @@ -1073,22 +1048,18 @@ def __init__(self, config, in_index=2, kernel_size=3, dilation=1): else: self.convs = nn.Sequential(*convs) if self.concat_input: - self.conv_cat = ConvModule( - self.in_channels + self.channels, - self.channels, - kernel_size=kernel_size, - padding=kernel_size // 2, + self.conv_cat = BeitConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 ) self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) def forward(self, encoder_hidden_states): - """Forward function.""" # just take the relevant feature maps - x = encoder_hidden_states[self.in_index] - output = self.convs(x) + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) if self.concat_input: - output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) output = self.classifier(output) return output @@ -1136,9 +1107,9 @@ def compute_loss(self, logits, auxiliary_logits, labels): ) # compute weighted loss loss_fct = CrossEntropyLoss(ignore_index=255) - loss = loss_fct(upsampled_logits, labels) + self.config.loss_weight * loss_fct( - upsampled_auxiliary_logits, labels - ) + main_loss = loss_fct(upsampled_logits, labels) + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss = main_loss + self.config.loss_weight * auxiliary_loss return loss diff --git a/src/transformers/models/beit/test.ipynb b/src/transformers/models/beit/test.ipynb deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/transformers/models/beit/test_semantic.py b/src/transformers/models/beit/test_semantic.py deleted file mode 100644 index d67f96321925..000000000000 --- a/src/transformers/models/beit/test_semantic.py +++ /dev/null @@ -1,22 +0,0 @@ -from datasets import load_dataset -from PIL import Image - -from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation - - -# load image + ground truth map -ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") -image = Image.open(ds[0]["file"]) -segmentation_map = Image.open(ds[1]["file"]) - -# load model -model_name = "/Users/NielsRogge/Documents/BEIT/beit-base-finetuned-ade20k" -feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) -model = BeitForSemanticSegmentation.from_pretrained(model_name) - -pixel_values = feature_extractor(image, return_tensors="pt").pixel_values -outputs = model(pixel_values) -logits = outputs.logits - -print("Shape of logits:", outputs.logits.shape) -print("First elements of logits:", logits[0, :3, :3, :3]) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 5bc4b01ea856..338d37b009bc 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -18,6 +18,8 @@ import inspect import unittest +from datasets import load_dataset + from transformers import BeitConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.models.auto import get_values @@ -492,14 +494,11 @@ def test_inference_image_classification_head_imagenet_22k(self): @slow def test_inference_semantic_segmentation(self): - model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640").to( - torch_device - ) + model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + model = model.to(torch_device) feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") image = Image.open(ds[0]["file"]) inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) From e61b5807621e8c8d483195bbce495d5695467fd8 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 28 Oct 2021 16:42:38 +0200 Subject: [PATCH 16/17] Apply suggestions from code review --- .../models/beit/configuration_beit.py | 24 +++++++++---------- src/transformers/models/beit/modeling_beit.py | 8 +++---- tests/test_modeling_beit.py | 2 -- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 477042db8842..bc1aa63197f7 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -84,13 +84,13 @@ class BeitConfig(PretrainedConfig): Pooling scales used in Pooling Pyramid Module applied on the last feature map. use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to use an auxiliary head during training. - loss_weight (:obj:`float`, `optional`, defaults to 0.4): + auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4): Weight of the cross-entropy loss of the auxiliary head. - channels (:obj:`int`, `optional`, defaults to 256): + auxiliary_channels (:obj:`int`, `optional`, defaults to 256): Number of channels to use in the auxiliary head. - num_convs (:obj:`int`, `optional`, defaults to 1): + auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1): Number of convolutional layers to use in the auxiliary head. - concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`): + auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to concatenate the output of the auxiliary head with the input before the classification layer. Example:: @@ -134,10 +134,10 @@ def __init__( out_indices=[3, 5, 7, 11], pool_scales=[1, 2, 3, 6], use_auxiliary_head=True, - loss_weight=0.4, - channels=256, - num_convs=1, - concat_input=False, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, **kwargs ): super().__init__(**kwargs) @@ -168,7 +168,7 @@ def __init__( self.pool_scales = pool_scales # auxiliary head attributes (semantic segmentation) self.use_auxiliary_head = use_auxiliary_head - self.loss_weight = loss_weight - self.channels = channels - self.num_convs = num_convs - self.concat_input = concat_input + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index acd9d4433059..bb6df3f2a5b6 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -1025,9 +1025,9 @@ class BeitFCNHead(nn.Module): def __init__(self, config, in_index=2, kernel_size=3, dilation=1): super().__init__() self.in_channels = config.hidden_size - self.channels = config.channels - self.num_convs = config.num_convs - self.concat_input = config.concat_input + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input self.in_index = in_index conv_padding = (kernel_size // 2) * dilation @@ -1109,7 +1109,7 @@ def compute_loss(self, logits, auxiliary_logits, labels): loss_fct = CrossEntropyLoss(ignore_index=255) main_loss = loss_fct(upsampled_logits, labels) auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) - loss = main_loss + self.config.loss_weight * auxiliary_loss + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss return loss diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 338d37b009bc..f0a89031416c 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -170,8 +170,6 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): attention_mask and seq_length. """ - maxDiff = None - all_model_classes = ( (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation) if is_torch_available() From 89fe3d50bf20bc05cbbf8bcd165e799df517ef35 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 1 Nov 2021 10:48:47 -0400 Subject: [PATCH 17/17] Fix failing test --- src/transformers/models/beit/modeling_beit.py | 2 +- tests/test_modeling_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index bb6df3f2a5b6..bf496d92a7fc 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -500,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3890198edbab..42330cc37dbf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -88,7 +88,7 @@ def _config_zero_init(config): configs_no_init = copy.deepcopy(config) for key in configs_no_init.__dict__.keys(): - if "_range" in key or "_std" in key or "initializer_factor" in key: + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: setattr(configs_no_init, key, 1e-10) return configs_no_init