In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [64]:
class URNetDownCNNBlock(nn.Module):

    def __init__(self, input_size, channel_size, conv_kernel, max_kernel, stride = 1, padding = 'same'):
        super(URNetDownCNNBlock, self).__init__()

        self.conv1 = nn.Conv1d(input_size, channel_size, conv_kernel, stride, padding)
        self.conv2 = nn.Conv1d(channel_size, channel_size, conv_kernel, stride, padding)
        self.batchnorm1 = nn.BatchNorm1d(channel_size)
        self.batchnorm2 = nn.BatchNorm1d(channel_size)
        self.rnn = nn.GRU(channel_size, channel_size)
        self.maxpool = nn.MaxPool1d(max_kernel)

    def forward(self, x):
        """Args:
            x (tensor): shape [batch, channels, len]

        Returns:
            cnv_out (tensor): shape [batch, channels, len]
            rrn_out (tensor): shape [len, batch, channels]
            max_out (tensor): shape [batch, channels, len]
        """

        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x)
        cnv_out = F.relu(x)

        x = cnv_out.permute(2, 0, 1)
        rnn_out, _ = self.rnn(x)
        
        rnn_out = rnn_out.permute(1, 2, 0)
        max_out = self.maxpool(rnn_out)

        return cnv_out, rnn_out, max_out

class URNetFlatCNNBlock(nn.Module):

    def __init__(self, input_size, channel_size, conv_kernel, stride = 1, padding = 'same'):
        super(URNetFlatCNNBlock, self).__init__()

        self.conv1 = nn.Conv1d(input_size, channel_size, conv_kernel, stride, padding)
        self.conv2 = nn.Conv1d(channel_size, channel_size, conv_kernel, stride, padding)
        self.batchnorm1 = nn.BatchNorm1d(channel_size)
        self.batchnorm2 = nn.BatchNorm1d(channel_size)

    def forward(self, x):
        """Args:
            x (tensor): shape [batch, channels, len]

        Returns:
            x (tensor): shape [batch, channels, len]
        """

        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = F.relu(x)

        return x

class URNetUpCNNBlock(nn.Module):

    def __init__(self, input_size, channel_size, conv_kernel, up_kernel, up_stride, conv_stride = 1, padding = 'same'):
        super(URNetUpCNNBlock, self).__init__()

        self.conv1 = nn.Conv1d(int(channel_size*3), channel_size, conv_kernel, conv_stride, padding)
        self.conv2 = nn.Conv1d(channel_size, channel_size, conv_kernel, conv_stride, padding)
        self.batchnorm1 = nn.BatchNorm1d(channel_size)
        self.batchnorm2 = nn.BatchNorm1d(channel_size)
        self.upconv = nn.ConvTranspose1d(input_size, channel_size, up_kernel, up_stride)

    def forward(self, x, cnn_in, rnn_in):
        """Args:
            x (tensor): shape [batch, channels, len]
            cnn_in (tensor): shape [batch, channels, len]
            rnn_in (tensor): shape [batch, channels, len]

        Returns:
            out (tensor): shape [batch, channels, len]
        """

        x = self.upconv(x)
        x = torch.cat([x, cnn_in, rnn_in], dim = 1) # concatenate on channel dim
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x)
        out = F.relu(x)

        return out

In [58]:
x = torch.randn((4, 1, 200))

In [67]:
DBlock1 = URNetDownCNNBlock(1, 16, 11, 2)
DBlock2 = URNetDownCNNBlock(16, 32, 3, 2)
DBlock3 = URNetDownCNNBlock(32, 64, 3, 2)
FBlock  = URNetFlatCNNBlock(64, 128, 3)
UBlock1 = URNetUpCNNBlock(128, 64, 3, 2, 2)
UBlock2 = URNetUpCNNBlock(64, 32, 3, 2, 2)
UBlock3 = URNetUpCNNBlock(32, 16, 3, 2, 2)

In [68]:
cnv_out1, rnn_out1, max_out1 = DBlock1(x)
cnv_out2, rnn_out2, max_out2 = DBlock2(max_out1)
cnv_out3, rnn_out3, max_out3 = DBlock3(max_out2)

In [69]:
f_out = FBlock(max_out3)

In [70]:
up1_out = UBlock1(f_out, cnv_out3, rnn_out3)
up2_out = UBlock2(up1_out, cnv_out2, rnn_out2)
up3_out = UBlock3(up2_out, cnv_out1, rnn_out1)

torch.Size([4, 64, 50])
torch.Size([4, 32, 100])
torch.Size([4, 16, 200])


In [71]:
up3_out.shape

torch.Size([4, 16, 200])