In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbed3D(nn.Module):
    def __init__(self,
                 patch_size = (2, 4, 4),
                 in_chans = 3,
                 embed_dim = 96,
                 norm_layer = None):
        super().__init__()

        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        # padding
        _, _, D, H, W = x.size()
        if W % self.patch_size[2] != 0:
            x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))

        if H % self.patch_size[1] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))

        if D % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))

        x = self.proj(x)

        if self.norm is not None:
            D, Wh, Ww = x.size(2), x.size(3), x.size(4)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)

        return x

In [2]:
sample = torch.randn(2, 1, 96, 384, 384)

In [3]:
model = PatchEmbed3D(patch_size=(16, 64, 64),
                     in_chans=1,
                     embed_dim=96,
                     norm_layer=nn.LayerNorm)

In [4]:
pred = model(sample)
print(pred.shape) # D // p, H // p,  W // p

torch.Size([2, 96, 6, 6, 6])


In [7]:
drop_path_rate = 0.2
depths = [2, 2, 6, 2]
print(sum(depths))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
print(dpr)

12
[0.0, 0.0181818176060915, 0.036363635212183, 0.05454545468091965, 0.072727270424366, 0.09090908616781235, 0.10909091681241989, 0.12727272510528564, 0.1454545557498932, 0.16363637149333954, 0.1818181872367859, 0.20000000298023224]


In [8]:
B, C, D, H, W = pred.shape

In [9]:
def get_window_size(x_size, window_size, shift_size = None):
    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0

    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)

In [11]:
window_size = (1, 7, 7)
shift_size = tuple(i // 2 for i in window_size)
print(window_size)
print(shift_size)

(1, 7, 7)
(0, 3, 3)


In [13]:
window_size, shift_size = get_window_size((D, H, W), window_size, shift_size)
print(window_size)
print(shift_size)

(1, 6, 6)
(0, 0, 0)


In [18]:
from functools import reduce, lru_cache
from operator import mul

def window_partition(x, window_size):
    B, D, H, W, C = x.shape

    x = x.view(B, 
               D //  window_size[0], window_size[0],
               H // window_size[1], window_size[1], 
               W // window_size[2], window_size[2], C)
    
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)

    return windows

def compute_mask(D, H, W, window_size, shift_size, device):
    img_mask = torch.zeros((1, D, H, W, 1), device=device)
    cnt = 0
    for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
        for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
            for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
                img_mask[:, d, h, w, :] = cnt
                cnt += 1
    mask_windows = window_partition(img_mask, window_size)  # nW, ws[0]*ws[1]*ws[2], 1
    mask_windows = mask_windows.squeeze(-1)  # nW, ws[0]*ws[1]*ws[2]
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask

In [16]:
from einops import rearrange
import numpy as np

x = rearrange(pred, 'b c d h w -> b d h w c')
print(x.shape)

Dp = int(np.ceil(D / window_size[0])) * window_size[0]
Hp = int(np.ceil(H / window_size[1])) * window_size[1]
Wp = int(np.ceil(W / window_size[2])) * window_size[2]
print(f"Dp : {Dp}, Hp : {Hp}, Wp : {Wp}")

torch.Size([2, 6, 6, 6, 96])
Dp : 6, Hp : 6, Wp : 6


In [19]:
attn_mask = compute_mask(Dp, Wp, Hp, window_size, shift_size, x.device)

In [20]:
window_size, shift_size = get_window_size((D, H, W),
                                            window_size,
                                            shift_size)

In [21]:
pad_l = pad_t = pad_d0 = 0
pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))

In [22]:
x.shape

torch.Size([2, 6, 6, 6, 96])