In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Inception(nn.Module):
    def __init__(self, inChannels, outChannels):
        super(Inception, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv3d(inChannels, outChannels, kernel_size=1, padding=0),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv3d(inChannels, outChannels, kernel_size=1, padding=0),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
            nn.Conv3d(outChannels, outChannels, kernel_size=3, padding=1),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        self.layer3 = nn.Sequential(
            nn.Conv3d(inChannels, outChannels, kernel_size=3, padding=1),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        self.layer3_1 = nn.Sequential(
            nn.Conv3d(outChannels, outChannels, kernel_size=[3, 1, 1], padding=[1, 0, 0]),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        self.layer3_2 = nn.Sequential(
            nn.Conv3d(outChannels, outChannels, kernel_size=[1, 3, 1], padding=[0, 1, 0]),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        self.layer3_3 = nn.Sequential(
            nn.Conv3d(outChannels, outChannels, kernel_size=[1, 1, 3], padding=[0, 0, 1]),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        self.layer4_1 = nn.Sequential(
            nn.MaxPool3d(kernel_size=3, stride=1, padding=1),
        )
        self.layer4_2 = nn.Sequential(
            nn.Conv3d(inChannels, outChannels, kernel_size=3, padding=1),
            nn.BatchNorm3d(outChannels),
            nn.ReLU(inplace=True),
        )
        

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x)
        
        x3 = self.layer3(x)
        x3_1 = self.layer3_1(x3)
        x3_2 = self.layer3_2(x3)
        x3_3 = self.layer3_3(x3)
        
        x4_1 = self.layer4_1(x)
        x4 = self.layer4_2(x4_1)
        
        return torch.cat([x1, x2, x3_1, x3_2, x3_3, x4], dim=1)

In [3]:
class AttentionBlock(nn.Module):
    def __init__(self):
        super(AttentionBlock, self).__init__()
        self.downsample = nn.MaxPool3d(kernel_size=2, stride=2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        o = x
        print(x.shape)
        x = self.downsample(x)
        print(x.shape)
        x = self.downsample(x)
        print(x.shape)
        x = F.interpolate(x, scale_factor=2, mode='trilinear')
        print(x.shape)
        x = F.interpolate(x, size=o.shape[2:], mode='trilinear')
        print(x.shape)
        x = self.sigmoid(x)
        return x

In [2]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.inception1 = Inception(1, 3)
        self.inception2 = Inception(18, 6)
        self.inception3 = Inception(36, 6)
        self.inception4 = Inception(36, 6)
        self.maxpool = nn.MaxPool3d(kernel_size=2, stride=2, padding=1)
        self.dropLayer = nn.Dropout(p = 0.5)
        self.fc = nn.Sequential(
            nn.Linear(36 * 7 * 8 * 7, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 7)
        )
        self.attention = AttentionBlock()
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.inception1(x)
        x = self.maxpool(x)
        x_a = self.attention(x)
        x = self.relu(x.mul(x_a) + x)
        x = self.inception2(x)
        x = self.maxpool(x)
        x_a = self.attention(x)
        x = self.relu(x.mul(x_a) + x)
        x = self.inception3(x)
        x = self.maxpool(x)
        x_a = self.attention(x)
        x = self.relu(x.mul(x_a) + x)
        x = self.inception4(x)
        x = self.maxpool(x)
        x = self.dropLayer(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x