In [2]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.crossformer import Attention
from models.crossformer import DynamicPosBias

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
depths = [2, 2, 18, 2]
layers = len(depths)
patches_resolution = [56, 56]
min_size = 4
total_depth = sum(depths)
step_size = (1 - min_size / patches_resolution[0]) / total_depth
group_fraction = np.arange(min_size / patches_resolution[0], 1.0, step_size)
# print(len(group_fraction))

cnt = 0
for i_layer in range(layers):
    cur_resolution = patches_resolution[0] // 2 ** i_layer
    for j in range(depths[i_layer]):
        group_size = cur_resolution * group_fraction[cnt]
        if group_size > cur_resolution // 2:
            group_size = cur_resolution if group_size > cur_resolution * 3 / 4 else cur_resolution // 2
        print(max(4, int(np.ceil(group_size))))
        cnt += 1


4
9
7
9
6
7
7
7
7
14
7
7


In [5]:
a = torch.ones((7, 7)).cuda()
a = F.pad(a, [2, 1, 2, 1])
a = a.reshape((5, 2, 5, 2)).permute((0, 2, 1, 3)).flatten(-2).flatten(0, 1)
print(a[24])
a = a.unsqueeze(-1) * a.unsqueeze(-2)
print(a[24])
a = (1 - a) * -10000
print(a[24])

tensor([1., 0., 0., 0.], device='cuda:0')
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')
tensor([[    -0., -10000., -10000., -10000.],
        [-10000., -10000., -10000., -10000.],
        [-10000., -10000., -10000., -10000.],
        [-10000., -10000., -10000., -10000.]], device='cuda:0')


In [13]:
C = 128
n_head = 8
B = 16
G = 5
H, W = 14, 14

attn = Attention(C, (G, G), n_head).cuda()

mask = torch.ones((H, W)).cuda()
pad_num = (int(np.ceil(H / G))) * G - H
if pad_num % 2 == 0:
    pad_size = [pad_num // 2] * 4
else:
    pad_size = [(pad_num - 1) // 2, (pad_num + 1) // 2] * 2

mask = F.pad(mask, pad_size)
mask = mask.reshape((int(np.ceil(H / G)), G, int(np.ceil(H / G)), G)).permute((0, 2, 1, 3)).flatten(-2).flatten(0, 1)
mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
mask = (1 - mask) * -10000

x = torch.rand((B, H * W, C)).cuda()
x = x.view(B, H, W, C)
x_pad = F.pad(x, [0, 0] + pad_size)
x_pad = x_pad.reshape(B, int(np.ceil(H / G)), G, int(np.ceil(H / G)), G, C).permute(0, 1, 3, 2, 4, 5)
x_pad = x_pad.reshape(-1, G**2, C)

print(x_pad.shape)

out_pad = attn(x_pad, mask)
out_pad = out_pad.reshape(B, int(np.ceil(H / G)), int(np.ceil(H / G)), G, G, C)
out_pad = out_pad.permute(0, 1, 3, 2, 4, 5).reshape(B, H + pad_num, W + pad_num, C)
out = out_pad[:, pad_num // 2: pad_num // 2 + H, pad_num // 2: pad_num // 2 + W, :]
print(out.shape)

torch.Size([144, 25, 128])
torch.Size([16, 14, 14, 128])


In [14]:
print(x_pad[-1, :, 0].reshape(G, G))
print(mask[-1, 0].reshape(G, G) / 10000)

tensor([[0.1455, 0.2386, 0.0221, 0.0086, 0.0000],
        [0.5513, 0.9299, 0.7500, 0.5527, 0.0000],
        [0.6135, 0.8574, 0.8102, 0.8803, 0.0000],
        [0.9937, 0.5681, 0.9847, 0.0302, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0')
tensor([[-0., -0., -0., -0., -1.],
        [-0., -0., -0., -0., -1.],
        [-0., -0., -0., -0., -1.],
        [-0., -0., -0., -0., -1.],
        [-1., -1., -1., -1., -1.]], device='cuda:0')


In [15]:
class Mutual_Attention(nn.Module):
    r""" Multi-head self attention module with dynamic position bias.

    Args:
        dim (int): Number of input channels.
        group_size (tuple[int]): The height and width of the group.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 position_bias=True):

        super().__init__()
        self.dim = dim
        self.group_size = group_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.position_bias = position_bias

        self.q  = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, y, mask=None):
        """
        Args:
            x: input features with shape of (num_groups*B, N, C)
            mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        _,  L, C = y.shape
        q  = self.q(x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        kv = self.kv(y).reshape(B_, L, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        # @ stands for matrix multiplication
        attn = (q @ k.transpose(-2, -1))

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 group with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        if self.position_bias:
            flops += self.pos.flops(N)
        return flops

In [16]:
C = 128
n_head = 8
B = 16
G = 7
H, W = 14, 14

attn = Mutual_Attention(C, (G, G), n_head, position_bias=False).cuda()
depth_conv = nn.Conv2d(in_channels=C, out_channels=C, kernel_size=3, padding=1, groups=C).cuda()

In [17]:
x = torch.rand((B, H, W, C)).cuda()
x = x.permute(0, 3, 1, 2)
x = x + depth_conv(x)
x = x.permute(0, 2, 3, 1)
y = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)
y = y.reshape(B, H * W // G**2, G**2, C).mean(dim=2)
x = x.reshape(B, H * W, C)

out = attn(x, y)
out = out.reshape(B, H, W, C)
print(out.shape)

torch.Size([16, 14, 14, 128])


In [11]:
class Pad_Attention(nn.Module):
    r""" Multi-head self attention module with dynamic position bias.

    Args:
        dim (int): Number of input channels.
        group_size (tuple[int]): The height and width of the group.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 position_bias=True):

        super().__init__()
        self.dim = dim
        self.group_size = group_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.position_bias = position_bias

        if position_bias:
            self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
            
            # generate mother-set
            position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
            position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
            biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2Ww-1
            biases = biases.flatten(1).transpose(0, 1).float()
            self.register_buffer("biases", biases)

            # get pair-wise relative position index for each token inside the group
            coords_h = torch.arange(self.group_size[0])
            coords_w = torch.arange(self.group_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.group_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (4, num_groups*B, N, C // 4)
            mask: (0/-inf) mask with shape of (4, num_groups, Wh*Ww, Wh*Ww) or None
        """
        _, B_, N, C = x.shape
        qkv = self.qkv(x).reshape(4, B_, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
        # q, k, v shape: [4, B, nH, N, C // nH]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        # @ stands for matrix multiplication
        attn = (q @ k.transpose(-2, -1))

        if self.position_bias:
            pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads
            # select position bias
            relative_position_bias = pos[self.relative_position_index.view(-1)].view(
                self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[1]
            attn = attn.view(4, B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(2).unsqueeze(1)
            attn = attn.view(4, -1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(2, 3).reshape(4, B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 group with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        if self.position_bias:
            flops += self.pos.flops(N)
        return flops

In [12]:
C = 128
n_head = 8
B = 16
G = 12
H, W = 14, 14

attn = Pad_Attention(C // 4, (G, G), n_head).cuda()

mask = torch.ones((H, W)).cuda()
pad_num  = (int(np.ceil(H / G))) * G - H
pad_size = [pad_num, 0, pad_num, 0]

mask = F.pad(mask, pad_size)
mask = torch.stack([mask, torch.fliplr(mask), torch.flipud(mask), torch.fliplr(torch.flipud(mask))], dim=0)
mask = mask.reshape(4, int(np.ceil(H / G)), G, int(np.ceil(H / G)), G).permute(0, 1, 3, 2, 4).flatten(-2).flatten(1, 2)
mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
mask = (1 - mask) * -10000
print(mask.shape)

torch.Size([4, 4, 144, 144])


In [13]:
x = torch.rand((B, H * W, C)).cuda()
x = x.view(B, H, W, C)

new_C = C // 4
x_pad = torch.stack([
    F.pad(x[:, :, :, :C//4],       [0, 0, pad_num, 0, pad_num, 0]), 
    F.pad(x[:, :, :, C//4:C//2],   [0, 0, 0, pad_num, pad_num, 0]), 
    F.pad(x[:, :, :, C//2:3*C//4], [0, 0, pad_num, 0, 0, pad_num]), 
    F.pad(x[:, :, :, 3*C//4:],     [0, 0, 0, pad_num, 0, pad_num])], 
    dim=0) # [4, B, H + pad_num, W + pad_num, C // 4]
x_pad = x_pad.reshape(4, B, (H + pad_num) // G, G, (W + pad_num) // G, G, new_C).permute(0, 1, 2, 4, 3, 5, 6)
x_pad = x_pad.reshape(4, -1, G**2, new_C)

print(x_pad.shape)

torch.Size([4, 64, 144, 32])


In [14]:
out = attn(x_pad, mask)
print(out.shape)
out = out.reshape(4, B, (H + pad_num) // G, (W + pad_num) // G, G, G, new_C)
out = out.permute(0, 1, 2, 4, 3, 5, 6).reshape(4, B, H + pad_num, W + pad_num, new_C)
out = torch.cat([
    out[0, :, pad_num:, pad_num:, :], 
    out[1, :, pad_num:, :W, :],
    out[2, :, :H, pad_num:, :],
    out[3, :, :H, :W, :]], 
    dim=-1)
print(out.shape)

torch.Size([4, 64, 144, 32])
torch.Size([16, 14, 14, 128])
