In [None]:
# -*- coding = utf-8 -*
# @Time：  10:07
# @File: MAF_CNN.py
# @Software: PyCharm
import torch
import torch.nn as nn


class SELayer(nn.Module):
    def __init__(self, channel=32, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        print("SE input (x) by concat: ([32, 3], [32, 2], [32, 3]): ", x.shape)
        b, c, _ = x.size()
        print("b, c: {}, {}".format(b, c))
        y = self.avg_pool(x).view(b, c)
        print("after squeeze (squeeze the last dimension):", y.shape)
        y = self.fc(y).view(b, c, 1)
        print("after exication (y): ", y.shape)
        print("y.expand_as(x): ", y.expand_as(x).shape)
        return x * y.expand_as(x)


class MSA(nn.Module):
    def __init__(self):
        super(MSA, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(64, 32, (5,), (1,), dilation=(2,)),
            nn.BatchNorm1d(32),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(64, 32, (4,), (1,), dilation=(3,)),
            nn.BatchNorm1d(32),
            nn.ReLU(),
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(64, 32, (3,), (1,), dilation=(4,)),
            nn.BatchNorm1d(32),
            nn.ReLU(),
        )
        self.se = SELayer()

    def forward(self, x):
        x1 = self.conv1(x)
        # print(x1.shape)
        x2 = self.conv2(x)
        # print(x2.shape)
        x3 = self.conv3(x)
        # print(x3.shape)
        out = torch.cat([x1, x2, x3], dim=2)
        # print(out.shape)
        out = self.se(out)
        # print(out.shape)
        return out


class MAF_CNN(nn.Module):
    def __init__(self, num_classes):
        super(MAF_CNN, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv1d(1, 64, (256,), (32,)),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.MaxPool1d(8, 8),

            nn.Dropout(),

            nn.Conv1d(64, 128, (7,), (1,)),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 64, (7,), (1,)),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.MaxPool1d(4, 4),
        )

        self.dropout = nn.Dropout()
        self.msa = MSA()
        self.ft = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
        )
        self.out = nn.Linear(128, num_classes)

    def forward(self, x1):
        x1 = self.cnn1(x1)
        print("after CNN: ", x1.shape)

        x1 = self.msa(x1)
        print("after MSA (x * y.expand_as(x)): ", x1.shape)

        x_concat = x1
        # print(x_concat.shape)
        x = self.dropout(x_concat)
        x = self.ft(x)

        out = self.fc(x)
        x = self.out(out)
        return out, x

In [None]:
import torch
from my_model import SimpleSleepPPGModel
from torchsummary import summary

device = torch.device("cuda")
net = MAF_CNN(2).to(device)
model = net.cuda()
summary(model, (1, 15360))

after CNN:  torch.Size([2, 64, 11])
SE input (x) by concat: ([32, 3], [32, 2], [32, 3]):  torch.Size([2, 32, 8])
b, c: 2, 32
after squeeze (squeeze the last dimension): torch.Size([2, 32])
after exication (y):  torch.Size([2, 32, 1])
y.expand_as(x):  torch.Size([2, 32, 8])
after MSA (x * y.expand_as(x)):  torch.Size([2, 32, 8])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1              [-1, 64, 473]          16,448
       BatchNorm1d-2              [-1, 64, 473]             128
              ReLU-3              [-1, 64, 473]               0
         MaxPool1d-4               [-1, 64, 59]               0
           Dropout-5               [-1, 64, 59]               0
            Conv1d-6              [-1, 128, 53]          57,472
       BatchNorm1d-7              [-1, 128, 53]             256
              ReLU-8              [-1, 128, 53]               0
            Conv1d-9         

: 