In [88]:
import torch
import torch.nn as nn
# from models.swish import Swish
# from models.ndInterp import NDLinearInterpolation

class LinearBlock(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(LinearBlock,self).__init__()
        self.Linear = nn.Linear(in_channel,out_channel)
    #		self.Act = Swish()
        self.Act_3c = nn.ReLU()
    def forward(self,x):
        x = self.Linear(x)
        x = self.Act_3c(x)
        return x

class ConvBlock(nn.Module):
    def __init__(self,in_channel,out_channel,kernel_size,padding_mode='replicate'):
        super(ConvBlock,self).__init__()
        self.Conv = nn.Conv3d(in_channel,out_channel,kernel_size,padding=kernel_size//2,padding_mode=padding_mode)
        self.BatchNorm = nn.BatchNorm3d(out_channel)
    #		self.Act = Swish()
        self.Act_3c = nn.ReLU()

    def forward(self,x):
        x = self.Conv(x)
        x = self.BatchNorm(x)
        x = self.Act_3c(x)
        return x

class ResBlock(nn.Module):
    def __init__(self,in_channel,out_channel,padding_mode='replicate'):
        super(ResBlock,self).__init__()
        self.shortcut = nn.Conv3d(in_channel,out_channel,1)
        self.Conv_1 = ConvBlock(in_channel,in_channel,1,padding_mode=padding_mode)
        self.Conv_2 = ConvBlock(in_channel,in_channel,3,padding_mode=padding_mode)
        self.Conv_3a = nn.Conv3d(in_channel,out_channel,1)
        self.BatchNorm_3b = nn.BatchNorm3d(out_channel)
    #		self.Act_3c = Swish()
        self.Act_3c = nn.ReLU()

    def forward(self,x):
        y = self.Conv_1(x)
        y = self.Conv_2(y)
        y = self.Conv_3a(y)
        y = self.BatchNorm_3b(y)
        y = self.shortcut(x)+y
        y = self.Act_3c(y)
        return y 

class SamplingBlock(nn.Module):	
    def __init__(self,in_channel,out_channel,mode,padding_mode='replicate'):
        super(SamplingBlock,self).__init__()	
        if mode == 'Up':
            self.conv = nn.Sequential(*[nn.Upsample(scale_factor=2),ResBlock(in_channel,out_channel)])
        if mode == 'Down':
            self.conv = nn.Sequential(*[ResBlock(in_channel,out_channel),nn.MaxPool3d(2)])

    def forward(self,x):
            return self.conv(x)

class Encoder3D(nn.Module):
    def __init__(self,in_channel,out_channel,n_pairs,padding_mode='replicate'):
        super(Encoder3D,self).__init__()	
        self.Res1 = ResBlock(in_channel,out_channel)
        self.n_pairs = n_pairs
        self.Down_array = nn.Sequential(*[SamplingBlock(2**i*out_channel,2**(i+1)*out_channel,'Down') for i in range(n_pairs)])

    def forward(self,x):
        y = self.Res1(x)
        for i in range(self.n_pairs):
            y = self.Down_array[i](y)
        y = y.permute(0, 2, 3, 4, 1).squeeze(1).squeeze(1).squeeze(1)
        return y

class SR(nn.Module):
    def __init__(self,in_channel,out_channel,n_layers):
        super(SR,self).__init__()
        self.contextEncoder = Encoder3D(in_channel,out_channel,n_layers)
        self.output = nn.Linear(512*5,1)

    def forward(self, context):
        outputs = []
        for i in range(context.shape[0]):
            print(context[i].shape)
            outputs.append(self.contextEncoder(context[i].unsqueeze(0)))
        print(outputs[0].shape)
        combine = torch.cat(outputs, dim=1)
        print(combine.shape)
        output = self.output(combine)
        return output


In [89]:
sr = SR(1, 16, 5)

In [90]:
sr(torch.rand((5, 1, 32, 32, 32)), )

torch.Size([32, 32, 32])
torch.Size([32, 32, 32])
torch.Size([32, 32, 32])
torch.Size([32, 32, 32])
torch.Size([32, 32, 32])
torch.Size([1, 512])
torch.Size([1, 2560])


tensor([[-0.2495]], grad_fn=<AddmmBackward>)

In [86]:
summary(sr, (5, 1, 32, 32, 32))

torch.Size([5, 1, 32, 32, 32])


RuntimeError: Expected 5-dimensional input for 5-dimensional weight 1 1 1 1 1 140327380454600, but got 6-dimensional input of size [1, 5, 1, 32, 32, 32] instead

In [34]:
enc = Encoder3D(1, 16, 5)

In [35]:
from torchsummary import summary

In [36]:
summary(enc, (1, 32, 32, 32))

torch.Size([2, 512, 1, 1, 1])
torch.Size([2, 512])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 1, 32, 32, 32]               2
       BatchNorm3d-2        [-1, 1, 32, 32, 32]               2
              ReLU-3        [-1, 1, 32, 32, 32]               0
         ConvBlock-4        [-1, 1, 32, 32, 32]               0
            Conv3d-5        [-1, 1, 32, 32, 32]              28
       BatchNorm3d-6        [-1, 1, 32, 32, 32]               2
              ReLU-7        [-1, 1, 32, 32, 32]               0
         ConvBlock-8        [-1, 1, 32, 32, 32]               0
            Conv3d-9       [-1, 16, 32, 32, 32]              32
      BatchNorm3d-10       [-1, 16, 32, 32, 32]              32
           Conv3d-11       [-1, 16, 32, 32, 32]              32
             ReLU-12       [-1, 16, 32, 32, 32]               0
         ResBlock-13       [-1, 16, 32, 32, 32]     