# UNETR

### Architecture의 주요 hyperparameters

- image size = (H:224, W:224, D:224)
- Layers = 12  
- Multi-head = 8
- Embedding length = 768  
- Patch size = 16  

![nn](./unetr_architecture.png)

기본적인 conv3d block을 선언해준다.  
위 그림에서 초록색, 파란색, 노란색, 회색 notation에에 해당한다.

### Implementation Details:    

Batch size: 6  
Optimizer: AdamW  
Learning rate: 0.0001  
Iteration: 20000  
L=12, K=768, P=16×16×16  
Shift intensity  
Ensemble: Five-fold cross-validation  

In [1]:
    import copy
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    from einops import rearrange, reduce, repeat
    from einops.layers.torch import Rearrange, Reduce
    from torchsummary import summary
    from torch import Tensor

    class SingleDeconv3DBlock(nn.Module):
        def __init__(self, in_planes, out_planes):
            super().__init__()
            self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)

        def forward(self, x):
            return self.block(x)


    class SingleConv3DBlock(nn.Module):
        def __init__(self, in_planes, out_planes, kernel_size):
            super().__init__()
            self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
                                   padding=((kernel_size - 1) // 2))

        def forward(self, x):
            return self.block(x)


    class Conv3DBlock(nn.Module):
        def __init__(self, in_planes, out_planes, kernel_size=3):
            super().__init__()
            self.block = nn.Sequential(
                SingleConv3DBlock(in_planes, out_planes, kernel_size),
                nn.BatchNorm3d(out_planes),
                nn.ReLU(True)
            )

        def forward(self, x):
            return self.block(x)


    class Deconv3DBlock(nn.Module):
        def __init__(self, in_planes, out_planes, kernel_size=3):
            super().__init__()
            self.block = nn.Sequential(
                SingleDeconv3DBlock(in_planes, out_planes),
                SingleConv3DBlock(out_planes, out_planes, kernel_size),
                nn.BatchNorm3d(out_planes),
                nn.ReLU(True)
            )

        def forward(self, x):
            return self.block(x)    

In [2]:
CB = Conv3DBlock(1, 64, 3)
summary(CB, (1,224,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 64, 224, 224, 224]           1,792
 SingleConv3DBlock-2    [-1, 64, 224, 224, 224]               0
       BatchNorm3d-3    [-1, 64, 224, 224, 224]             128
              ReLU-4    [-1, 64, 224, 224, 224]               0
Total params: 1,920
Trainable params: 1,920
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 42.88
Forward/backward pass size (MB): 21952.00
Params size (MB): 0.01
Estimated Total Size (MB): 21994.88
----------------------------------------------------------------


## 1. Embedding  

1. 인풋 이미지들을 패치로 나눠주고 
2. position embedding과 더해준다.

In [3]:
class Embeddings(nn.Module):
    def __init__(self, input_shape, patch_size=16, embed_dim=768, dropout=0.):
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = input_shape[-4]
        self.n_patches = int((input_shape[-1] * input_shape[-2] * input_shape[-3]) / (patch_size * patch_size * patch_size))
        self.embed_dim = embed_dim
        self.patch_embeddings = nn.Conv3d(in_channels=self.in_channels, out_channels=self.embed_dim,
                                          kernel_size=self.patch_size, stride=self.patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, self.embed_dim))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.patch_embeddings(x)
        x = rearrange(x, "b n h w d -> b (h w d) n")
        # batch, embed_dim, height/patch, width/patch, depth/patch
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

In [4]:
## input ##
shape = (1,1,224,224,224)
x = torch.rand(shape)
patch_size = 16

In [5]:
E = Embeddings(x.shape[1:])
summary(E, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 768, 14, 14, 14]       3,146,496
           Dropout-2            [-1, 2744, 768]               0
Total params: 3,146,496
Trainable params: 3,146,496
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 42.88
Forward/backward pass size (MB): 32.16
Params size (MB): 12.00
Estimated Total Size (MB): 87.03
----------------------------------------------------------------


In [6]:
embedding_x = E(x)
print(embedding_x.shape)

torch.Size([1, 2744, 768])


## 2. Multi Head Attention Block (MHA)

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

In [8]:
MHA = MultiHeadAttention()
summary(MHA, embedding_x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [-1, 2744, 2304]       1,771,776
           Dropout-2        [-1, 8, 2744, 2744]               0
            Linear-3            [-1, 2744, 768]         590,592
Total params: 2,362,368
Trainable params: 2,362,368
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.04
Forward/backward pass size (MB): 523.88
Params size (MB): 9.01
Estimated Total Size (MB): 540.93
----------------------------------------------------------------


In [9]:
embedding_x = MHA(embedding_x)
print(embedding_x.shape)

torch.Size([1, 2744, 768])


## 3. Feed Forward Block (FF)

In [10]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int = 768, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

In [11]:
FF = FeedForwardBlock()
summary(FF, embedding_x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [-1, 2744, 3072]       2,362,368
              GELU-2           [-1, 2744, 3072]               0
           Dropout-3           [-1, 2744, 3072]               0
            Linear-4            [-1, 2744, 768]       2,360,064
Total params: 4,722,432
Trainable params: 4,722,432
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.04
Forward/backward pass size (MB): 209.02
Params size (MB): 18.01
Estimated Total Size (MB): 235.07
----------------------------------------------------------------


In [12]:
embedding_x = FF(embedding_x)
print(embedding_x.shape)

torch.Size([1, 2744, 768])


## Transformer Block  

MHA 블럭과 FF 블럭을 하나로 묶어준다.

한편, 각각 3,6,9,12번째의 transformer block의 feature을 뽑아줘야하기 때문에, 해당하는 layer의 feature map을 리스트 형태로 받습니다.


In [13]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, depth=12, dropout=0., extract_layers=[3,6,9,12]):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(embed_dim, MultiHeadAttention(embed_dim, num_heads, dropout)),
                PreNorm(embed_dim, FeedForwardBlock(embed_dim, expansion=4))
            ]))            
        self.extract_layers = extract_layers
        
    def forward(self, x):
        extract_layers = []
        
        for cnt, (attn, ff) in enumerate(self.layers):
            x = attn(x) + x
            x = ff(x) + x
            if cnt+1 in self.extract_layers:
                extract_layers.append(x)
            
        return extract_layers

In [14]:
TB = TransformerBlock()
summary(TB, embedding_x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1            [-1, 2744, 768]           1,536
            Linear-2           [-1, 2744, 2304]       1,771,776
           Dropout-3        [-1, 8, 2744, 2744]               0
            Linear-4            [-1, 2744, 768]         590,592
MultiHeadAttention-5            [-1, 2744, 768]               0
           PreNorm-6            [-1, 2744, 768]               0
         LayerNorm-7            [-1, 2744, 768]           1,536
            Linear-8           [-1, 2744, 3072]       2,362,368
              GELU-9           [-1, 2744, 3072]               0
          Dropout-10           [-1, 2744, 3072]               0
           Linear-11            [-1, 2744, 768]       2,360,064
          PreNorm-12            [-1, 2744, 768]               0
        LayerNorm-13            [-1, 2744, 768]           1,536
           Linear-14           [-1, 274

In [15]:
embedding_x_list = TB(embedding_x)
for i in embedding_x_list:
    print(i.shape)

torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])
torch.Size([1, 2744, 768])


## UNETR

위에서 선언해준 block들을 통해 UNETR을 빌드한다.  
옵션에 `light_r` 를 추가해주었는데, 논문 그대로 설정하면 메모리 문제로 빌드가 안된다.. (OOM)

In [16]:
class UNETR(nn.Module):
    def __init__(self, img_shape=(224, 224, 224), input_dim=3, output_dim=3, 
                 embed_dim=768, patch_size=16, num_heads=8, dropout=0.1, light_r=4):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embed_dim = embed_dim
        self.img_shape = img_shape
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.num_layers = 12
        self.ext_layers = [3, 6, 9, 12]

        self.patch_dim = [int(x / patch_size) for x in img_shape]
        self.conv_channels = [int(i/light_r) for i in [32, 64, 128, 256, 512, 1024]]

        self.embedding = Embeddings((input_dim,*img_shape))
        
        # Transformer Encoder
        self.transformer = \
            TransformerBlock(
            )

        # U-Net Decoder
        self.decoder0 = \
            nn.Sequential(
                Conv3DBlock(input_dim, self.conv_channels[0], 3),
                Conv3DBlock(self.conv_channels[0], self.conv_channels[1], 3)
            )

        self.decoder3 = \
            nn.Sequential(
                Deconv3DBlock(embed_dim, self.conv_channels[2]),
                Deconv3DBlock(self.conv_channels[2], self.conv_channels[2]),
                Deconv3DBlock(self.conv_channels[2], self.conv_channels[2])
            )

        self.decoder6 = \
            nn.Sequential(
                Deconv3DBlock(embed_dim, self.conv_channels[3]),
                Deconv3DBlock(self.conv_channels[3], self.conv_channels[3]),
            )

        self.decoder9 = \
            Deconv3DBlock(embed_dim, self.conv_channels[4])

        self.decoder12_upsampler = \
            SingleDeconv3DBlock(embed_dim, self.conv_channels[4])

        self.decoder9_upsampler = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[5], self.conv_channels[3]),
                Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
                Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
                SingleDeconv3DBlock(self.conv_channels[3], self.conv_channels[3])
            )

        self.decoder6_upsampler = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[4], self.conv_channels[2]),
                Conv3DBlock(self.conv_channels[2], self.conv_channels[2]),
                SingleDeconv3DBlock(self.conv_channels[2], self.conv_channels[2])
            )

        self.decoder3_upsampler = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[3], self.conv_channels[1]),
                Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
                SingleDeconv3DBlock(self.conv_channels[1], self.conv_channels[1])
            )

        self.decoder0_header = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[2], self.conv_channels[1]),
                Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
                SingleConv3DBlock(self.conv_channels[1], output_dim, 1)
            )

    def forward(self, x):
        z0 = x
        x = self.embedding(x)
        z = self.transformer(x)
        z3, z6, z9, z12 = z
        z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)

        z12 = self.decoder12_upsampler(z12)
        z9 = self.decoder9(z9)
        z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1))
        z6 = self.decoder6(z6)
        z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1))
        z3 = self.decoder3(z3)
        z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1))
        z0 = self.decoder0(z0)
        output = self.decoder0_header(torch.cat([z0, z3], dim=1))
        return output

In [17]:
from torchsummary import summary

# x = torch.rand(1,1,224,224,224)
x = torch.rand(1,1,128,128,128)
# x = torch.rand(1,1,64,64,64)
model = UNETR(img_shape=x.shape[2:], input_dim=x.shape[1], output_dim=4, 
              embed_dim=768, patch_size=16, num_heads=8, dropout=0., light_r=4)
summary(model, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1         [-1, 768, 8, 8, 8]       3,146,496
           Dropout-2             [-1, 512, 768]               0
        Embeddings-3             [-1, 512, 768]               0
         LayerNorm-4             [-1, 512, 768]           1,536
            Linear-5            [-1, 512, 2304]       1,771,776
           Dropout-6          [-1, 8, 512, 512]               0
            Linear-7             [-1, 512, 768]         590,592
MultiHeadAttention-8             [-1, 512, 768]               0
           PreNorm-9             [-1, 512, 768]               0
        LayerNorm-10             [-1, 512, 768]           1,536
           Linear-11            [-1, 512, 3072]       2,362,368
             GELU-12            [-1, 512, 3072]               0
          Dropout-13            [-1, 512, 3072]               0
           Linear-14             [-1, 5

# FULL code

In [18]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from torch import Tensor

class SingleDeconv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super().__init__()
        self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)

    def forward(self, x):
        return self.block(x)


class SingleConv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size):
        super().__init__()
        self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
                               padding=((kernel_size - 1) // 2))

    def forward(self, x):
        return self.block(x)


class Conv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            SingleConv3DBlock(in_planes, out_planes, kernel_size),
            nn.BatchNorm3d(out_planes),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.block(x)
    
    
class Deconv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            SingleDeconv3DBlock(in_planes, out_planes),
            SingleConv3DBlock(out_planes, out_planes, kernel_size),
            nn.BatchNorm3d(out_planes),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.block(x)    

class Embeddings(nn.Module):
    def __init__(self, input_shape, patch_size=16, embed_dim=768, dropout=0.):
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = input_shape[-4]
        self.n_patches = int((input_shape[-1] * input_shape[-2] * input_shape[-3]) / (patch_size * patch_size * patch_size))
        self.embed_dim = embed_dim
        self.patch_embeddings = nn.Conv3d(in_channels=self.in_channels, out_channels=self.embed_dim,
                                          kernel_size=self.patch_size, stride=self.patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, self.embed_dim))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.patch_embeddings(x)
        x = rearrange(x, "b n h w d -> b (h w d) n")
        # batch, embed_dim, height/patch, width/patch, depth/patch
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int = 768, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, depth=12, dropout=0., extract_layers=[3,6,9,12]):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(embed_dim, MultiHeadAttention(embed_dim, num_heads, dropout)),
                PreNorm(embed_dim, FeedForwardBlock(embed_dim, expansion=4))
            ]))            
        self.extract_layers = extract_layers
        
    def forward(self, x):
        extract_layers = []
        
        for cnt, (attn, ff) in enumerate(self.layers):
            x = attn(x) + x
            x = ff(x) + x
            if cnt+1 in self.extract_layers:
                extract_layers.append(x)
            
        return extract_layers

class UNETR(nn.Module):
    def __init__(self, img_shape=(224, 224, 224), input_dim=3, output_dim=3, 
                 embed_dim=768, patch_size=16, num_heads=8, dropout=0.1, light_r=4):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embed_dim = embed_dim
        self.img_shape = img_shape
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.num_layers = 12
        self.ext_layers = [3, 6, 9, 12]

        self.patch_dim = [int(x / patch_size) for x in img_shape]
        self.conv_channels = [int(i/light_r) for i in [32, 64, 128, 256, 512, 1024]]

        self.embedding = Embeddings((input_dim,*img_shape))
        
        # Transformer Encoder
        self.transformer = \
            TransformerBlock(
            )

        # U-Net Decoder
        self.decoder0 = \
            nn.Sequential(
                Conv3DBlock(input_dim, self.conv_channels[0], 3),
                Conv3DBlock(self.conv_channels[0], self.conv_channels[1], 3)
            )

        self.decoder3 = \
            nn.Sequential(
                Deconv3DBlock(embed_dim, self.conv_channels[2]),
                Deconv3DBlock(self.conv_channels[2], self.conv_channels[2]),
                Deconv3DBlock(self.conv_channels[2], self.conv_channels[2])
            )

        self.decoder6 = \
            nn.Sequential(
                Deconv3DBlock(embed_dim, self.conv_channels[3]),
                Deconv3DBlock(self.conv_channels[3], self.conv_channels[3]),
            )

        self.decoder9 = \
            Deconv3DBlock(embed_dim, self.conv_channels[4])

        self.decoder12_upsampler = \
            SingleDeconv3DBlock(embed_dim, self.conv_channels[4])

        self.decoder9_upsampler = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[5], self.conv_channels[3]),
                Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
                Conv3DBlock(self.conv_channels[3], self.conv_channels[3]),
                SingleDeconv3DBlock(self.conv_channels[3], self.conv_channels[3])
            )

        self.decoder6_upsampler = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[4], self.conv_channels[2]),
                Conv3DBlock(self.conv_channels[2], self.conv_channels[2]),
                SingleDeconv3DBlock(self.conv_channels[2], self.conv_channels[2])
            )

        self.decoder3_upsampler = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[3], self.conv_channels[1]),
                Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
                SingleDeconv3DBlock(self.conv_channels[1], self.conv_channels[1])
            )

        self.decoder0_header = \
            nn.Sequential(
                Conv3DBlock(self.conv_channels[2], self.conv_channels[1]),
                Conv3DBlock(self.conv_channels[1], self.conv_channels[1]),
                SingleConv3DBlock(self.conv_channels[1], output_dim, 1)
            )

    def forward(self, x):
        z0 = x
        x = self.embedding(x)
        z = self.transformer(x)
        z3, z6, z9, z12 = z
        z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)
        z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)

        z12 = self.decoder12_upsampler(z12)
        z9 = self.decoder9(z9)
        z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1))
        z6 = self.decoder6(z6)
        z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1))
        z3 = self.decoder3(z3)
        z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1))
        z0 = self.decoder0(z0)
        output = self.decoder0_header(torch.cat([z0, z3], dim=1))
        return output