In [1]:
from distutils.command.register import register

import timm


In [None]:
@register_model

In [5]:
list = timm.list_models("*swin*")
print(list)

['eca_swinnext26ts_256', 'swin_base_patch4_window7_224', 'swin_base_patch4_window7_224_in22k', 'swin_base_patch4_window12_384', 'swin_base_patch4_window12_384_in22k', 'swin_large_patch4_window7_224', 'swin_large_patch4_window7_224_in22k', 'swin_large_patch4_window12_384', 'swin_large_patch4_window12_384_in22k', 'swin_small_patch4_window7_224', 'swin_tiny_patch4_window7_224', 'swinnet26t_256', 'swinnet50ts_256']


In [None]:
# save my model on timm
timm.model_entrypoint()

In [None]:
class BertLayer(nn.Module):
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.layer_num = layer_num

        # compatibility for ALBEF and BLIP
        try:
            # ALBEF & ALPRO
            fusion_layer = self.config.fusion_layer
            add_cross_attention = (
                    fusion_layer <= layer_num and self.config.add_cross_attention
            )

            self.fusion_layer = fusion_layer
        except AttributeError:
            # BLIP
            self.fusion_layer = self.config.num_hidden_layers
            add_cross_attention = self.config.add_cross_attention

        # if self.config.add_cross_attention:
        if add_cross_attention:
            # self.crossattention = BertAttention(
            #     config, is_cross_attention=self.config.add_cross_attention
            # )
            self.coAttention_text = BertAttention(
                config, is_co_attention= 1
            )
            self.coAttention_image = BertAttention(
                config, is_co_attention= 2
            )
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
            self,
            hidden_states,
            attention_mask=None,
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_value=None,
            output_attentions=False,
            mode=None,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = (
            past_key_value[:2] if past_key_value is not None else None
        )
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
            is_co_attention=0,
        )
        attention_output = self_attention_outputs[0]

        outputs = self_attention_outputs[1:-1]
        present_key_value = self_attention_outputs[-1]

        # TODO line 482 in albef/models/xbert.py
        # compatibility for ALBEF and BLIP
        if mode in ["multimodal", "fusion"] and hasattr(self, "crossattention"):
            assert (
                    encoder_hidden_states is not None
            ), "encoder_hidden_states must be given for cross-attention layers"

            if isinstance(encoder_hidden_states, list):
                co_attention_text_outputs = self.coAttention_text(
                    attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states[
                        (self.layer_num - self.fusion_layer + 1)
                        % len(encoder_hidden_states)
                        ],
                    encoder_attention_mask[
                        (self.layer_num - self.fusion_layer + 1)
                        % len(encoder_hidden_states)
                        ],
                    output_attentions=output_attentions,
                    is_co_attention=1,
                )
                text_attention_output = co_attention_text_outputs[0]
                text_outputs = outputs + co_attention_text_outputs[1:-1]

                co_attention_image_outputs = self.coAttention_image(
                    attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states[
                        (self.layer_num - self.fusion_layer)
                        % len(encoder_hidden_states)
                        ],
                    encoder_attention_mask[
                        (self.layer_num - self.fusion_layer)
                        % len(encoder_hidden_states)
                        ],
                    output_attentions=output_attentions,
                    is_co_attention= 2,
                )
                image_attention_output = co_attention_image_outputs[0]
                image_outputs = outputs + co_attention_image_outputs[1:-1]
            else:
                co_attention_text_outputs = self.coAttention_text(
                    attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions=output_attentions,
                    is_co_attention=1,
                )
                text_attention_output = co_attention_text_outputs[0]
                text_outputs = outputs + co_attention_text_outputs[1:-1]

                co_attention_image_outputs = self.coAttention_image(
                    attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions=output_attentions,
                    is_co_attention=2,
                )
                image_attention_output = co_attention_image_outputs[0]
                image_outputs = outputs + co_attention_image_outputs[1:-1]

        text_layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            image_attention_output,
        )
        image_layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            text_attention_output,
        )
        
        text_outputs = (text_layer_output,) + text_outputs
        image_outputs = (image_layer_output,) + image_outputs

        # decoder 才會回傳present_key_value
        text_outputs = text_outputs + (present_key_value,)
    
        return text_outputs, image_outputs