In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

In [8]:
class Recurrent_Conv(nn.Module):
    def __init__(self, out_channels, t):
        super(Recurrent_Conv, self).__init__()
        self.t = t
        self.out_ch = out_channels
        self.conv = nn.Sequential(
            nn.Conv3d(self.out_ch, self.out_ch, kernel_size=3,stride=1,padding=1,bias=True),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):
            if i==0:
                x1 = self.conv(x)    
            x1 = self.conv(x + x1)
        return x1
        
class R2CNN(nn.Module): 
    def __init__(self, in_channels, out_channels, t):
        super(R2CNN, self).__init__()
        self.in_ch = in_channels
        self.out_ch = out_channels
        self.t = t
        self.RCNN = nn.Sequential(
            Recurrent_Conv(self.out_ch, self.t),
            Recurrent_Conv(self.out_ch, self.t)
        )
        self.initial_conv = nn.Conv3d(self.in_ch, self.out_ch, kernel_size=1, stride=1, padding=0)

    def forward(self,x):
        x = self.initial_conv(x)
        x1 = self.RCNN(x)
        return x+x1

class R2Unet(nn.Module):
    def __init__(self, in_channels, out_channels, t):
        super(R2Unet, self).__init__()
        self.in_ch= in_channels
        self.out_ch = out_channels
        self.t = t
        self.avgPool = nn.AvgPool3d(2)
        self.r2cnn1 = R2CNN(self.in_ch, 16, 2)
        self.r2cnn2 = R2CNN(16, 32, 2)
        self.r2cnn3 = R2CNN(32, 64, 2)
        self.r2cnn4 = R2CNN(64, 128, 2)
        self.up_conv1 = self.up_conv_layer(128, 64, 3, 2, 1, 1)     
        self.r2cnn5 = R2CNN(128, 32, 2)
        self.up_conv2 = self.up_conv_layer(32, 32, 3, 2, 1, 1)
        self.r2cnn6 = R2CNN(64, 16, 2)
        self.up_conv3 = self.up_conv_layer(16, 16, 3, 2, 1, 1)
        self.r2cnn7 = R2CNN(32, 16, 2)
        self.conv11 = nn.Conv3d(16, self.out_ch, kernel_size = 1, stride=1, padding=0)
        
    def up_conv_layer(self, in_channels, out_channels, kernel_size, stride=3, padding=1, output_padding=1, bias=True):
        layers = nn.Sequential(
            nn.ConvTranspose3d(in_channels,out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=output_padding, bias=True),
            nn.ReLU()
        )
        return layers
    
    def forward(self, x):
        x1 = self.r2cnn1(x)    
        x2 = self.avgPool(x1)  
        x3 = self.r2cnn2(x2)   
        x4 = self.avgPool(x3)  
        x5 = self.r2cnn3(x4)   
        x6 = self.avgPool(x5)  
        x7 = self.r2cnn4(x6)   
        x8 = self.up_conv1(x7) 
        x8 = torch.cat((x5, x8), dim = 1)
        x9 = self.r2cnn5(x8)
        x10 = self.up_conv2(x9)
        x10 = torch.cat((x3, x10), dim = 1)
        x11 = self.r2cnn6(x10)
        x12 = self.up_conv3(x11)
        x12 = torch.cat((x1, x12), dim = 1)
        x13 = self.r2cnn7(x12)
        x14 = self.conv11(x13)
        return x14    
        

In [9]:
model = R2Unet(1, 1, 2)
inputs = torch.randn(1, 1, 32, 32, 32).uniform_(0, 1)
output = model(inputs)

In [10]:
output.size()

torch.Size([1, 1, 32, 32, 32])