In [1]:
import torch
import torch.nn.functional as F

class CNA3D(torch.nn.Module):
    def __init__(self,in_channels,out_channels,kSize,stride,padding=(1,1,1),bias=True,norm_args=None,activation_args=None):
        super().__init__()
        self.norm_args = norm_args
        self.activation_args = activation_args
        
        self.conv = torch.nn.Conv3d(in_channels,out_channels,kernel_size=kSize,stride=stride,padding=padding,bias=bias)
        
        if norm_args is not None:
            self.norm = torch.nn.InstanceNorm3d(out_channels, **norm_args)
        
        if activation_args is not None:
            self.activation = torch.nn.LeakyReLU(**activation_args)
            
    def forward(self,x):
        x = self.conv(x)
        
        if self.norm_args is not None:
            x = self.norm(x)
        if self.activation_args is not None:
            x = self.activation(x)
        return x
    
class CB3d(torch.nn.Module):
    def __init__(self,in_channels,out_channels,kSize=(3,3),stride=(1,1),padding=(1,1,1),bias=True,norm_args:tuple=(None,None),activation_args:tuple=(None,None)):
        super().__init__()
        
        self.conv1 = CNA3D(in_channels,out_channels,kSize=kSize[0],stride=stride[0],padding=padding,bias=bias,norm_args=norm_args[0],activation_args=activation_args[0])
        
        self.conv2 = CNA3D(out_channels,in_channels,kSize=kSize[1],stride=stride[1],padding=padding,bias=bias,norm_args=norm_args[1],activation_args=activation_args[1])
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        return x
    
class BasicNet(torch.nn.Module):
    norm_kwargs={"affine":True}
    activation_kwargs={"negative_slope":1e-2,"inplace":True}
    
    def __init__(self):
        super(BasicNet,self).__init__();

def FMU_sub(x2d, x3d, mode="sub"):  
        return torch.abs(x2d-x3d) 

class Downsample(BasicNet):
    def __init__(self,in_channels,out_channels,mode: tuple, downsample=True, min_z=8):
        super().__init__()
        self.mode_in,self.mode_out = mode 
        self.downsample = downsample 
        self.FMU = FMU_sub
        norm_args = (self.norm_kwargs, self.norm_kwargs)
        activation_args = (self.activation_kwargs, self.activation_kwargs)
        self.CB2d = None
        
        print(self.mode_out)
        
        if self.mode_out == '2d' or self.mode_out =='both':
            self.CB2d = CB3d(in_channels=in_channels,out_channels=out_channels,kSize=((1,3,3),(1,3,3)),stride=(1,1),padding=(0,1,1),
                             norm_args=norm_args,activation_args=activation_args)
            
        if self.mode_out=='3d' or self.mode_out=='both':
            self.CB3d = CB3d(in_channels=in_channels,out_channels=out_channels,kSize=(3,3),stride=(1,1),padding=(1,1,1),
                             norm_args=norm_args,activation_args=activation_args)
            
    def forward(self,x):
        if self.downsample:
            if self.mode_in=='both':
                x2d,x3d = x
                p2d = F.max_pool3d(x2d,kernel_size=(1,2,2),stride=(1,2,2))
                if x3.shape[2] >= self.min_z:
                    p3d=F.max_pool3d(x3d,kernel_size=(2,2,2),stride=(2,2,2))
                else:
                    p3d=F.max_pool3d(x3d,kernel_size=(1,2,2),stride=(1,2,2))
                    
                x = self.FMU(p2d,p3d)
                
            elif self.mode_in=='2d':
                x = F.max_pool3d(x,kernel_size=(1,2,2),stride=(1,2,2))
                
            elif self.mode_in=='3d':
                if x.shape[2] >= self.min_z:
                    x=F.max_pool3d(x,kernel_size=(2,2,2),stride=(2,2,2))
                else:
                    x=F.max_pool3d(x,kernel_size=(1,2,2),stride=(2,2,2))
                    
        if self.mode_out=='2d':
            return self.CB2d(x)
        elif self.mode_out=='3d':
            return self.CB3d(x)
        else: 
            return self.CB2d(x), self.CB3d(x)
                
class Upsample(BasicNet):
    def __init__(self, in_channels, out_channels, mode: tuple):
        super().__init__()
        self.mode_in,self.mode_out = mode 
        self.FMU = FMU_sub
        norm_args = (self.norm_kwargs,self.norm_kwargs)
        activation_args=(self.activation_kwargs,self.activation_kwargs)
        
        if self.mode_out=='2d' or self.mode_out=='both':
            self.CB2d=CB3d(in_channels=in_channels,out_channels=out_channels,
                           kSize=((1,3,3),(1,3,3)),stride=(1,1),padding=(0,1,1),
                           norm_args=norm_args,activation_args=activation_args)
            
        if self.mode_out=='3d' or self.mode_out=='both':
            self.CB3d=CB3d(in_channels=in_channels,out_channels=out_channels,
                           kSize=(3,3),stride=(1,1),padding=(1,1,1),
                           norm_args=norm_args,activation_args=activation_args)
    
    def forward(self,x):
        x2d, xskip2d, x3d, xskip3d = x
        
        tarSize=xskip2d.shape[2:]
        up2d=F.interpolate(x2d,size=tarSize,mode="trilinear",align_corners=False)
        up3d=F.interpolate(x3d,size=tarSize,mode="trilinear",align_corners=False)
        
        cat=torch.cat([FMU_sub(xskip2d,xskip3d,self.FMU),FMU_sub(up2d,up3d,self.FMU)],dim=1)
        
        if self.mode_out=='2d':
            return self.CB2d(cat)
        elif self.mode_out=='3d':
            return self.CB3d(cat)
        else:
            return self.CB2d(cat),self.CB3d(cat)
                
class MNet(BasicNet):
    def __init__(self, in_channels, num_classes, kn=(32, 48, 64, 80, 96), ds=True ):
        """
        Args:
            in_channels: channels of input
            num_classes: output classes
            kn: the number of kernels
            ds: deep supervision
            FMU: type of feature merging unit
        """
        super().__init__()
        self.ds = ds
        self.num_classes = num_classes

        channel_factor = {'sum': 1, 'sub': 1, 'cat': 2}
        fct = 1##= channel_factor[FMU]

        self.down11 = Downsample(in_channels, kn[0], ('/', 'both'), downsample=False)
        self.down12 = Downsample(kn[0], kn[1], ('2d', 'both'))
        self.down13 = Downsample(kn[1], kn[2], ('2d', 'both'))
        self.down14 = Downsample(kn[2], kn[3], ('2d', 'both'))
        self.bottleneck1 = Downsample(kn[3], kn[4], ('2d', '2d'))
        self.up11 = Upsample(fct * (kn[3] + kn[4]), kn[3], ('both', '2d'))
        self.up12 = Upsample(fct * (kn[2] + kn[3]), kn[2], ('both', '2d'))
        self.up13 = Upsample(fct * (kn[1] + kn[2]), kn[1], ('both', '2d'))
        self.up14 = Upsample(fct * (kn[0] + kn[1]), kn[0], ('both', 'both'))

        self.down21 = Downsample(kn[0], kn[1], ('3d', 'both'))
        self.down22 = Downsample(fct * kn[1], kn[2], ('both', 'both'))
        self.down23 = Downsample(fct * kn[2], kn[3], ('both', 'both'))
        self.bottleneck2 = Downsample(fct * kn[3], kn[4], ('both', 'both'))
        self.up21 = Upsample(fct * (kn[3] + kn[4]), kn[3], ('both', 'both'))
        self.up22 = Upsample(fct * (kn[2] + kn[3]), kn[2], ('both', 'both'))
        self.up23 = Upsample(fct * (kn[1] + kn[2]), kn[1], ('both', '3d'))

        self.down31 = Downsample(kn[1], kn[2], ('3d', 'both'))
        self.down32 = Downsample(fct * kn[2], kn[3], ('both', 'both'))
        self.bottleneck3 = Downsample(fct * kn[3], kn[4], ('both', 'both'))
        self.up31 = Upsample(fct * (kn[3] + kn[4]), kn[3], ('both', 'both'))
        self.up32 = Upsample(fct * (kn[2] + kn[3]), kn[2], ('both', '3d'))

        self.down41 = Downsample(kn[2], kn[3], ('3d', 'both'))
        self.bottleneck4 = Downsample(fct * kn[3], kn[4], ('both', 'both'))
        self.up41 = Upsample(fct * (kn[3] + kn[4]), kn[3], ('both', '3d'))

        self.bottleneck5 = Downsample(kn[3], kn[4], ('3d', '3d'))


        self.outputs = torch.nn.ModuleList(
            [torch.nn.Conv3d(c, num_classes, kernel_size=(1, 1, 1), stride=1, padding=0, bias=False)
             for c in [kn[0], kn[1], kn[1], kn[2], kn[2], kn[3], kn[3]]]
        )

    def forward(self, x):
        down11 = self.down11(x)
        down12 = self.down12(down11[0])
        down13 = self.down13(down12[0])
        down14 = self.down14(down13[0])
        bottleNeck1 = self.bottleneck1(down14[0])

        down21 = self.down21(down11[1])
        down22 = self.down22([down21[0], down12[1]])
        down23 = self.down23([down22[0], down13[1]])
        bottleNeck2 = self.bottleneck2([down23[0], down14[1]])

        down31 = self.down31(down21[1])
        down32 = self.down32([down31[0], down22[1]])
        bottleNeck3 = self.bottleneck3([down32[0], down23[1]])

        down41 = self.down41(down31[1])
        bottleNeck4 = self.bottleneck4([down41[0], down32[1]])

        bottleNeck5 = self.bottleneck5(down41[1])

        up41 = self.up41([bottleNeck4[0], down41[0], bottleNeck5, down41[1]])

        up31 = self.up31([bottleNeck3[0], down32[0], bottleNeck4[1], down32[1]])
        up32 = self.up32([up31[0], down31[0], up41, down31[1]])

        up21 = self.up21([bottleNeck2[0], down23[0], bottleNeck3[1], down23[1]])
        up22 = self.up22([up21[0], down22[0], up31[1], down22[1]])
        up23 = self.up23([up22[0], down21[0], up32, down21[1]])
        
        up11 = self.up11([bottleNeck1, down14[0], bottleNeck2[1], down14[1]])
        up12 = self.up12([up11, down13[0], up21[1], down13[1]])
        up13 = self.up13([up12, down12[0], up22[1], down12[1]])
        up14 = self.up14([up13, down11[0], up23, down11[1]])


        if self.ds:
            features = [up14[0]+up14[1], up23, up13, up32, up12, up41,up11]
            return [self.outputs[i](features[i]) for i in range(7)]
        else:
            return self.outputs[0](up14[0]+up14[1])



if __name__ == '__main__':
    MNet = MNet(1, 3, kn=(2, 2, 2, 2, 2), ds=True)
    input = torch.randn((1, 1, 19, 255,256))
    output = MNet(input)

    print([e.shape for e in output])

both
both
both
both
2d
both
both
both
both
both
both
both
both
both
3d


RuntimeError: Given groups=1, weight of size [2, 2, 1, 3, 3], expected input[1, 1, 19, 127, 128] to have 2 channels, but got 1 channels instead

In [None]:
!ls /datasets/brats20_1