In [3]:
import torch
import torch.nn as nn

# TODO

* ELANBlock 상위 클래스 만들어서 기본 메소드 상속하는 방식

In [4]:
# from yolov7 in common.py
class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        p=k//2
        self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))
    
class Concat(nn.Module):
    def __init__(self, dimension=1):
        super(Concat, self).__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)

In [24]:
# ELANBlock for Backbone
class BBoneELANBlock(nn.Module):
    def __init__(self, c1, k, depth):
        super(BBoneELANBlock, self).__init__()
        assert c1 % 2 == 0 and depth < 5
        c_ = int(c1 / 2)
        
        self.depth = depth
        
        # depth 1
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        # depth 2
        self.cv3 = Conv(c_, c_, k, 1)
        self.cv4 = Conv(c_, c_, k, 1)
        # depth 3
        self.cv5 = Conv(c_, c_, k, 1)
        self.cv6 = Conv(c_, c_, k, 1)
        # depth 4
        self.cv7 = Conv(c_, c_, k, 1)
        self.cv8 = Conv(c_, c_, k, 1)
        
        self.act_idx = [0, 1, 3, 5, 7][:depth+1] 
    
    def forward(self, x):
        outputs = []
        # depth 1
        x1 = self.cv1(x)
        outputs.append(x1)
        x2 = self.cv2(x)    
        outputs.append(x2)
        # depth 2
        x3 = self.cv3(x2)
        outputs.append(x3)
        x4 = self.cv4(x3)
        outputs.append(x4)
        # depth 3
        x5 = self.cv5(x4)
        outputs.append(x5)
        x6 = self.cv6(x5)
        outputs.append(x6)
        # depth 4
        x7 = self.cv7(x6)
        outputs.append(x7)
        x8 = self.cv8(x7)
        outputs.append(x8)
        
        return torch.cat([outputs[i] for i in self.act_idx], dim=1)

In [26]:
input = torch.randn(1, 128, 64, 64)
block = BBoneELANBlock(c1=128, k=3, depth=4)

In [27]:
block(input).shape

torch.Size([1, 320, 64, 64])

In [28]:
# ELANBlock for Head
# there are differences about cardinality(path) and channel size
class HEADELANBlock(nn.Module):
    def __init__(self, c1, k, depth):
        super(HEADELANBlock, self).__init__()
        assert c1 % 2 == 0 and depth < 6
        c_ = int(c1 / 2)
        c_2 = int(c_ / 2)
        self.depth = depth
        
        # depth 1
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        # depth 2
        self.cv3 = Conv(c_, c_2, k, 1)
        # depth 3
        self.cv4 = Conv(c_2, c_2, k, 1)
        # depth 4
        self.cv5 = Conv(c_2, c_2, k, 1)
        # depth 5
        self.cv6 = Conv(c_2, c_2, k, 1)
        
        self.act_idx = [0, 1, 2, 3, 4, 5, 6][:depth+1] 
    
    def forward(self, x):
        outputs = []
        # depth 1
        x1 = self.cv1(x)
        outputs.append(x1)
        x2 = self.cv2(x)    
        outputs.append(x2)
        # depth 2
        x3 = self.cv3(x2)
        outputs.append(x3)
        # depth 3
        x4 = self.cv4(x3)
        outputs.append(x4)
        # depth 4
        x5 = self.cv5(x4)
        outputs.append(x5)
        # depth 5
        x6 = self.cv6(x5)
        outputs.append(x6)
        
        return torch.cat([outputs[i] for i in self.act_idx], dim=1)

In [29]:
input = torch.randn(1, 128, 64, 64)
block = HEADELANBlock(c1=128, k=3, depth=5)

In [30]:
block(input).shape

torch.Size([1, 256, 64, 64])

## ETC

In [51]:
# ELANBlock for Backbone
class ELANBlock(nn.Module):
    def __init__(self, c1, k, depth):
        super(ELANBlock, self).__init__()
        assert c1 % 2 == 0 and depth < 5
        c_ = int(c1 / 2)
        c2 = int(c1 / 2 * (depth+1))
        self.depth = depth
        # depth 1
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        # depth 2
        self.cv3 = Conv(c_, c_, 3, 1)
        self.cv4 = Conv(c_, c_, 3, 1)
        # depth 3
        self.cv5 = Conv(c_, c_, 3, 1)
        self.cv6 = Conv(c_, c_, 3, 1)
        # depth 4
        self.cv7 = Conv(c_, c_, 3, 1)
        self.cv8 = Conv(c_, c_, 3, 1)
        
        self.cat = Concat(dimension=1)
        
        self.blocks, self.act_idx = self.set_blocks(c1, k, s, depth)
    
    def set_blocks(self, c1, k, s, depth):
        c_ = int(c1 / 2)
        layers = []
        # depth에 따라 from 이 결정됨.
        # depth = 3
        act_idx = [0, 1, 3, 5, 7][:depth] 
        f = [-1, -3, -5, -6]
        for i in range(depth):
            if i == 0:
                # depth 1
                m1, m2 = Conv(c1, c_, 1, 1), Conv(c1, c_, 1, 1)
                m1.i, m1.f = i, 0
                m2.i, m2.f = i+1, 0
                layers.append(m1)
                layers.append(m2)
                # layers.append(Conv(c1, c_, 1, 1))
                # layers.append(Conv(c1, c_, 1, 1))
            else:
                # another depth
                m3, m4 = Conv(c_, c_, k, 1), Conv(c_, c_, k, 1)
                m3.i, m3.f = 2*i, 2*i-1
                m4.i, m4.f = 2*i+1, 2*i
                layers.append(m3)
                layers.append(m4)
                # layers.append(Conv(c_, c_, k, 1))
                # layers.append(Conv(c_, c_, k, 1))
                m = nn.Sequential(
                    Conv(c_, c_, k, 1),
                    Conv(c_, c_, k, 1)
                )
                
        
        layers.append(Concat(1))
        return nn.Sequential(*layers), act_idx
    
    # def forward(self, x):
    #     x1 = self.cv1(x)
    #     x2 = self.cv2(x)    
    #     # depth 2
    #     x3 = self.cv3(x2)
    #     x4 = self.cv4(x3)
    #     # depth 3
    #     x5 = self.cv5(x4)
    #     x6 = self.cv6(x5)
        
    #     x = self.depth
    #     return torch.cat([x1, x2, x4, x6], dim=1)
    
    def forward(self, x):
        outputs = []
        x1 = self.cv1(x)
        outputs.append(x1)
        x2 = self.cv2(x)    
        outputs.append(x2)
        # depth 2
        x3 = self.cv3(x2)
        outputs.append(x3)
        x4 = self.cv4(x3)
        outputs.append(x4)
        # depth 3
        x5 = self.cv5(x4)
        outputs.append(x5)
        x6 = self.cv6(x5)
        outputs.append(x6)
        # depth 4
        x7 = self.cv7(x6)
        outputs.append(x7)
        x8 = self.cv8(x7)
        outputs.append(x8)
        
        return torch.cat([outputs[i] for i in self.act_idx], dim=1) # 0, 1, 3, 5 
        # return torch.cat([x1, x2, x4, x6], dim=1)
    
    def forward_blocks(self, x, depth=3):
        self.saved = []
        output = []
        
        for i, block in enumerate(self.blocks):
            if isinstance(block.f, int):
                output.append(block(x))
            
        return self.cat(output[self.act_idx])            

        
    
        

In [79]:
input = torch.randn(1, 128, 64, 64)
block = ELANBlock(c1=128, k=3, depth=4)

In [80]:
block(input).shape

torch.Size([1, 320, 64, 64])

5.0