In [1]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
from collections import OrderedDict
import torch.nn.functional as F

In [2]:
class SE(nn.Module):
    def __init__(self, channel, out_chan, reduction_ratio =2, ):
        super(SE, self).__init__()
        ### squeeze
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        
        ### excitation
        self.excitation = nn.Sequential(
            nn.Linear(channel, channel // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction_ratio, channel, bias=False),
            nn.Sigmoid()
        )

        ### ??
        self.c = nn.Conv2d(channel, out_chan, 1, 1)

    def forward(self, x):
        if x.ndim == 4:
            b = x.size(0)
            c = x.size(1)
        if x.ndim == 5:
            b = x.size(1)
            c = x.size(2)
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).unsqueeze(2).unsqueeze(3)
        return self.c(x*y)

In [3]:
class RefNet_basic(nn.Module):
    def __init__(self, in_chan, out_chan, se=False, maxpool=False):
        super(RefNet_basic, self).__init__()

        self.conv = nn.Conv3d(in_chan, in_chan, 1, 1)
        self.relu = nn.ReLU()

        self.se = None
        self.maxpool = None

        if se:
            # self.se = torchvision.ops.SqueezeExcitation(in_chan, out_chan)
            self.se = SE(in_chan, out_chan)
            # self.se = SpatialSELayer3D(in_chan, out_chan)
        if maxpool:
            self.maxpool = nn.MaxPool2d(2,2)
            self.conv_2 = nn.Conv2d(in_chan, out_chan, 1, 1)

    def forward(self, X):
        # X = B, C, H, W
        x = X.transpose(0,1)
        x_conv =  self.conv(X)

        t = self.relu(x_conv)
        t = self.conv(t)
        t = self.relu(t)

        y = torch.add(x_conv, t)

        if self.se is not None:
            y = y.transpose(1,0)
            y = self.se(y)
        if self.maxpool is not None:
            y = self.maxpool(y)
            y = y.transpose(1,0)
            y = self.conv_2(y)


        y = y.transpose(1,0)
        
        return y

In [4]:
class RefNet(nn.Module):
    def __init__(self, in_chan):
        super(RefNet, self).__init__()

        self.down_1 = RefNet_basic(in_chan, 32, maxpool=True)
        self.att_1 = RefNet_basic(in_chan, 32, se=True)
        
        self.down_2 = RefNet_basic(32, 64, maxpool=True)
        self.conv_1 = nn.Conv2d(64, 128, 1 , 1)
        self.att_2 = RefNet_basic(32, 64, se=True)

        self.deconv_1 = nn.ConvTranspose2d(128, 64, 4, 2, 1)

        self.refbasic_1 = RefNet_basic(64, 64)

        self.deconv_2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)

        self.refbasic_2 = RefNet_basic(32, 32)


    
    def forward(self, X):
        # X : S, B, C, H, W
        # seq_number, batch_size, input_channel, height, width = X.size()
        # X = torch.reshape(X, (-1, input_channel, height, width))
        # X = torch.reshape(X, (seq_number, batch_size, X.size(1), X.size(2), X.size(3)))

        X = X.transpose(0,1)

        print('before att:', X.size())

        res_1 = self.att_1(X) # -> 32
        down_1 = self.down_1(X) # -> 32

        res_2 = self.att_2(down_1) # -> 64
        down_2 = self.down_2(down_1) # -> 64


        down_2 = down_2.transpose(1,0)
        conv_1 = self.conv_1(down_2) # -> 128
        down_2 = down_2.transpose(1,0)
        #conv_1 = torch.add(conv_1, conv_1) # ??

        deconv_1 = self.deconv_1(conv_1) # -> 64
        deconv_1 = deconv_1.transpose(1,0)
        deconv_1 = torch.add(deconv_1, res_2) # -> 64
        deconv_1 = self.refbasic_1(deconv_1) # -> 64

        deconv_2 = self.deconv_2(deconv_1) # -> 32
        deconv_2 = deconv_2.transpose(1,0)
        deconv_2 = torch.add(deconv_2, res_1) # -> 32
        deconv_2 = self.refbasic_2(deconv_2) # -> 32

        return nn.Conv2d(32, 1, 1, 1)(deconv_2) # -> 1


In [5]:
in_chan = 1
x = torch.randn(10,1,300,300)
x = x.transpose(1,0)

In [6]:
down_1 = RefNet_basic(in_chan, 32, maxpool=True)
att_1 = RefNet_basic(in_chan, 32, se=True)
down_2 = RefNet_basic(32, 64, maxpool=True)
conv_1 = nn.Conv2d(64, 128, 1 , 1)
att_2 = RefNet_basic(32, 64, se=True)
deconv_1 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
refbasic_1 = RefNet_basic(64, 64)
deconv_2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
refbasic_2 = RefNet_basic(32, 32)



In [7]:
print(x.size())
down1 = down_1(x)
att1 = att_1(x)
print("down1:", down1.size())
print("att1:", att1.size())

torch.Size([1, 10, 300, 300])
down1: torch.Size([32, 10, 150, 150])
att1: torch.Size([32, 10, 300, 300])


In [8]:
att2 = att_2(down1) # -> 64
down2 = down_2(down1) # -> 64
print("down2:", down2.size())
print("att2:", att2.size())

down2: torch.Size([64, 10, 75, 75])
att2: torch.Size([64, 10, 150, 150])


In [9]:
down2 = down2.transpose(1,0)
conv1 = conv_1(down2) # -> 128
down2 = down2.transpose(1,0)
print(conv1.size())

torch.Size([10, 128, 75, 75])


In [10]:
deconv1 = deconv_1(conv1) # -> 64
deconv1 = deconv1.transpose(1,0)
deconv1 = torch.add(deconv1, att2) # -> 64
deconv1 = refbasic_1(deconv1) # -> 64

print(deconv1.size())

torch.Size([10, 64, 150, 150])


In [11]:
deconv2 = deconv_2(deconv1) # -> 32
deconv2 = deconv2.transpose(1,0)
deconv2 = torch.add(deconv2, att1) # -> 32
deconv2 = refbasic_2(deconv2) # -> 32

output = nn.Conv2d(32, 1, 1, 1)(deconv2) # -> 1

print(deconv2.size())
print(output.size())

torch.Size([10, 32, 300, 300])
torch.Size([10, 1, 300, 300])


In [12]:
x = torch.randn(10,1,300,300)
RefNet(1)(x).size()

before att: torch.Size([1, 10, 300, 300])


torch.Size([10, 1, 300, 300])

---

In [13]:
x = torch.randn(10,128,1,128,128)
print(x.size())
seq_number, batch_size, input_channel, height, width = x.size()
torch.reshape(x, (-1, input_channel, height, width)).size()

torch.Size([10, 128, 1, 128, 128])


torch.Size([1280, 1, 128, 128])

In [14]:
x = torch.randn(5, 128, 1, 128, 128)
seq_number, batch_size, input_channel, height, width = x.size()

print("before:", x.size())
x = torch.reshape(x, (-1, input_channel, height, width))
print("after:", x.size())
y = RefNet(1)(x)
print("output:", y.size())
y = torch.reshape(y, (seq_number, batch_size, x.size(1), x.size(2), x.size(3)))
print("after output:", y.size())

before: torch.Size([5, 128, 1, 128, 128])
after: torch.Size([640, 1, 128, 128])
before att: torch.Size([1, 640, 128, 128])
output: torch.Size([640, 1, 128, 128])
after output: torch.Size([5, 128, 1, 128, 128])
