In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# softmax test

# Mutex Attention Block

In [2]:
class Mutex_block(nn.Module):
    def __init__(self):
        super(Mutex_block, self).__init__()
        # self.batch = batch
        # self.H = H
        # self.C = C
        # self.W = W
        self.softmax = nn.Softmax(dim=1)
        
        
    def forward(self, Frei, Frem):
        batch, C, H, W = Frei.shape
        
        Fam = torch.subtract(Frei, Frem)
        Fam = torch.pow(Fam, 2)
        
        #reshape
        Fam = Fam.view(batch, C, H*W)
        
        #softmax
        Fam = self.softmax(Fam)

        #reshape
        Fam = Fam.view(batch, C, H, W)
        
        #multiplication
        Fam = torch.mul(Fam, Frem)
        
        return Fam

# Fusion Attention Block

In [3]:
class Fusion_block(nn.Module):
    def __init__(self, in_fc, out_fc, pool_size):
        super(Fusion_block, self).__init__()
        self.in_fc = in_fc
        self.out_fc = out_fc
        
        self.softmax = nn.Softmax(dim=1)
        self.avgpool = nn.AvgPool2d(pool_size)
        self.maxpool = nn.MaxPool2d(pool_size)
        
        self.softmax = nn.Softmax(dim=0)
        
        # Set these parameters
        self.fcC = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(in_fc, out_fc),
            nn.ReLU())
        
        self.fcC_prim = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(out_fc, out_fc),
            nn.ReLU())
        
        self.fcM = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(out_fc, out_fc),
            nn.ReLU())
        
        self.fcN = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(out_fc, out_fc),
            nn.ReLU())
        
    def forward(self, Fam, Frem):
        f_temp = torch.add(Fam, Frem)
        print(f'first add: {f_temp.shape}')
        
        avg_pool = self.avgpool(f_temp).squeeze()
        max_pool = self.maxpool(f_temp).squeeze()
        f_temp = torch.add(avg_pool, max_pool)
        print(f'add after poolings: {f_temp.shape}')
        
        f_temp = f_temp.view(f_temp.size(0), -1)
        print(f'bifore first fc: {f_temp.shape}')
        
        
        f_temp = self.fcC(f_temp)
        print(f'after first fc: {f_temp.shape}')
        f_temp = self.fcC_prim(f_temp)
        print(f'after second fc: {f_temp.shape}')
        
        
        fM = self.fcM(f_temp)
        print(f'size of fM: {fM.shape}')
        fN = self.fcN(f_temp)
        print(f'size of fN: {fN.shape}')
        
        # unsqueeze
        fM = torch.unsqueeze(fM, 0)
        fN = torch.unsqueeze(fN, 0)
        print(f'unsqueeze fM: {fM.shape}\nunsqueeze fN: {fN.shape}')
        
        # concatenate
        z = torch.concat([fM, fN], dim=0)
        print(f'size of concat fM, fN: {z.shape}')
        
        # softmax
        z = self.softmax(z)
        print(f'after softmax: {z.shape}')
        
        fM = z[0,...]
        print(f'separate fM: {fM.shape}')
        fN = z[1,...]
        print(f'separate fN: {fN.shape}')
        
        
        print(f'size of fM after view: {fM.view(fM.size(0), fM.size(1), 1, 1).shape}')
        
        # fM --> [B, C]
        # Fam --> [B, C, H, W]
        Ffm_m = torch.mul(Fam, fM.view(fM.size(0), fM.size(1), 1, 1))
        Ffm_n = torch.mul(Frem, fN.view(fN.size(0), fN.size(1), 1, 1))
        
        Ffm = torch.add(Ffm_m, Ffm_n)
        
        return Ffm

# ResNet

In [4]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
              nn.Conv2d(in_planes, self.expansion*planes, 
                        kernel_size=1, stride=stride, bias=False),
              nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


In [5]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, Mutex_attention, Fusion_attention, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.layer0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64)
        )
        self.mutex_block_0 = Mutex_attention()
        self.Fusion_block_0 = Fusion_attention(64, 64 , 54)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.mutex_block_1 = Mutex_attention()
        self.Fusion_block_1 = Fusion_attention(256, 256, 54)
        
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.mutex_block_2 = Mutex_attention()
        self.Fusion_block_2 = Fusion_attention(512, 512, 27)
        
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.mutex_block_3 = Mutex_attention()
        self.Fusion_block_3 = Fusion_attention(1024, 1024, 14)
        
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.mutex_block_4 = Mutex_attention()
        self.Fusion_block_4 = Fusion_attention(2048, 2048, 7)
        
        self.linear = nn.Linear(512*block.expansion, num_classes)
        # self.linear = nn.Linear(100352, num_classes)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x, mutex_x):
        
        
        # Res_layer_0
        print('Res_layer_0')
        Frei = F.relu(self.layer0(x))
        Frei = nn.MaxPool2d(kernel_size = 3, stride = 2)(Frei)
        
        Frem = F.relu(self.layer0(mutex_x))
        Frem = nn.MaxPool2d(kernel_size = 3, stride = 2)(Frem)
        
        Fam = self.mutex_block_0(Frei, Frem)
        
        Fom = self.Fusion_block_0(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_1
        print('Res_layer_1')
        
        Frei = F.relu(self.layer1(Fi))
        
        Frem = F.relu(self.layer1(Fm))
        
        Fam = self.mutex_block_1(Frei, Frem)
        
        Fom = self.Fusion_block_1(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_2
        print('Res_layer_2')
        
        Frei = F.relu(self.layer2(Fi))
        
        Frem = F.relu(self.layer2(Fm))
        
        Fam = self.mutex_block_2(Frei, Frem)
        
        Fom = self.Fusion_block_2(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_3
        print('Res_layer_3')
        
        Frei = F.relu(self.layer3(Fi))
        
        Frem = F.relu(self.layer3(Fm))
        
        Fam = self.mutex_block_3(Frei, Frem)
        
        Fom = self.Fusion_block_3(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_4
        print('Res_layer_4')
        
        Frei = F.relu(self.layer4(Fi))
        
        Frem = F.relu(self.layer4(Fm))
        
        Fam = self.mutex_block_4(Frei, Frem)
        
        Fom = self.Fusion_block_4(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        
        # out = self.layer1(out)
        # out = self.layer2(out)
        # out = self.layer3(out)
        # out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = out.view(out.size(0), -1)
        # # print(out.shape)
        # out = self.linear(out)
        return Fo, Fom


In [6]:
model = ResNet(Bottleneck, [3, 4, 6, 3], Mutex_block, Fusion_block)

In [7]:
Fo, Fom = model(torch.randn(16, 3, 224, 224), torch.randn(16, 3, 224, 224))
print(Fo.shape)
print(Fom.shape)

Res_layer_0
first add: torch.Size([16, 64, 54, 54])
add after poolings: torch.Size([16, 64])
bifore first fc: torch.Size([16, 64])
after first fc: torch.Size([16, 64])
after second fc: torch.Size([16, 64])
size of fM: torch.Size([16, 64])
size of fN: torch.Size([16, 64])
unsqueeze fM: torch.Size([1, 16, 64])
unsqueeze fN: torch.Size([1, 16, 64])
size of concat fM, fN: torch.Size([2, 16, 64])
after softmax: torch.Size([2, 16, 64])
separate fM: torch.Size([16, 64])
separate fN: torch.Size([16, 64])
size of fM after view: torch.Size([16, 64, 1, 1])
shape of Fo is: torch.Size([16, 64, 54, 54])
--------------------------------------------------
Res_layer_1
first add: torch.Size([16, 256, 54, 54])
add after poolings: torch.Size([16, 256])
bifore first fc: torch.Size([16, 256])
after first fc: torch.Size([16, 256])
after second fc: torch.Size([16, 256])
size of fM: torch.Size([16, 256])
size of fN: torch.Size([16, 256])
unsqueeze fM: torch.Size([1, 16, 256])
unsqueeze fN: torch.Size([1, 16, 2

In [8]:
input_names = ['ct_image', 'mutex_image']
output_names = ['Fo', 'Fom']
torch.onnx.export(model, (torch.randn(16, 3, 224, 224), torch.randn(16, 3, 224, 224)), 'final_model.onnx', input_names=input_names, output_names=output_names)

Res_layer_0
first add: torch.Size([16, 64, 54, 54])
add after poolings: torch.Size([16, 64])
bifore first fc: torch.Size([16, 64])
after first fc: torch.Size([16, 64])
after second fc: torch.Size([16, 64])
size of fM: torch.Size([16, 64])
size of fN: torch.Size([16, 64])
unsqueeze fM: torch.Size([1, 16, 64])
unsqueeze fN: torch.Size([1, 16, 64])
size of concat fM, fN: torch.Size([2, 16, 64])
after softmax: torch.Size([2, 16, 64])
separate fM: torch.Size([16, 64])
separate fN: torch.Size([16, 64])
size of fM after view: torch.Size([16, 64, 1, 1])
shape of Fo is: torch.Size([16, 64, 54, 54])
--------------------------------------------------
Res_layer_1
first add: torch.Size([16, 256, 54, 54])
add after poolings: torch.Size([16, 256])
bifore first fc: torch.Size([16, 256])
after first fc: torch.Size([16, 256])
after second fc: torch.Size([16, 256])
size of fM: torch.Size([16, 256])
size of fN: torch.Size([16, 256])
unsqueeze fM: torch.Size([1, 16, 256])
unsqueeze fN: torch.Size([1, 16, 2