In [23]:
import torch
import torch.nn as nn

# yolov7 commons 
class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        p=k//2
        self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))
    
class Concat(nn.Module):
    def __init__(self, dimension=1):
        super(Concat, self).__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)

Backbone ELAN Module

In [24]:
class BBone_ELAN(nn.Module):
    def __init__(self, chan1, chan2, ker, depth):
        super(BBone_ELAN, self).__init__()

        self.chan1 = chan1 # input channel size
        self.chan2 = chan2 # output channel size (= input_channel // 2)
        self.ker = ker # kernel size
        self.depth = depth # depth (could be variable)

        self.idx = [idx for idx in range(self.depth * 2) if (idx % 2 == 1 or idx == 0)] #include idx 0 always

        elans = {} # elan module dictionary
        for d in range(self.depth*2):
            elans['BBoneELAN_{0}'.format(d+1)] = self.construct_elan(d) # ELAN block consturct

        self.elan_dict = nn.ModuleDict(elans) # ELAN module dictionary to ModuleDict
        self.cat = Concat(dimension=1) # from yolov7 modules

    def construct_elan(self, depth):
        '''
        construct ELAN modules
        '''
        elan = nn.Sequential()

        if depth == 0:
            elan.add_module('Conv_1', Conv(self.chan1, self.chan2, 1, 1))

        else:
            for i in range(depth):
              if i == 0:
                elan.add_module('Conv_{0}'.format(i+2), Conv(self.chan1, self.chan2, 1, 1))
              else:
                elan.add_module('Conv_{0}'.format(i+2), Conv(self.chan2, self.chan2, self.ker, 1))

        return elan
    
    def forward(self, x):
        tmp_out = []
        
        # idx에 해당하는것 먼저 select하고 for문 돌리는게 나을 수도?

        for _, elan in self.elan_dict.items(): # get each elan module output
          tmp_out.append(elan(x))
        
        out = self.cat([tmp_out[d] for d in self.idx]) # concat

        return out

In [25]:
input = torch.randn(1, 128, 64, 64) #random input

# depth 2 -> 192
block = BBone_ELAN(chan1=128, chan2=64, ker=3, depth=2)
print('Backbone EALN output shape (depth = 2) :', block(input).shape)

# depth 3 -> 256
block = BBone_ELAN(chan1=128, chan2=64, ker=3, depth=3)
print('Backbone EALN output shape (depth = 3) :', block(input).shape)

# depth 4 -> 320
block = BBone_ELAN(chan1=128, chan2=64, ker=3, depth=4)
print('Backbone EALN output shape (depth = 4) :', block(input).shape)

# depth 5 -> 384
block = BBone_ELAN(chan1=128, chan2=64, ker=3, depth=5)
print('Backbone EALN output shape (depth = 5) :', block(input).shape)

Backbone EALN output shape (depth = 2) : torch.Size([1, 192, 64, 64])
Backbone EALN output shape (depth = 3) : torch.Size([1, 256, 64, 64])
Backbone EALN output shape (depth = 4) : torch.Size([1, 320, 64, 64])
Backbone EALN output shape (depth = 5) : torch.Size([1, 384, 64, 64])


Head ELAN Module

In [26]:
class Head_ELAN(nn.Module):
    '''
    Head ELAN
    '''
    def __init__(self, chan1, chan2, ker, depth):
        super(Head_ELAN, self).__init__()

        self.chan1 = chan1
        self.chan2 = chan2
        self._chan = chan2 // 2 # hidden channel (channel2 // 2)
        self.ker = ker 
        self.depth = depth 

        self.idx = [idx for idx in range(self.depth+1)] # use all idx

        elans = {}
        for d in range(self.depth+1):
            elans['HeadELAN_{0}'.format(d+1)] = self.construct_elan(d)

        self.elan_dict = nn.ModuleDict(elans)
        self.cat = Concat(dimension=1)

    def construct_elan(self, depth):
        elan = nn.Sequential()

        if depth == 0:
            elan.add_module('Conv_1', Conv(self.chan1, self.chan2, 1, 1))

        else:
          for i in range(depth):
            if i == 0:
                elan.add_module('Conv_{0}'.format(i+2), Conv(self.chan1, self.chan2, 1, 1))
            elif i == 1:
                elan.add_module('Conv_{0}'.format(i+2), Conv(self.chan2, self._chan, self.ker, 1))
            else:
                elan.add_module('Conv_{0}'.format(i+2), Conv(self._chan, self._chan, self.ker, 1))

        return elan
    
    def forward(self, x):
        tmp_out = []

        for _, elan in self.elan_dict.items():
          tmp_out.append(elan(x))
        
        out = self.cat([tmp_out[d] for d in self.idx])

        return out

In [27]:
input = torch.randn(1, 512, 64, 64) #random input

# depth 2 -> 640
block = Head_ELAN(chan1=512, chan2=256, ker=3, depth=2)
print('Head EALN output shape (depth = 2) :', block(input).shape)

# depth 3 -> 768
block = Head_ELAN(chan1=512, chan2=256, ker=3, depth=3)
print('Head EALN output shape (depth = 3) :', block(input).shape)

# depth 4 -> 896
block = Head_ELAN(chan1=512, chan2=256, ker=3, depth=4)
print('Head EALN output shape (depth = 4) :', block(input).shape)

# depth 5 -> 1024
block = Head_ELAN(chan1=512, chan2=256, ker=3, depth=5)
print('Head EALN output shape (depth = 5) :', block(input).shape)

Head EALN output shape (depth = 2) : torch.Size([1, 640, 64, 64])
Head EALN output shape (depth = 3) : torch.Size([1, 768, 64, 64])
Head EALN output shape (depth = 4) : torch.Size([1, 896, 64, 64])
Head EALN output shape (depth = 5) : torch.Size([1, 1024, 64, 64])
