In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


from .TimeDistributed import TimeDistributed
from .ConvLSTM import ConvLSTM
from .unet_parts import *

class LSTM_UNet(nn.Module):
    def __init__(self, n_channels, n_classes, num_filter = 64, bilinear=False):
        super(LSTM_UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = TimeDistributed(DoubleConv(n_channels, num_filter))
        self.down1 = TimeDistributed(Down(num_filter, num_filter*2))
        self.down2 = TimeDistributed(Down(num_filter*2, num_filter*4))
        self.down3 = TimeDistributed(Down(num_filter*4, num_filter*8))
        factor = 2 if bilinear else 1
        self.down4 = TimeDistributed(Down(num_filter*8, num_filter*16 // factor))
        
        self.lstm = ConvLSTM(input_dim = num_filter*16 // factor,
                    hidden_dim=num_filter*16 // factor,
                    kernel_size=(3,3),
                    num_layers=2,
                    batch_first=True,
                    bias=True,
                    return_all_layers=False)
        
        self.lf1x1 = nn.Conv2d(num_filter*16, num_filter*8, kernel_size=1)
        self.lstm1x1 = nn.Conv2d(num_filter*16, num_filter*8, kernel_size=1)
        
        self.up1 = (Up(num_filter*16, num_filter*8 // factor, bilinear))
        self.up2 = (Up(num_filter*8, num_filter*4 // factor, bilinear))
        self.up3 = (Up(num_filter*4, num_filter*2 // factor, bilinear))
        self.up4 = (Up(num_filter*2, num_filter, bilinear))
        self.outc = (OutConv(num_filter, n_classes))
        
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        _ , [(h, _)] = self.lstm(x5)
        
        x6 = self.lf1x1(x5[:,-1,...])
    
        x = self.lstm1x1(h)
        
        x = self.up1(torch.cat([x,x6], dim=1), x4[:,-1,...])
        x = self.up2(x, x3[:,-1,...])
        x = self.up3(x, x2[:,-1,...])
        x = self.up4(x, x1[:,-1,...])
        logits = self.outc(x)
        return logits

In [82]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


from TimeDistributed import TimeDistributed
from ConvLSTM import ConvLSTM
from unet_parts import *

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_ch, out_ch, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_ch, out_ch, in_ch // 2)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_ch, in_ch//2)
        self.conv2 = nn.Conv2d(in_ch//2, out_ch, kernel_size=1)

    def forward(self, x):
        x = torch.cat(x, dim=1)
        x = self.conv(x)
        x = self.up(x)
        return self.conv2(x)
    
class LSTM_UNet(nn.Module):
    def __init__(self, n_channels, n_classes, num_filter=64, bilinear=False):
        
        n = [num_filter*2**i for i in range(4)] #[64,128,256,512,512]
        n.append(n[-1])
        t_filter = [2,8,32,128]
        n_filter = [n[i] - t_filter[i] for i in range(4)] #[62,120,224,384]
        
        super(LSTM_UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

#         self.inc = TimeDistributed(DoubleConv(1, n_filter[0]))
#         self.down1 = TimeDistributed(Down(n_filter[0], n_filter[1]))
#         self.down2 = TimeDistributed(Down(n_filter[1], n_filter[2]))
#         self.down3 = TimeDistributed(Down(n_filter[2], n_filter[3]))
#         factor = 2 if bilinear else 1
#         self.down4 = TimeDistributed(Down(n_filter[3], n[4]))
        
        self.inc = (DoubleConv(1, n_filter[0]))
        self.down1 = (Down(n_filter[0], n_filter[1]))
        self.down2 = (Down(n_filter[1], n_filter[2]))
        self.down3 = (Down(n_filter[2], n_filter[3]))
        factor = 2 if bilinear else 1
        self.down4 = (Down(n_filter[3], n[4]))
        
        self.inc_nograd = TimeDistributed(self.inc)
        
        
        self.lstm1 = ConvLSTM_BLock(n_filter[0], t_filter[0])
        self.lstm2 = ConvLSTM_BLock(n_filter[1], t_filter[1])
        self.lstm3 = ConvLSTM_BLock(n_filter[2], t_filter[2])
        self.lstm4 = ConvLSTM_BLock(n_filter[3], t_filter[3])
        self.lstm5 = ConvLSTM_BLock(n[4], n[4])
        
        # n = [64,128,256,512,512]
        self.up4 = (Up(n[3]*2, n[3], bilinear))
        self.up3 = (Up(n[3]*2, n[2], bilinear))
        self.up2 = (Up(n[2]*2, n[1], bilinear))
        self.up1 = (Up(n[1]*2, n[0], bilinear))
        self.outc = (OutConv(n[0]*2, 1))
        
    def forward(self, x):

        x1 = self.inc(x[:,-1,...])
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
#!/usr/bin/env python
# coding: utf-8

# In[1]:


from TimeDistributed import TimeDistributed
from ConvLSTM import ConvLSTM
from unet_parts import *

class LSTM_UNet(nn.Module):
    def __init__(self, n_channels, n_classes, num_filter = 64, bilinear=False):
        super(LSTM_UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, num_filter))
        self.down1 = (Down(num_filter, num_filter*2))
        self.down2 = (Down(num_filter*2, num_filter*4))
        self.down3 = (Down(num_filter*4, num_filter*8))
        factor = 2 if bilinear else 1
        self.down4 = (Down(num_filter*8, num_filter*16 // factor))
        
        self.lstm = ConvLSTM(input_dim = num_filter*16 // factor,
                    hidden_dim=num_filter*16 // factor,
                    kernel_size=(3,3),
                    num_layers=2,
                    batch_first=True,
                    bias=True,
                    return_all_layers=False)
        
        self.lf1x1 = nn.Conv2d(num_filter*16, num_filter*8, kernel_size=1)
        self.lstm1x1 = nn.Conv2d(num_filter*16, num_filter*8, kernel_size=1)
        
        self.up1 = (Up(num_filter*16, num_filter*8 // factor, bilinear))
        self.up2 = (Up(num_filter*8, num_filter*4 // factor, bilinear))
        self.up3 = (Up(num_filter*4, num_filter*2 // factor, bilinear))
        self.up4 = (Up(num_filter*2, num_filter, bilinear))
        self.outc = (OutConv(num_filter, n_classes))
        
    def forward(self, x):
        x1 = self.inc(x[:,-1,...])
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        with torch.no_grad():
            x1_ = TimeDistributed(self.inc)(x[:,:-1,...].contiguous())
            x2_ = TimeDistributed(self.down1)(x1_)
            x3_ = TimeDistributed(self.down2)(x2_)
            x4_ = TimeDistributed(self.down3)(x3_)
            x5_ = TimeDistributed(self.down4)(x4_)
            
        _ , [(h, _)] = self.lstm(torch.cat((x5_,x5.unsqueeze(1)),1))
        
        x6 = self.lf1x1(x5)
    
        x = self.lstm1x1(h)
        
        x = self.up1(torch.cat([x,x6], dim=1), x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
        
        t1 = self.lstm1(torch.cat((x1_,x1.unsqueeze(1)),1))
        t2 = self.lstm2(torch.cat((x2_,x2.unsqueeze(1)),1))
        t3 = self.lstm3(torch.cat((x3_,x3.unsqueeze(1)),1))
        t4 = self.lstm4(torch.cat((x4_,x4.unsqueeze(1)),1))
        t5 = self.lstm5(torch.cat((x5_,x5.unsqueeze(1)),1))
        
        x = self.up4([x5,t5])
        x = self.up3([x4,t4,x])
        x = self.up2([x3,t3,x])
        x = self.up1([x2,t2,x])

        logits = self.outc(torch.cat([x1,t1,x], dim = 1))
        return logits

In [83]:
# from LSTM_v1 import LSTM_UNet
model = LSTM_UNet(n_channels=1, n_classes=1, num_filter=32)
a = torch.rand([4,10,1,448,336])

out = model(a)
out.shape

torch.Size([4, 1, 448, 336])

In [85]:
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params/1e6

45.777953

In [57]:
a = torch.rand([4,10,1,448,336])
b = a[:,:-1,:,:,:].contiguous()
l = DoubleConv(1, 32)
t = TimeDistributed(l)(b)
t.shape

torch.Size([4, 9, 32, 448, 336])

In [38]:
a.shape

torch.Size([4, 10, 1, 448, 336])

In [34]:
a = torch.rand([4,10,1,448,336])
t(a[:,:-1,...]).shape

TypeError: 'Tensor' object is not callable

In [14]:
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params/1e6

108.602731

In [15]:
out.shape

torch.Size([1, 1, 448, 336])

In [85]:
a = torch.rand([1,10,62,448,336])

l = ConvLSTM_BLock(62,2)
l(a).shape

torch.Size([1, 2, 448, 336])

In [63]:
up = ConvLSTM(input_dim = 1,
                    hidden_dim=4,
                    kernel_size=(3,3),
                    num_layers=3,
                    batch_first=True,
                    bias=True,
                    return_all_layers=False)

model_parameters = filter(lambda p: p.requires_grad, up.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params/1e6

0.003072

In [64]:
[x], _ = up(a)

In [65]:
x.shape

torch.Size([1, 10, 4, 448, 336])

In [32]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x
    
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x
    
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi


# In[5]:


from TimeDistributed import TimeDistributed
from ConvLSTM import ConvLSTM
from unet_parts import *

class Att_LSTM_UNet(nn.Module):
    def __init__(self, n_channels, n_classes, num_filter = 32, bilinear=False):
        super(Att_LSTM_UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = TimeDistributed(DoubleConv(n_channels, num_filter))
        self.down1 = TimeDistributed(Down(num_filter, num_filter*2))
        self.down2 = TimeDistributed(Down(num_filter*2, num_filter*4))
        self.down3 = TimeDistributed(Down(num_filter*4, num_filter*8))
        factor = 2 if bilinear else 1
        self.down4 = TimeDistributed(Down(num_filter*8, num_filter*16 // factor))
        
        self.lstm = ConvLSTM(input_dim = num_filter*16 // factor,
                    hidden_dim=num_filter*16 // factor,
                    kernel_size=(3,3),
                    num_layers=2,
                    batch_first=True,
                    bias=True,
                    return_all_layers=False)
        
        self.Up5 = up_conv(ch_in=num_filter*16,ch_out=num_filter*8)
        self.Att5 = Attention_block(F_g=num_filter*8,F_l=num_filter*8,F_int=num_filter*4)
        self.Up_conv5 = DoubleConv(in_channels=num_filter*16, out_channels=num_filter*8)

        self.Up4 = up_conv(ch_in=num_filter*8,ch_out=num_filter*4)
        self.Att4 = Attention_block(F_g=num_filter*4,F_l=num_filter*4,F_int=num_filter*2)
        self.Up_conv4 = DoubleConv(in_channels=num_filter*8, out_channels=num_filter*4)
        
        self.Up3 = up_conv(ch_in=num_filter*4,ch_out=num_filter*2)
        self.Att3 = Attention_block(F_g=num_filter*2,F_l=num_filter*2,F_int=num_filter)
        self.Up_conv3 = DoubleConv(in_channels=num_filter*4, out_channels=num_filter*2)
        
        self.Up2 = up_conv(ch_in=num_filter*2,ch_out=num_filter)
        self.Att2 = Attention_block(F_g=num_filter,F_l=num_filter,F_int=num_filter//2)
        self.Up_conv2 = DoubleConv(in_channels=num_filter*2, out_channels=num_filter)

        self.Conv_1x1 = nn.Conv2d(num_filter,n_classes,kernel_size=1,stride=1,padding=0)
        
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        _ , [(h, c)] = self.lstm(x5)
        print(x5.shape, h.shape)
        # decoding + concat path
        d5 = self.Up5(c)
        x4 = self.Att5(g=d5,x=x4[:,-1,...])
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3[:,-1,...])
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2[:,-1,...])
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1[:,-1,...])
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)
        
        d1 = self.Conv_1x1(d2)

        return d1

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [2]:
a = torch.rand([4,10,512,8,8])
num_filter=32
factor = 1
layer = ConvLSTM(input_dim = num_filter*16 // factor,
                    hidden_dim=num_filter*16 // factor,
                    kernel_size=(3,3),
                    num_layers=2,
                    batch_first=True,
                    bias=True,
                    return_all_layers=False)
[x], _ = layer(a)

In [33]:
model = Att_LSTM_UNet(n_channels=1, n_classes=1, num_filter = 32)

In [34]:
a = torch.rand([1,10,1,448,336])

out = model(a)

torch.Size([1, 10, 512, 28, 21]) torch.Size([1, 512, 28, 21])


In [35]:
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params/1e6

46.475389

In [36]:
out.shape

torch.Size([1, 1, 448, 336])

In [9]:
torch.squeeze(out).shape

torch.Size([4, 10, 448, 336])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x
    
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x
    
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

In [None]:
a =torch.rand([4,10,1,448,336])
model = Att_LSTM_UNet(n_channels=1, n_classes=1, num_filter = 32)

In [None]:
args_load = './checkpoints_LSTM/checkpoint_epoch50.pth'

if args_load:
    state_dict = torch.load(args_load)#, map_location=device)

In [None]:
model.state_dict()

In [None]:
for name, module in model.named_children():
    if not name.startswith('params'):
        print(name)
        print(module)
        print('------')