In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# softmax test

# Mutex Attention Block

In [2]:
class Mutex_block(nn.Module):
    def __init__(self):
        super(Mutex_block, self).__init__()
        # self.batch = batch
        # self.H = H
        # self.C = C
        # self.W = W
        self.softmax = nn.Softmax(dim=1)
        
        
    def forward(self, Frei, Frem):
        batch, C, H, W = Frei.shape
        
        Fam = torch.subtract(Frei, Frem)
        Fam = torch.pow(Fam, 2)
        
        #reshape
        Fam = Fam.view(batch, C, H*W)
        
        #softmax
        Fam = self.softmax(Fam)

        #reshape
        Fam = Fam.view(batch, C, H, W)
        
        #multiplication
        Fam = torch.mul(Fam, Frem)
        
        return Fam

# Fusion Attention Block

In [38]:
class Fusion_block(nn.Module):
    def __init__(self, in_fc, out_fc, pool_size):
        super(Fusion_block, self).__init__()
        self.in_fc = in_fc
        self.out_fc = out_fc
        
        self.softmax = nn.Softmax(dim=1)
        self.avgpool = nn.AvgPool2d(pool_size)
        self.maxpool = nn.MaxPool2d(pool_size)
        
        self.softmax = nn.Softmax(dim=0)
        
        # Set these parameters
        self.fcC = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(in_fc, out_fc),
            nn.ReLU())
        
        self.fcC_prim = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(out_fc, out_fc),
            nn.ReLU())
        
        self.fcM = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(out_fc, out_fc),
            nn.ReLU())
        
        self.fcN = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(out_fc, out_fc),
            nn.ReLU())
        
    def forward(self, Fam, Frem):
        f_temp = torch.add(Fam, Frem)
        # print(f'first add: {f_temp.shape}')
        
        avg_pool = self.avgpool(f_temp).squeeze()
        max_pool = self.maxpool(f_temp).squeeze()
        f_temp = torch.add(avg_pool, max_pool)
        # print(f'add after poolings: {f_temp.shape}')
        
        f_temp = f_temp.view(f_temp.size(0), -1)
        print(f'bifore first fc: {f_temp.shape}')
        
        
        f_temp = self.fcC(f_temp)
        print(f'after first fc: {f_temp.shape}')
        f_temp = self.fcC_prim(f_temp)
        print(f'after second fc: {f_temp.shape}')
        
        
        fM = self.fcM(f_temp)
        print(f'size of fM: {fM.shape}')
        fN = self.fcN(f_temp)
        print(f'size of fN: {fN.shape}')
        
        # unsqueeze
        fM = torch.unsqueeze(fM, 0)
        fN = torch.unsqueeze(fN, 0)
        # print(f'unsqueeze fM: {fM.shape}\nunsqueeze fN: {fN.shape}')
        
        # concatenate
        z = torch.concat([fM, fN], dim=0)
        # print(f'size of concat fM, fN: {z.shape}')
        
        # softmax
        z = self.softmax(z)
        # print(f'after softmax: {z.shape}')
        
        fM = z[0,...]
        # print(f'separate fM: {fM.shape}')
        fN = z[1,...]
        # print(f'separate fN: {fN.shape}')
        
        
        # print(f'size of fM after view: {fM.view(fM.size(0), fM.size(1), 1, 1).shape}')
        
        # fM --> [B, C]
        # Fam --> [B, C, H, W]
        Ffm_m = torch.mul(Fam, fM.view(fM.size(0), fM.size(1), 1, 1))
        Ffm_n = torch.mul(Frem, fN.view(fN.size(0), fN.size(1), 1, 1))
        
        Ffm = torch.add(Ffm_m, Ffm_n)
        
        return Ffm

# ResNet

In [39]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
              nn.Conv2d(in_planes, self.expansion*planes, 
                        kernel_size=1, stride=stride, bias=False),
              nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


In [40]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, Mutex_attention, Fusion_attention, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.layer0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64)
        )
        self.mutex_block_0 = Mutex_attention()
        self.Fusion_block_0 = Fusion_attention(64, 64 , 54)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.mutex_block_1 = Mutex_attention()
        self.Fusion_block_1 = Fusion_attention(256, 256, 54)
        
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.mutex_block_2 = Mutex_attention()
        self.Fusion_block_2 = Fusion_attention(512, 512, 27)
        
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.mutex_block_3 = Mutex_attention()
        self.Fusion_block_3 = Fusion_attention(1024, 1024, 14)
        
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.mutex_block_4 = Mutex_attention()
        self.Fusion_block_4 = Fusion_attention(2048, 2048, 7)
        
        # pooling
        self.avgpool = nn.AvgPool2d(7)
        
        self.fc_o_1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 512),
            nn.ReLU())
        self.fc_m_1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 512),
            nn.ReLU())
        
        self.fc_o_2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 2),
            nn.ReLU())
        self.fc_m_2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 2),
            nn.ReLU())
        
        self.softmax_out = nn.Softmax(dim=1)
        
        # self.linear = nn.Linear(512*block.expansion, num_classes)
        # self.linear = nn.Linear(100352, num_classes)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x, mutex_x):
        
        
        # Res_layer_0
        print('Res_layer_0')
        Frei = F.relu(self.layer0(x))
        Frei = nn.MaxPool2d(kernel_size = 3, stride = 2)(Frei)
        
        Frem = F.relu(self.layer0(mutex_x))
        Frem = nn.MaxPool2d(kernel_size = 3, stride = 2)(Frem)
        
        Fam = self.mutex_block_0(Frei, Frem)
        
        Fom = self.Fusion_block_0(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        # print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_1
        print('Res_layer_1')
        
        Frei = F.relu(self.layer1(Fi))
        
        Frem = F.relu(self.layer1(Fm))
        
        Fam = self.mutex_block_1(Frei, Frem)
        
        Fom = self.Fusion_block_1(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        # print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_2
        print('Res_layer_2')
        
        Frei = F.relu(self.layer2(Fi))
        
        Frem = F.relu(self.layer2(Fm))
        
        Fam = self.mutex_block_2(Frei, Frem)
        
        Fom = self.Fusion_block_2(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        # print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_3
        print('Res_layer_3')
        
        Frei = F.relu(self.layer3(Fi))
        
        Frem = F.relu(self.layer3(Fm))
        
        Fam = self.mutex_block_3(Frei, Frem)
        
        Fom = self.Fusion_block_3(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        # print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # Res_layer_4
        print('Res_layer_4')
        
        Frei = F.relu(self.layer4(Fi))
        
        Frem = F.relu(self.layer4(Fm))
        
        Fam = self.mutex_block_4(Frei, Frem)
        
        Fom = self.Fusion_block_4(Fam, Frem)
        
        Fo = Frei
        
        # Set names
        Fi = Fo
        Fm = Fom
        
        #set Vi, Vm for calculate cosine loss
        Vi = Fo.clone()
        Vm = Fom.clone()
        # print(f'shape of Fo is: {Fo.shape}')
        print(50*'-')
        
        # ---------------------------------------------------------
        
        # global pooling
        Fo = self.avgpool(Fo).squeeze()
        Fom = self.avgpool(Fom).squeeze()
        # print(f'shape of Fo , Fm after pooling:\n \
        #         {Fo.shape}, {Fom.shape}')
        
        # ---------------------------------------------------------
        
        # final_fully_layer
        Fo = self.fc_o_1(Fo)
        Fom = self.fc_m_1(Fom)
        
        Fo = self.fc_o_2(Fo)
        Fom = self.fc_m_2(Fom)
        
        Fo = self.softmax_out(Fo)
        Fom = self.softmax_out(Fom)
        
        # print(f'Fo finel:{Fo.shape}')
        # print(f'Fm finel:{Fom.shape}')
        
        # out = self.layer1(out)
        # out = self.layer2(out)
        # out = self.layer3(out)
        # out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = out.view(out.size(0), -1)
        # # print(out.shape)
        # out = self.linear(out)
        return Fo, Fom, Vi, Vm


In [41]:
model = ResNet(Bottleneck, [3, 4, 6, 3], Mutex_block, Fusion_block)

In [42]:
Fo, Fom, Vi, Vm = model(torch.randn(16, 3, 224, 224), torch.randn(16, 3, 224, 224))
print(Vi.shape)
print(Vm.shape)
Fo.shape, Fom.shape

Res_layer_0
bifore first fc: torch.Size([16, 64])
after first fc: torch.Size([16, 64])
after second fc: torch.Size([16, 64])
size of fM: torch.Size([16, 64])
size of fN: torch.Size([16, 64])
--------------------------------------------------
Res_layer_1
bifore first fc: torch.Size([16, 256])
after first fc: torch.Size([16, 256])
after second fc: torch.Size([16, 256])
size of fM: torch.Size([16, 256])
size of fN: torch.Size([16, 256])
--------------------------------------------------
Res_layer_2
bifore first fc: torch.Size([16, 512])
after first fc: torch.Size([16, 512])
after second fc: torch.Size([16, 512])
size of fM: torch.Size([16, 512])
size of fN: torch.Size([16, 512])
--------------------------------------------------
Res_layer_3
bifore first fc: torch.Size([16, 1024])
after first fc: torch.Size([16, 1024])
after second fc: torch.Size([16, 1024])
size of fM: torch.Size([16, 1024])
size of fN: torch.Size([16, 1024])
--------------------------------------------------
Res_layer_4


(torch.Size([16, 2]), torch.Size([16, 2]))

In [34]:
input_names = ['ct_image', 'mutex_image']
output_names = ['Fo', 'Fom']
torch.onnx.export(model, (torch.randn(16, 3, 224, 224), torch.randn(16, 3, 224, 224)), 'final_model.onnx', input_names=input_names, output_names=output_names)

# train

In [12]:
device = 'cpu'

In [13]:
labels_i = torch.randint(0, 2, (16,))
labels_x = torch.randint(0, 2, (16,))
# labels_onehot = F.one_hot(labels).squeeze()
image_i = torch.rand(16, 3, 224, 224)
image_x = torch.rand(16, 3, 224, 224)

In [14]:
Fo, Fom, Vi, Vm = model(image_i, image_x)

In [15]:
Fom.shape

torch.Size([16, 2])

In [28]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, weight_decay = 0.005, momentum = 0.9)

In [17]:
# initial losses
Lce_ = nn.CrossEntropyLoss()
Lcs_ = nn.CosineSimilarity(dim=1, eps=1e-6)

In [24]:
# calculate losses
L1 = Lce_(Fo, labels_i)
L2 = Lce_(Fom, labels_x)
Lce = L1 + L2
(L1, L2, Lce)

(tensor(0.6745, grad_fn=<NllLossBackward0>),
 tensor(0.6917, grad_fn=<NllLossBackward0>),
 tensor(1.3662, grad_fn=<AddBackward0>))

In [21]:
Lcs = torch.mean(Lcs_(Vi.view(Vi.size(0), -1), Vm.view(Vm.size(0), -1)))

In [25]:
Lce_exp = torch.exp(1/Lce)
Lcs_exp = torch.exp(1/Lcs)
a1 = Lce_exp / (Lce_exp + Lcs_exp)
a2 = Lcs_exp / (Lce_exp + Lcs_exp)

In [29]:
loss = (a1 * Lce) + (a2 * Lcs)

In [30]:
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Loss Fusnctions

In [43]:
x = torch.rand(16, 20, 7, 7)
y = torch.rand(16, 20, 7, 7)

In [57]:
x = x.view(x.size(0), -1)
y = y.view(y.size(0), -1)

In [58]:
print(x.shape, y.shape)

torch.Size([16, 980]) torch.Size([16, 980])


In [54]:
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

In [60]:
cos(x, y)

tensor([0.7478, 0.7598, 0.7650, 0.7458, 0.7323, 0.7597, 0.7598, 0.7419, 0.7604,
        0.7672, 0.7420, 0.7642, 0.7558, 0.7374, 0.7600, 0.7442])

In [32]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [33]:
pytorch_total_params

47922500