In [1]:
# !pip install -U openmim

In [2]:
# !mim install mmcv-full==1.6.0

In [3]:
# !pip install mmsegmentation --force-reinstall

In [4]:
# !pip install transformers

In [5]:
# !pip install datasets

In [28]:
from transformers import SegformerFeatureExtractor, SegformerModel
import torch
import torch.nn as nn
from datasets import load_dataset

# build a model

In [32]:
class CustomSegFormerBase(nn.Module):
    def __init__(self):
        super(CustomSegFormerBase, self).__init__()
        self.model = SegformerModel.from_pretrained("nvidia/mit-b5", output_hidden_states=True)
    
    def forward(self, x, **inputs):
        return self.model(x, **inputs).hidden_states

In [33]:
bb = CustomSegFormerBase()

Some weights of the model checkpoint at nvidia/mit-b5 were not used when initializing SegformerModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SegformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [38]:
out = bb(torch.randn(1, 3, 255, 255))

In [39]:
for el in out:
    print(el.shape)

torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 320, 16, 16])
torch.Size([1, 512, 8, 8])


In [None]:
class CustomSegFormerPretrain(nn.Module):
    def __init__(self, hidden_sizes, decoder_dim=512, num_classes=1, backbone=None):
        super(CustomSegFormerPretrain, self).__init__()
        if backbone is not None:
            self.backbone = backbone
        else:
            self.backbone = CustomSegFormerBase()
        self.hidden_sizes = hidden_sizes
        self.decoder_dim = decoder_dim
        self.num_classes = num_classes
        
        self.linears = []
        for hs in self.hidden_sizes:
            self.linears.append(nn.Conv2d(in_channels=hs, out_channels=self.decoder_dim, kernel_size=(1, 1)))
            
        self.ups = [
            nn.Identity(),
            nn.Upsample(scale_factor=2),
            nn.Upsample(scale_factor=4),
            nn.Upsample(scale_factor=8)
       ]
            
        self.linears = nn.ModuleList(self.linears)
        self.ups = nn.ModuleList(self.ups)
        
        self.to_segmentation = nn.Sequential(
            nn.Conv2d(4 * decoder_dim, decoder_dim, 3),
            nn.Conv2d(decoder_dim, num_classes, 3),
        )
        
    def forward(self, x, **kwargs):
        hs = list(self.backbone(x, **kwargs))
        
        for i in range(len(self.linears)):
            hs[i] = self.linears[i](hs[i])
            hs[i] = self.ups[i](hs[i])
            
        fused = torch.cat(hs, dim=1)
        out = self.to_segmentation(fused)
        return out

In [63]:
class CustomSegFormerSegmentation(nn.Module):
    def __init__(self, hidden_sizes, decoder_dim=512, num_classes=1, backbone=None):
        super(CustomSegFormerSegmentation, self).__init__()
        if backbone is not None:
            self.backbone = backbone
        else:
            self.backbone = CustomSegFormerBase()
        self.hidden_sizes = hidden_sizes
        self.decoder_dim = decoder_dim
        self.num_classes = num_classes
        
        self.linears = []
        for hs in self.hidden_sizes:
            self.linears.append(nn.Conv2d(in_channels=hs, out_channels=self.decoder_dim, kernel_size=(1, 1)))
            
        self.ups = [
            nn.Identity(),
            nn.Upsample(scale_factor=2),
            nn.Upsample(scale_factor=4),
            nn.Upsample(scale_factor=8)
       ]
            
        self.linears = nn.ModuleList(self.linears)
        self.ups = nn.ModuleList(self.ups)
        
        self.to_segmentation = nn.Sequential(
            nn.Conv2d(4 * decoder_dim, decoder_dim, 1),
            nn.Conv2d(decoder_dim, num_classes, 1),
        )
        
    def forward(self, x, **kwargs):
        hs = list(self.backbone(x, **kwargs))
        
        for i in range(len(self.linears)):
            hs[i] = self.linears[i](hs[i])
            hs[i] = self.ups[i](hs[i])
            
        fused = torch.cat(hs, dim=1)
        out = self.to_segmentation(fused)
        return out

In [66]:
ftuner = CustomSegFormerPretrain(hidden_sizes=[64, 128, 320, 512], decoder_dim=512)

Some weights of the model checkpoint at nvidia/mit-b5 were not used when initializing SegformerModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SegformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [68]:
_ = ftuner(torch.randn(1, 3, 256, 256))

???? torch.Size([1, 64, 64, 64])
???? torch.Size([1, 128, 32, 32])
???? torch.Size([1, 320, 16, 16])
???? torch.Size([1, 512, 8, 8])
torch.Size([1, 1, 64, 64])
