In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pdb
print_tensor = lambda n, x: print(n, type(x), x.shape, x.min(), x.max())
class WindowAttentionTest(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        print_tensor('pos_bias_tb', self.relative_position_bias_table)
        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        
        print_tensor('coords', coords)
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        print_tensor('relative coords', relative_coords)

        pdb.set_trace()
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0

        relative_coords[:, :, 1] += self.window_size[1] - 1

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

        
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """ Forward function.

        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        


In [None]:
windatt = WindowAttentionTest(32, (3, 3), 3)

In [7]:

class WindowAttention3d(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window. example [7, 7, 7]
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww, Wd
        self.window_volume = np.prod(self.window_size)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads))  
            # 2*Wh-1 * 2*Ww-1 * 2*Wd-1 nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords_d = torch.arange(self.window_size[2])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_d]))  # 3, Wh, Ww, Wd
        coords_flatten = torch.flatten(coords, 1)  # 3, Wh*Ww*Wd
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 3, Wh*Ww*Wd, Wh*Ww*Wd

        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww*Wd, Wh*Ww*Wd, 3
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1

        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) #TODO: how to adapt here 
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww*Wd, Wh*Ww*Wd
        self.register_buffer("relative_position_index", relative_position_index)
        
        pdb.set_trace()
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout3d(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout3d(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1) # 

    def forward(self, x, mask=None):
        """ Forward function.

        Args:
            x: input features with shape of (num_windows*B, N, C) #  # x_windows: nW*B, window_size^3, C
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None # attn_mask: nW, window_size^3, window_size^3;
        """
        B_, N, C = x.shape # 

In [None]:
windatt3d = WindowAttention3d(32, (2, 2, 2), 1)

> [0;32m<ipython-input-7-4b19298a9d19>[0m(50)[0;36m__init__[0;34m()[0m
[0;32m     48 [0;31m[0;34m[0m[0m
[0m[0;32m     49 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 50 [0;31m        [0mself[0m[0;34m.[0m[0mqkv[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mLinear[0m[0;34m([0m[0mdim[0m[0;34m,[0m [0mdim[0m [0;34m*[0m [0;36m3[0m[0;34m,[0m [0mbias[0m[0;34m=[0m[0mqkv_bias[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m        [0mself[0m[0;34m.[0m[0mattn_drop[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mDropout3d[0m[0;34m([0m[0mattn_drop[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m        [0mself[0m[0;34m.[0m[0mproj[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mLinear[0m[0;34m([0m[0mdim[0m[0;34m,[0m [0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> relative_position_index
tensor([[13, 12, 10,  9,  4,  3,  1,  0]