In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.nn import functional as F
import timm


model = timm.create_model('tf_efficientnet_b7', pretrained=True)

## EfficientNet 모델의 상세 구조 파악 코드

In [3]:
conv_stem = model.conv_stem
bn1 = model.bn1
act1 = model.act1
blocks0 = model.blocks[0]
blocks1 = model.blocks[1]
blocks2 = model.blocks[2]
blocks3 = model.blocks[3]
blocks4 = model.blocks[4]
blocks5 = model.blocks[5]
blocks6 = model.blocks[6]
conv_head = model.conv_head
bn2 = model.bn2
act2 = model.act2
global_pool = model.global_pool
classifier = model.classifier

## EfficientNet + DeepLabV3 모델 구현 Ver1

In [18]:
class DeepLabV3EffiB7Timm(nn.Module):
    """ 최종 모델 구현
    """
    def __init__(self, n_classes=12, n_blocks=[3, 4, 23, 3], atrous_rates= [1, 6, 12, 18]):
        super(DeepLabV3EffiB7Timm,self).__init__()
        model = timm.create_model('tf_efficientnet_b7', pretrained=True)
        
        self.conv_stem = model.conv_stem
        self.bn1 = model.bn1
        self.act1 = model.act1
        self.blocks0 = model.blocks[0]
        self.blocks1 = model.blocks[1]
        self.blocks2 = model.blocks[2]
        self.blocks3 = model.blocks[3]
        self.blocks4 = model.blocks[4]
        self.blocks5 = model.blocks[5]
        self.blocks6 = model.blocks[6]
        self.conv_head = model.conv_head
        self.bn2 = model.bn2
        self.act2 = model.act2
        self.global_pool = model.global_pool
        
        self.pool_cnn = nn.Conv2d(in_channels = 160, out_channels = 12,kernel_size = 1, stride=2, padding=0)

        self.deep_head = DeepLabHead(in_ch=2560, out_ch=256, n_classes=n_classes, atrous_rates= atrous_rates)
        
        self.classifier = nn.Conv2d(in_channels = n_classes * 2, out_channels = n_classes,kernel_size = 3, stride=1, padding=1)
    
    def forward(self, x):
        
        restore_size = x.shape[2:]
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.blocks0(x)
        x = self.blocks1(x)
        x = self.blocks2(x)
        pool = self.blocks3(x)
        x = self.blocks4(pool)
        x = self.blocks5(x)
        x = self.blocks6(x)
        x = self.conv_head(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.deep_head(x)
        pool = self.pool_cnn(pool)
        output = torch.cat([pool, x], dim = 1)
        output = self.classifier(output)
        output = F.interpolate(output, size=restore_size, mode="bilinear", align_corners=False)
        
        return output


    
class DeepLabHead(nn.Sequential):
    """ 전제 DeepLabV3 헤드 부분 아키텍쳐
    """
    def __init__(self, in_ch, out_ch, n_classes, atrous_rates= [1, 6, 12, 18]):
        super(DeepLabHead, self).__init__()
        self.add_module("0", ASPP(in_ch, out_ch,atrous_rates))
        self.add_module("1", nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1 , bias=False))
        self.add_module("2", nn.BatchNorm2d(out_ch))
        self.add_module("3", nn.ReLU())
        self.add_module("4", nn.Conv2d(out_ch, n_classes, kernel_size=1, stride=1))    
        
    
class ASPPConv(nn.Module):
    """각각의 Atrous Convolution의 로직
    """
    def __init__(self, inplanes, outplanes, kernel_size, padding, dilation):
        super(ASPPConv, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)
        output = self.relu(x)
        return output
    

class ASPPPooling(nn.Module):
    """Atrous convolution후의 pooling과정
    """
    def __init__(self, inplanes, outplanes):
        super(ASPPPooling, self).__init__()
        self.globalavgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False)
        self.bn = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.globalavgpool(x)
        x = self.conv(x)
        x = self.bn(x)
        output = self.relu(x)
        return output


class ASPP(nn.Module):
    """전체 Atrous Convolution의 로직
    """
    def __init__(self, inplanes, outplanes, atrous_rates= [1, 6, 12, 18]):
        super(ASPP, self).__init__()
        dilations = atrous_rates
        self.aspp1 = ASPPConv(inplanes, outplanes, 1, padding=0, dilation=dilations[0])
        self.aspp2 = ASPPConv(inplanes, outplanes, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = ASPPConv(inplanes, outplanes, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = ASPPConv(inplanes, outplanes, 3, padding=dilations[3], dilation=dilations[3])
        self.global_avg_pool = ASPPPooling(inplanes, outplanes)
        self.project = nn.Sequential(
            nn.Conv2d(outplanes*5, outplanes, 1, bias=False), 
            nn.BatchNorm2d(outplanes), 
            nn.ReLU(), 
            nn.Dropout(0.5)      
        )

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        output = self.project(x)
        return output

In [1]:
model = DeepLabV3EffiB7Timm(n_classes=12, n_blocks=[3, 4, 23, 3], atrous_rates=[1, 12, 24, 36]) # 비교
x = torch.randn([2, 3, 512, 512])
print(f"input shape : {x.shape}")
out = model(x)
print(f"output shape :  {out.size()}")

## EfficientNet + DeepLabV3 모델 구현 Ver2

In [8]:
class DeepLabV3EffiB7Timm(nn.Module):
    """ 최종 모델 구현
    """
    def __init__(self, n_classes, n_blocks, atrous_rates= [1, 6, 12, 18]):
        super(DeepLabV3EffiB7Timm,self).__init__()
        self.pretrained_model = timm.create_model('tf_efficientnet_b7', pretrained=True)

        self.extract_features = self.pretrained_model.forward_features
        self.classifier = DeepLabHead(in_ch=2560, out_ch=256, n_classes=12, atrous_rates= atrous_rates)
    
    def forward(self, x):

        h = self.extract_features(x)
        h = self.classifier(h)
        output = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)
        return output


    
class DeepLabHead(nn.Sequential):
    """ 전제 DeepLabV3 헤드 부분 아키텍쳐
    """
    def __init__(self, in_ch, out_ch, n_classes, atrous_rates= [1, 6, 12, 18]):
        super(DeepLabHead, self).__init__()
        self.add_module("0", ASPP(in_ch, out_ch,atrous_rates))
        self.add_module("1", nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1 , bias=False))
        self.add_module("2", nn.BatchNorm2d(out_ch))
        self.add_module("3", nn.ReLU())
        self.add_module("4", nn.Conv2d(out_ch, n_classes, kernel_size=1, stride=1))    
        
    
class ASPPConv(nn.Module):
    """각각의 Atrous Convolution의 로직
    """
    def __init__(self, inplanes, outplanes, kernel_size, padding, dilation):
        super(ASPPConv, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)
        output = self.relu(x)
        return output
    

class ASPPPooling(nn.Module):
    """Atrous convolution후의 pooling과정
    """
    def __init__(self, inplanes, outplanes):
        super(ASPPPooling, self).__init__()
        self.globalavgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False)
        self.bn = nn.BatchNorm2d(outplanes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.globalavgpool(x)
        x = self.conv(x)
        x = self.bn(x)
        output = self.relu(x)
        return output


class ASPP(nn.Module):
    """전체 Atrous Convolution의 로직
    """
    def __init__(self, inplanes, outplanes, atrous_rates= [1, 6, 12, 18]):
        super(ASPP, self).__init__()
        # dilations = [1, 6, 12, 18]
        dilations = atrous_rates
        self.aspp1 = ASPPConv(inplanes, outplanes, 1, padding=0, dilation=dilations[0])
        self.aspp2 = ASPPConv(inplanes, outplanes, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = ASPPConv(inplanes, outplanes, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = ASPPConv(inplanes, outplanes, 3, padding=dilations[3], dilation=dilations[3])
        self.global_avg_pool = ASPPPooling(inplanes, outplanes)
        self.project = nn.Sequential(
            nn.Conv2d(outplanes*5, outplanes, 1, bias=False), 
            nn.BatchNorm2d(outplanes), 
            nn.ReLU(), 
            nn.Dropout(0.5)      
        )

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        output = self.project(x)
        return output

In [2]:
model = DeepLabV3EffiB7Timm(n_classes=12, n_blocks=[3, 4, 23, 3], atrous_rates=[1, 12, 24, 36]) # 비교
x = torch.randn([2, 3, 512, 512])
print(f"input shape : {x.shape}")
out = model(x)
print(f"output shape :  {out.size()}")