In [1]:
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [2]:
class CBR(nn.Module):

    def __init__(self, n_in, n_out, k_size, stride=1):
        super().__init__()
        padding = (k_size - 1) // 2
        self.conv = nn.Conv2d(n_in, n_out, kernel_size=k_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(n_out)
        self.act = nn.PReLU(n_out)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)

        return x


In [3]:
class BR(nn.Module):

    def __init__(self, n_out):
        super().__init__()
        self.bn = nn.BatchNorm2d(n_out)
        self.act = nn.PReLU(n_out)

    def forward(self, x):
        x = self.bn(x)
        x = self.act(x)

        return x

In [4]:
class C(nn.Module):

    def __init__(self, n_in, n_out, k_size, stride=1):
        super().__init__()
        padding = (k_size - 1) // 2
        self.conv = nn.Conv2d(n_in, n_out, kernel_size=k_size, stride=stride, padding=padding, bias=False)

    def forward(self, x):
        x = self.conv(x)
        
        return x


In [5]:
class InputProjection(nn.Module):

    def __init__(self, n_sampling):
        super().__init__()
        self.pool = nn.ModuleList()

        for i in range(n_sampling):
            # 1 iter : W, H -> W/2, H/2
            self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))

    def forward(self, x):
        for pool in self.pool:
            x = pool(x)

        return x

In [6]:
class Split_K(nn.Module):

    def __init__(self, n_in, n_out, k_size, dilation=1):
        super().__init__()

        padding = ((k_size - 1) // 2) * dilation
        self.dilated_conv = nn.Conv2d(n_in, n_out, kernel_size=k_size, stride=1, padding=padding, dilation=dilation, bias=False)

    def forward(self, x):
        x = self.dilated_conv(x)
        
        return x


In [7]:
class ESP_Module(nn.Module):
    # W: width
    # H: height
    # m: width of kernel
    # n: height of kernel
    # M: # input feature channels
    # N: # output feature channels

    def __init__(self, M, N, K=5, n=3, add=True, downsample=False):
        '''
        channel 변화 요약
        1. M -> 1x1(Reduce) -> d=N/K
        2. d -> K개의 rate가 다른 dilated conv(Split & Transform) -> K개의 d
        3. K개의 d -> HFF & Merge -> N
        '''

        super().__init__()

        d = N // K
        d1 = N - 4*d # N/K may not be a perfect divisor
        
        # Reduce
        if downsample:
            self.reduce = C(n_in=M, n_out=d, k_size=3, stride=2)
        else:
            self.reduce = C(n_in=M, n_out=d, k_size=1, stride=1)

        # Split & Transform
        self.split_1 = Split_K(n_in=d, n_out=d1, k_size=n, dilation=2**0)
        self.split_2 = Split_K(n_in=d, n_out=d, k_size=n, dilation=2**1)
        self.split_3 = Split_K(n_in=d, n_out=d, k_size=n, dilation=2**2)
        self.split_4 = Split_K(n_in=d, n_out=d, k_size=n, dilation=2**3)
        self.split_5 = Split_K(n_in=d, n_out=d, k_size=n, dilation=2**4)

        self.bn = nn.BatchNorm2d(N)
        self.act = nn.PReLU(N)

        self.add = add
        self.downsample = downsample
        
    def forward(self, identity):
        # Reduce
        x = self.reduce(identity)

        # Split & Transform
        x_1 = self.split_1(x)
        x_2 = self.split_2(x)
        x_3 = self.split_3(x)
        x_4 = self.split_4(x)
        x_5 = self.split_5(x)

        # HFF(Hierarchical feature fusion)
        add_1 = x_2
        add_2 = add_1 + x_3
        add_3 = add_2 + x_4
        add_4 = add_3 + x_5

        # Merge
        concat = torch.cat([x_1, add_1, add_2, add_3, add_4], 1)
        
        if self.add:
            concat = identity + concat
        
        out = self.bn(concat)
        out = self.act(out)

        return out

In [8]:
class ESPNet_C(nn.Module):

    def __init__(self, classes=20, a2=5, a3=8):
        super().__init__()

        # Conv-3
        self.conv_3 = CBR(n_in=3, n_out=16, k_size=3, stride=2)

        # Concat
        self.input_downsample_1 = InputProjection(1) # input (1/2)
        self.br1 = BR(19) # 19 = 16 + 3

        # ESP
        self.esp_downsample_1 = ESP_Module(19, 64, add=False, downsample=True)

        # ESP x a2
        self.esp_a2 = nn.ModuleList()
        for _ in range(a2):
            self.esp_a2.append(ESP_Module(64, 64))

        # Concat
        self.input_downsample_2 = InputProjection(2)
        self.br2 = BR(131) # 131 = 64 * 2 + 3

        # ESP
        self.esp_downsample_2 = ESP_Module(131, 128, add=False, downsample=True)

        # ESP x a3
        self.esp_a3 = nn.ModuleList()
        for _ in range(a3):
            self.esp_a3.append(ESP_Module(128, 128))

        # Concat
        self.br3 = BR(256)
        
        # Conv-1
        self.conv1 = C(256, classes, k_size=1, stride=1)

    def forward(self, img):
        # Conv-3
        x_0 = self.conv_3(img)

        # Concat
        img_downsample_1 = self.input_downsample_1(img)
        concat_1 = self.br1(torch.cat([x_0, img_downsample_1], 1))

        # ESP
        x_1_0 = self.esp_downsample_1(concat_1)

        # ESP x a2
        for i, esp in enumerate(self.esp_a2):
            if i == 0:
                x_1 = esp(x_1_0)
            else:
                x_1 = esp(x_1)

        # Concat
        img_downsample_2 = self.input_downsample_2(img)
        concat_2 = self.br2(torch.cat([x_1, x_1_0, img_downsample_2], 1))

        # ESP
        x_2_0 = self.esp_downsample_2(concat_2)

        # ESP x a3
        for i, esp in enumerate(self.esp_a3):
            if i == 0:
                x_2 = esp(x_2_0)
            else:
                x_2 = esp(x_2)

        # Concat
        concat_3 = self.br3(torch.cat([x_2, x_2_0], 1))

        # Conv-1
        out = self.conv1(concat_3)

        return out

In [9]:
class ESPNet(nn.Module):

    def __init__(self, classes=20, a2=5, a3=8):
        super().__init__()

        # Conv-3
        self.conv_3 = CBR(n_in=3, n_out=16, k_size=3, stride=2)

        # Concat
        self.input_downsample_1 = InputProjection(1) # input (1/2)
        self.br1 = BR(19) # 19 = 16 + 3

        # Conv-1: 19->C
        self.conv1_1 = C(19, classes, 1, 1)

        # ESP
        self.esp_downsample_1 = ESP_Module(19, 64, add=False, downsample=True)
        self.esp_a2 = nn.ModuleList()
        for _ in range(a2):
            self.esp_a2.append(ESP_Module(64, 64))

        # Concat
        self.input_downsample_2 = InputProjection(2)
        self.br2 = BR(131) # 131 = 64 * 2 + 3

        # Conv-1: 131->C
        self.conv1_2 = C(131, classes, 1, 1)

        # ESP
        self.esp_downsample_2 = ESP_Module(131, 128, add=False, downsample=True)
        self.esp_a3 = nn.ModuleList()
        for _ in range(a3):
            self.esp_a3.append(ESP_Module(128, 128))

        # Concat
        self.br3 = BR(256)

        # Conv-1: 256->C
        self.conv1_3 = C(256, classes, k_size=1, stride=1)

        # <----------- Up: Encoder (ESPNet-C) / Down: Light-weight Decoder -----------> #

        # Deconv
        self.bn = nn.BatchNorm2d(classes)
        self.deconv_1 = nn.ConvTranspose2d(classes, classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=False)

        # Concat
        self.br4 = BR(2 * classes)

        # ESP
        self.esp = ESP_Module(2*classes, classes, add=False)

        # Deconv
        self.deconv_2 = nn.ConvTranspose2d(classes, classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=False)

        # Concat
        self.br5 = BR(2 * classes)

        # Conv-1: 2C->C
        self.conv1_4 = C(2 * classes, classes, k_size=1, stride=1)

        # Deconv
        self.deconv_3 = nn.ConvTranspose2d(classes, classes, kernel_size=2, stride=2,
                                           padding=0, output_padding=0, bias=False)
    
    def forward(self, img):
        # <---------------- Encoder (ESPNet-C) ----------------> #
        # Conv-3
        x_0 = self.conv_3(img)

        # Concat
        img_downsample_1 = self.input_downsample_1(img)
        concat_1 = self.br1(torch.cat([x_0, img_downsample_1], 1))

        # ESP
        x_1_0 = self.esp_downsample_1(concat_1)
        for i, esp in enumerate(self.esp_a2):
            if i == 0:
                x_1 = esp(x_1_0)
            else:
                x_1 = esp(x_1)

        # Concat
        img_downsample_2 = self.input_downsample_2(img)
        concat_2 = self.br2(torch.cat([x_1, x_1_0, img_downsample_2], 1))

        # ESP
        x_2_0 = self.esp_downsample_2(concat_2)
        for i, esp in enumerate(self.esp_a3):
            if i == 0:
                x_2 = esp(x_2_0)
            else:
                x_2 = esp(x_2)

        # Concat
        concat_3 = self.br3(torch.cat([x_2, x_2_0], 1))
        out_encoder = self.conv1_3(concat_3)

        # <----------- Up: Encoder (ESPNet-C) / Down: Light-weight Decoder -----------> #
        # Deconv
        x_3_0 = self.deconv_1(self.bn(out_encoder))
        
        # Concat
        x_3 = self.br4(torch.cat([self.conv1_2(concat_2), x_3_0], 1))

        # ESP
        x_3 = self.esp(x_3)

        # Deconv
        x_4_0 = self.deconv_2(x_3)
        
        # Concat
        x_5 = self.br5(torch.cat([self.conv1_1(concat_1), x_4_0], 1)) 

        # Conv-1: 2C->C
        x_5 = self.conv1_4(x_5)

        # Deconv
        out = self.deconv_3(x_5)

        return out

In [10]:
inp = torch.randn(1, 3, 1024, 2048).to(device)
print(inp.shape)

net = ESPNet_C().to(device)
out = net(inp)

print(out.shape)

torch.Size([1, 3, 1024, 2048])
torch.Size([1, 20, 128, 256])


In [11]:
inp = torch.randn(1, 3, 1024, 2048).to(device)
print(inp.shape)

net = ESPNet().to(device)
out = net(inp)

print(out.shape)

torch.Size([1, 3, 1024, 2048])
torch.Size([1, 20, 1024, 2048])
