In [3]:
import os
from transformers import (
    SegformerForSemanticSegmentation,
    SegformerConfig,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForSemanticSegmentation,
    SegformerLayer,
)

from transformers.models.segformer.modeling_segformer import (
    SegformerAttention,
    SegformerDropPath,
    SegformerMixFFN,
)

from pascal_utils import PascalTrainer, get_dataset
import torch

In [None]:
class FC_SegformerLayer(SegformerLayer):
    """This corresponds to the Block class in the original implementation."""

    def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
        super().__init__()
        self.layer_norm_1 = torch.nn.LayerNorm(hidden_size)
        self.attention = SegformerAttention(
            config,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            sequence_reduction_ratio=sequence_reduction_ratio,
        )
        self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
        self.layer_norm_2 = torch.nn.LayerNorm(hidden_size)
        mlp_hidden_size = int(hidden_size * mlp_ratio)
        self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)

    def forward(self, hidden_states, height, width, output_attentions=False):
        self_attention_outputs = self.attention(
            self.layer_norm_1(hidden_states),  # in Segformer, layernorm is applied before self-attention
            height,
            width,
            output_attentions=output_attentions,
        )

        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # first residual connection (with stochastic depth)
        attention_output = self.drop_path(attention_output)
        hidden_states = attention_output + hidden_states

        mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)

        # second residual connection (with stochastic depth)
        mlp_output = self.drop_path(mlp_output)
        layer_output = mlp_output + hidden_states

        outputs = (layer_output,) + outputs

        return outputs

In [None]:
segformer_for_semantic_segmentation = SegformerForSemanticSegmentation()
segformer_model = segformer_for_semantic_segmentation.segformer
segformer_encoder = segformer_model.encoder