In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [83]:
class MFA(nn.Module):
    def __init__(self, input_channel, input_dim):
        super(MFA, self).__init__()
        
        self.width = input_channel * input_dim
        
        self.gap = nn.AdaptiveAvgPool3d((None, None,1))
        self.flatten = nn.Flatten()
        self.layer = nn.Linear(self.width, self.width)
    
        self.flatten_tdnn = nn.Flatten(1,2)
        
        #TDNN with Conv1d
        self.cnn = nn.Conv1d(self.width, input_channel, kernel_size=1)
        self.relu = nn.ReLU()       
        self.bn = nn.BatchNorm1d(input_channel)
        
    
    def forward(self, x):
        x_ = x
        x = self.gap(x)
        
        x = x.squeeze(-1)
        b,c,d = x.size()
        
        x = self.flatten(x)
        x = self.layer(x)
        
        x = x.reshape(-1, c, d)
        x = x.unsqueeze(-1)*x_
        
        x = self.flatten_tdnn(x)
        
        x = self.cnn(x)
        
        x = self.relu(x)
        x = self.bn(x)
        #print(x.shape)
        x = x.unsqueeze(2)
        #print(x.shape)
        return x

In [84]:
mfa = MFA(8, 80)
mfa.eval()

MFA(
  (gap): AdaptiveAvgPool3d(output_size=(None, None, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer): Linear(in_features=640, out_features=640, bias=True)
  (flatten_tdnn): Flatten(start_dim=1, end_dim=2)
  (cnn): Conv1d(640, 8, kernel_size=(1,), stride=(1,))
  (relu): ReLU()
  (bn): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [85]:
test = torch.rand(1, 8, 80, 200)
output = mfa(test)

In [108]:
class DpMsModule(nn.Module):
    def __init__(self, scale, channel_output, input_dim):
        super(DpMsModule, self).__init__()
        self.cnn1 = nn.Conv2d(1, channel_output, kernel_size=3, padding=(1,1))
        self.cnn2 = nn.Conv2d(channel_output, channel_output, kernel_size=3, padding=(1,1))
        self.width = channel_output//scale
        self.scale = scale

        self.mfa1 = MFA(self.width, input_dim)
        self.mfa2 = MFA(self.width, input_dim)
        self.mfa3 = MFA(self.width, input_dim)
        self.mfa4 = MFA(self.width, input_dim)

        self.cnn3_1 = nn.Conv2d(self.width, self.width, kernel_size=3, padding='same')
        self.cnn3_2 = nn.Conv2d(self.width, self.width, kernel_size=3, padding='same')
        self.cnn3_3 = nn.Conv2d(self.width, self.width, kernel_size=3, padding='same')

        self.flatten = nn.Flatten(1,2)
        
        self.cnn4 = nn.Conv1d(channel_output, channel_output, kernel_size=1)

    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        x1, x2, x3, x4 = torch.split(x, self.width, dim=1)

        x1 = self.mfa1(x1)
        x2 = self.cnn3_1(x2)
        x2_ = x2
        x2 = self.mfa2(x1*x2)
        
        x3 = self.cnn3_2(x2_+x3)
        x3_ = x3
        x3 = self.mfa3(x2*x3)
        
        x4 = self.cnn3_3(x3_+x4)
        x4 = self.mfa4(x3*x4)


        y = torch.cat((x1,x2,x3,x4), 1)
        print(y.size())
        y = self.flatten(y)
        print(y.size())
        y_ = y
        y= self.cnn4(y)
        y= y+y_

        #batch, c, d, l = x1.size()
        
        return y

In [109]:
test = torch.rand(1, 1,80,200)

In [110]:
check = DpMsModule(4,32, 80)
check.eval()

DpMsModule(
  (cnn1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (cnn2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (mfa1): MFA(
    (gap): AdaptiveAvgPool3d(output_size=(None, None, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (layer): Linear(in_features=640, out_features=640, bias=True)
    (flatten_tdnn): Flatten(start_dim=1, end_dim=2)
    (cnn): Conv1d(640, 8, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mfa2): MFA(
    (gap): AdaptiveAvgPool3d(output_size=(None, None, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (layer): Linear(in_features=640, out_features=640, bias=True)
    (flatten_tdnn): Flatten(start_dim=1, end_dim=2)
    (cnn): Conv1d(640, 8, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mfa3): MFA(
    (g

In [111]:
output = check(test)

torch.Size([1, 8, 1, 200])
torch.Size([1, 8, 1, 200])
torch.Size([1, 8, 1, 200])
torch.Size([1, 8, 1, 200])
torch.Size([1, 32, 1, 200])
torch.Size([1, 32, 200])


In [238]:
check = mfa(output)

torch.Size([1, 8, 76])
torch.Size([1, 8, 76, 196])
torch.Size([1, 8, 76, 196])
torch.Size([1, 608, 196])
torch.Size([1, 608, 196])


In [160]:
test2 = nn.AdaptiveAvgPool3d((None, None,1))

In [161]:
test3 = test2(output)

In [162]:
print(test3.shape)
test3 = test3.squeeze(-1)
print(test3.shape)

torch.Size([1, 8, 76, 1])
torch.Size([1, 8, 76])


In [163]:
layer = nn.Flatten()

In [164]:
output = layer(test3)

In [165]:
output.size()

torch.Size([1, 608])

In [166]:
layer2 = nn.Linear(608, 608)
output = layer2(output)
output = output.reshape(-1,8, 76)

In [169]:
print(output.size())
print(output_.size())
final = output.unsqueeze(-1)*output_
print(final.shape)

torch.Size([1, 8, 76])
torch.Size([1, 8, 76, 196])
torch.Size([1, 8, 76, 196])


In [153]:
import torch.nn as nn

class SEBlock(nn.Module):
    def __init__(
        self,
        input_channel,
        scaled_channel=128
    ):
        super(SEBlock, self).__init__()
        self.block = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(input_channel, scaled_channel, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(scaled_channel,input_channel, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(
        self,
        input
    ):
        print(input.shape)
        output = self.block(input)
        print(output.shape)
        return output*input

In [154]:
test = torch.rand(1, 80,150)
check = SEBlock(80)

In [155]:
output = check(test)

torch.Size([1, 80, 150])
torch.Size([1, 80, 1])
