Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/models/swin/configuration_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class SwinConfig(PretrainedConfig):

attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}

def __init__(
Expand Down Expand Up @@ -141,4 +142,4 @@ def __init__(
self.encoder_stride = encoder_stride
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = embed_dim * 8
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
9 changes: 4 additions & 5 deletions tests/swin/test_modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(
patch_size=2,
num_channels=3,
embed_dim=16,
depths=[1],
num_heads=[2],
depths=[1, 2, 1],
num_heads=[2, 2, 4],
window_size=2,
mlp_ratio=2.0,
qkv_bias=True,
Expand All @@ -73,7 +73,7 @@ def __init__(
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=2,
encoder_stride=8,
):
self.parent = parent
self.batch_size = batch_size
Expand Down Expand Up @@ -139,8 +139,7 @@ def create_and_check_model(self, config, pixel_values, labels):
model.eval()
result = model(pixel_values)

# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))

self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device

from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..swin.test_modeling_swin import SwinModelTester
from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
from ..vit.test_modeling_vit import ViTModelTester
Expand All @@ -35,8 +37,10 @@

from transformers import (
AutoTokenizer,
BartForCausalLM,
BertLMHeadModel,
DeiTModel,
SwinModel,
TrOCRForCausalLM,
VisionEncoderDecoderConfig,
VisionEncoderDecoderModel,
Expand Down Expand Up @@ -514,6 +518,90 @@ def prepare_config_and_inputs(self):
}


@require_torch
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = SwinModel(config).eval()
decoder_model = BartForCausalLM(decoder_config).eval()
return encoder_model, decoder_model

def prepare_config_and_inputs(self):
model_tester_encoder = SwinModelTester(self, batch_size=13, embed_dim=32)
model_tester_decoder = BartModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs
decoder_config, decoder_inputs_dict = decoder_config_and_inputs

# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
**decoder_inputs_dict,
}

def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)

encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

# in Swin, the seq_len equals:
seq_len = encoder_model.config.window_size**2
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads[0], seq_len, seq_len))

decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)

self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)

cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)

encoder_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, encoder_seq_len),
)

# there are no published pretrained BART-causal checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass


@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
Expand Down