Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BeitForSemanticSegmentation #14096

Merged
merged 17 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/model_doc/beit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ BeitForImageClassification
:members: forward


BeitForSemanticSegmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BeitForSemanticSegmentation
:members: forward


FlaxBeitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification",
"BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel",
"BeitPreTrainedModel",
]
Expand Down Expand Up @@ -2477,6 +2478,7 @@
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
BeitPreTrainedModel,
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/beit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification",
"BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel",
"BeitPreTrainedModel",
]
Expand All @@ -57,6 +58,7 @@
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
BeitPreTrainedModel,
)
Expand Down
30 changes: 30 additions & 0 deletions src/transformers/models/beit/configuration_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig):
use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
CLS token, before applying the classification head.
out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`):
Indices of the feature maps to use for semantic segmentation.
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`):
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to use an auxiliary head during training.
auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4):
Weight of the cross-entropy loss of the auxiliary head.
auxiliary_channels (:obj:`int`, `optional`, defaults to 256):
Number of channels to use in the auxiliary head.
auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1):
Number of convolutional layers to use in the auxiliary head.
auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to concatenate the output of the auxiliary head with the input before the classification layer.

Example::

Expand Down Expand Up @@ -117,6 +131,13 @@ def __init__(
layer_scale_init_value=0.1,
drop_path_rate=0.1,
use_mean_pooling=True,
out_indices=[3, 5, 7, 11],
pool_scales=[1, 2, 3, 6],
use_auxiliary_head=True,
auxiliary_loss_weight=0.4,
auxiliary_channels=256,
auxiliary_num_convs=1,
auxiliary_concat_input=False,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -142,3 +163,12 @@ def __init__(
self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate
self.use_mean_pooling = use_mean_pooling
# decode head attributes (semantic segmentation)
self.out_indices = out_indices
self.pool_scales = pool_scales
# auxiliary head attributes (semantic segmentation)
self.use_auxiliary_head = use_auxiliary_head
self.auxiliary_loss_weight = auxiliary_loss_weight
self.auxiliary_channels = auxiliary_channels
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
166 changes: 123 additions & 43 deletions src/transformers/models/beit/convert_beit_unilm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@
from pathlib import Path

import torch
from datasets import load_dataset
from PIL import Image

import requests
from huggingface_hub import cached_download, hf_hub_url
from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling
from transformers import (
BeitConfig,
BeitFeatureExtractor,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
)
from transformers.utils import logging


Expand All @@ -33,27 +40,33 @@


# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, has_lm_head=False):
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
prefix = "backbone." if is_semantic else ""

rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
)
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
)
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))

# projection layer + position embeddings
rename_keys.extend(
[
("cls_token", "beit.embeddings.cls_token"),
("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
]
)

Expand All @@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False):
("norm.bias", "layernorm.bias"),
]
)
elif is_semantic:
# semantic segmentation classification heads
rename_keys.extend(
[
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
]
)
else:
# layernorm + classification head
rename_keys.extend(
Expand All @@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False):


# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, has_lm_head=False):
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
for i in range(config.num_hidden_layers):
prefix = "beit."
prefix = "backbone." if is_semantic else ""
# queries, keys and values
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias")
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")

state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias

# gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2")
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")

state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2

# relative_position bias table + index
if not has_lm_head:
# each layer has its own relative position bias
table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"blocks.{i}.attn.relative_position_index")
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")

state_dict[
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
] = table
state_dict[
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
] = index


Expand All @@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
# define default BEiT configuration
config = BeitConfig()
has_lm_head = False
is_semantic = False
repo_id = "datasets/huggingface/label-files"
# set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k":
Expand Down Expand Up @@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.image_size = 384
if "512" in checkpoint_url:
config.image_size = 512
elif "ade20k" in checkpoint_url:
# fine-tuning
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = 640
is_semantic = True
else:
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'")
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")

# size of the architecture
if "base" in checkpoint_url:
Expand All @@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
if "ade20k" in checkpoint_url:
config.image_size = 640
config.out_indices = [7, 11, 15, 23]
else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")

# load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head)
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]

rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
if is_semantic:
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("backbone.fpn"):
key = key.replace("backbone.fpn", "fpn")
state_dict[key] = val

# load HuggingFace model
if checkpoint_url[-9:-4] == "pt22k":
model = BeitForMaskedImageModeling(config)
elif "ade20k" in checkpoint_url:
model = BeitForSemanticSegmentation(config)
else:
model = BeitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)

# Check outputs on an image
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
if is_semantic:
feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
else:
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
image = prepare_img()

encoding = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"]

outputs = model(pixel_values)
Expand Down Expand Up @@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
)
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
]
)
else:
raise ValueError("Can't verify logits as model is not supported")

assert logits.shape == expected_shape, "Shape of logits not as expected"
print("Shape of logits:", logits.shape)
if not has_lm_head:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
if is_semantic:
assert torch.allclose(
logits[0, :3, :3, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
else:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(
logits[0, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"

Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
Expand Down
Loading