diff --git a/src/transformers/models/swin/configuration_swin.py b/src/transformers/models/swin/configuration_swin.py index 8749ed3d754f..9956482b9ab7 100644 --- a/src/transformers/models/swin/configuration_swin.py +++ b/src/transformers/models/swin/configuration_swin.py @@ -94,6 +94,7 @@ class SwinConfig(PretrainedConfig): attribute_map = { "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", } def __init__( @@ -141,4 +142,4 @@ def __init__( self.encoder_stride = encoder_stride # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model - self.hidden_size = embed_dim * 8 + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index e79b12533deb..25acbb724f68 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -56,8 +56,8 @@ def __init__( patch_size=2, num_channels=3, embed_dim=16, - depths=[1], - num_heads=[2], + depths=[1, 2, 1], + num_heads=[2, 2, 4], window_size=2, mlp_ratio=2.0, qkv_bias=True, @@ -73,7 +73,7 @@ def __init__( scope=None, use_labels=True, type_sequence_label_size=10, - encoder_stride=2, + encoder_stride=8, ): self.parent = parent self.batch_size = batch_size @@ -139,8 +139,7 @@ def create_and_check_model(self, config, pixel_values, labels): model.eval() result = model(pixel_values) - # since the model we're testing only consists of a single layer, expected_seq_len = number of patches - expected_seq_len = (config.image_size // config.patch_size) ** 2 + expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1)) expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) diff --git a/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 311161158845..d318eb4286de 100644 --- a/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -22,8 +22,10 @@ from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from ..bart.test_modeling_bart import BartModelTester from ..bert.test_modeling_bert import BertModelTester from ..deit.test_modeling_deit import DeiTModelTester +from ..swin.test_modeling_swin import SwinModelTester from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester from ..vit.test_modeling_vit import ViTModelTester @@ -35,8 +37,10 @@ from transformers import ( AutoTokenizer, + BartForCausalLM, BertLMHeadModel, DeiTModel, + SwinModel, TrOCRForCausalLM, VisionEncoderDecoderConfig, VisionEncoderDecoderModel, @@ -514,6 +518,90 @@ def prepare_config_and_inputs(self): } +@require_torch +class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase): + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = SwinModel(config).eval() + decoder_model = BartForCausalLM(decoder_config).eval() + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = SwinModelTester(self, batch_size=13, embed_dim=32) + model_tester_decoder = BartModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs() + config, pixel_values, _ = encoder_config_and_inputs + decoder_config, decoder_inputs_dict = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "pixel_values": pixel_values, + "decoder_config": decoder_config, + **decoder_inputs_dict, + } + + def check_encoder_decoder_model_output_attentions( + self, + config, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels=None, + pixel_values=None, + **kwargs + ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + ) + + encoder_attentions = outputs_encoder_decoder["encoder_attentions"] + self.assertEqual(len(encoder_attentions), config.num_hidden_layers) + + # in Swin, the seq_len equals: + seq_len = encoder_model.config.window_size**2 + self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads[0], seq_len, seq_len)) + + decoder_attentions = outputs_encoder_decoder["decoder_attentions"] + num_decoder_layers = ( + decoder_config.num_decoder_layers + if hasattr(decoder_config, "num_decoder_layers") + else decoder_config.num_hidden_layers + ) + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs_encoder_decoder["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + encoder_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1)) + cross_attention_input_seq_len = decoder_input_ids.shape[-1] + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, encoder_seq_len), + ) + + # there are no published pretrained BART-causal checkpoints for now + def test_real_model_save_load_from_pretrained(self): + pass + + @require_torch class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): def get_encoder_decoder_model(self, config, decoder_config):