In [None]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [None]:
import torch
from torch import nn

In [None]:
class zUnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool = nn.MaxPool3d(kernel_size=(2, 2, 1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.first = nn.Conv3d(1, 50, kernel_size=(3, 3, 1), padding=(1, 1, 0))

        self.encoder1 = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )

        self.encoder2 = nn.Sequential(
            nn.Conv3d(32, 32, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )

        self.encoder3 = nn.Sequential(
            nn.Conv3d(64, 64, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )

        self.encoder4 = nn.Sequential(
            nn.Conv3d(128, 128, kernel_size=(2, 2, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(128, 256, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )

        self.upconv3 = nn.ConvTranspose3d(256, 256, kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=(1, 1, 0))
        self.decoder3 = nn.Sequential(
            nn.Conv3d(128+256, 128, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(128, 128, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )
        
        self.upconv2 = nn.ConvTranspose3d(128, 128, kernel_size=(2, 2, 3), stride=(2, 2, 2))
        self.decoder2 = nn.Sequential(
            nn.Conv3d(64+128, 64, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )

        self.upconv1 = nn.ConvTranspose3d(64, 64, kernel_size=(3, 3, 2), stride=(2, 2, 2))
        self.decoder1 = nn.Sequential(
            nn.Conv3d(32+64, 32, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
            nn.Conv3d(32, 32, kernel_size=(3, 3, 1), padding=(1, 1, 0)),
            nn.ReLU(),
        )

        self.final = nn.Conv3d(32, 3, kernel_size=(3, 3, 1), padding=(1, 1, 0))

    def forward(self, x):
        x = torch.permute(x, (0, 4, 2, 3, 1))
        x = self.first(x)
        x = torch.permute(x, (0, 4, 2, 3, 1))
        x1 = self.encoder1(x)
        x = self.maxpool2(x1)
        x2 = self.encoder2(x)
        x = self.maxpool2(x2)
        x3 = self.encoder3(x)
        x = self.maxpool(x3)
        x = self.encoder4(x)
        x = self.upconv3(x)
        x = torch.concatenate([x, x3], 1)
        x = self.decoder3(x)
        x = self.upconv2(x)
        x = torch.concatenate([x, x2], 1)
        x = self.decoder2(x)
        x = self.upconv1(x)
        x = torch.concatenate([x, x1], 1)
        x = self.decoder1(x)

        x = self.final(x)

        return x

In [None]:
b = torch.rand([1, 3, 513, 257, 1]).to('cuda')

In [None]:
model = zUnet().to('cuda')

In [None]:
model(b).shape

torch.Size([1, 3, 513, 257, 50])