In [2]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F


class GDFN_1(nn.Module):
    def __init__(self, channels, expansion_factor):  
        super(GDFN_1, self).__init__()

        hidden_channels = int(channels * expansion_factor)  # channel expansion 
        # 1x1 conv to extend feature channel
        self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False)  
        
        # 3x3 DW Conv (groups=input_channels) -> each input channel is convolved with its own set of filters
        self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1,
                              groups=hidden_channels * 2, bias=False)
        
        # 1x1 conv to reduce channels back to original input dimension
        self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False)
        
    def forward(self, x):
        '''HxWxC -> HxWxC'''
        x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1)
        # Gating: the element-wise product of 2 parallel paths of linear transformation layers 
        x = self.project_out(F.gelu(x1) * x2)
        
        return x    
    
    
class MDTA_1(nn.Module):
    def __init__(self, channels, num_heads):
        super(MDTA_1, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1))

        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.qkv_conv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False)  # DConv

        self.project_out_1 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
        self.project_out_2 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
        self.project_out_3 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
        self.project_out_4 = nn.Conv2d(channels, channels, kernel_size=1, bias=False)


    def forward(self, x, y):
        '''x, y: (N, C, H, W) - features to fuse'''
        b, c, h, w = x.shape

        q_x, k_x, v_x = self.qkv_conv(self.qkv(x)).chunk(3, dim=1)  # (N, C, H, W)  
        q_y, k_y, v_y = self.qkv_conv(self.qkv(y)).chunk(3, dim=1)  # (N, C, H, W)
        
        # divide the # of channels into heads & learn separate attention map
        q_x = q_x.reshape(b, self.num_heads, -1, h * w)  # (N, num_heads, C/num_heads, HW)
        k_x = k_x.reshape(b, self.num_heads, -1, h * w)
        v_x = v_x.reshape(b, self.num_heads, -1, h * w)

        q_y = q_y.reshape(b, self.num_heads, -1, h * w)  # (N, num_heads, C/num_heads, HW)
        k_y = k_y.reshape(b, self.num_heads, -1, h * w)
        v_y = v_y.reshape(b, self.num_heads, -1, h * w)

        q_x, k_x = F.normalize(q_x, dim=-1), F.normalize(k_x, dim=-1)
        q_y, k_y = F.normalize(q_y, dim=-1), F.normalize(k_y, dim=-1)
        
        # SA(Intra) - CxC Self Attention map instead of HWxHW (when num_heads=1)
        self_attn_x = torch.softmax(torch.matmul(q_x, k_x.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)  # (N, num_heads, C/num_heads, C_num_heads)
        self_attn_y = torch.softmax(torch.matmul(q_y, k_y.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)
        
        intra_x = self.project_out_1(torch.matmul(self_attn_x, v_x).reshape(b, -1, h, w))  # (N, C, H, W)
        intra_y = self.project_out_2(torch.matmul(self_attn_y, v_y).reshape(b, -1, h, w))

        # CA(Inter) - CxC Cross Attention map instead of HWxHW (when num_heads=1)
        cross_attn_xy = torch.softmax(torch.matmul(q_x, k_y.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)  # (N, num_heads, C/num_heads, C_num_heads)
        cross_attn_yx = torch.softmax(torch.matmul(q_y, k_x.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)
        cross_attn_yx = cross_attn_yx.squeeze(0)
        
        inter_xy = self.project_out_3(torch.matmul(cross_attn_xy, v_y).reshape(b, -1, h, w))  # (N, C, H, W)
        inter_yx = self.project_out_4(torch.matmul(cross_attn_yx, v_x).reshape(b, -1, h, w))
        # out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w))

        return intra_x, intra_y, inter_xy, inter_yx

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, channels, num_heads, expansion_factor):
        super(TransformerBlock, self).__init__()

        self.norm1_1 = nn.LayerNorm(channels)
        self.norm1_2 = nn.LayerNorm(channels)
        
        self.attn = MDTA_1(channels, num_heads)
        
        self.norm2_1 = nn.LayerNorm(channels)
        self.norm2_2 = nn.LayerNorm(channels)
        self.norm2_3 = nn.LayerNorm(channels)
        self.norm2_4 = nn.LayerNorm(channels)

        # parallel GDFNs
        self.ffn_1 = GDFN_1(channels, expansion_factor)
        self.ffn_2 = GDFN_1(channels, expansion_factor)
        self.ffn_3 = GDFN_1(channels, expansion_factor)
        self.ffn_4 = GDFN_1(channels, expansion_factor)

    def forward(self, x, y):
        b, c, h, w = x.shape
        
        # SA feature-x, SA feature-y, CA feature-(query=x, key=y), CA featue(query=y, key=x)
        sa_x, sa_y, cross_xy, cross_yx = self.attn(self.norm1_1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                                                   .contiguous().reshape(b, c, h, w), self.norm1_2(y.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                                                   .contiguous().reshape(b, c, h, w))
        x = x + sa_x
        y = y + sa_y
        
        # GDFNs
        x = x + self.ffn_1(self.norm2_1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                         .contiguous().reshape(b, c, h, w))
        
        y = y + self.ffn_2(self.norm2_2(y.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                         .contiguous().reshape(b, c, h, w))
        
        ca_xy = cross_xy + self.ffn_3(self.norm2_3(cross_xy.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                                      .contiguous().reshape(b, c, h, w))
        
        ca_yx = cross_yx + self.ffn_4(self.norm2_4(cross_yx.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                                      .contiguous().reshape(b, c, h, w))

        return x, y, ca_xy, ca_yx

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, channels, num_heads, expansion_factor):
        super(TransformerBlock, self).__init__()

        self.norm1_1 = nn.LayerNorm(channels)
        self.norm1_2 = nn.LayerNorm(channels)
        
        self.attn = MDTA_1(channels, num_heads)
        
        self.norm2_1 = nn.LayerNorm(channels)
        self.norm2_2 = nn.LayerNorm(channels)

        # parallel GDFNs
        self.ffn_1 = GDFN_1(channels, expansion_factor)
        self.ffn_2 = GDFN_1(channels, expansion_factor)

    def forward(self, x, y):
        b, c, h, w = x.shape
        
        # SA feature-x, SA feature-y, CA feature-(query=x, key=y), CA featue(query=y, key=x)
        _, _, cross_xy, cross_yx = self.attn(self.norm1_1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                                             .contiguous().reshape(b, c, h, w), self.norm1_2(y.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                                             .contiguous().reshape(b, c, h, w))
        x = x + cross_xy
        y = y + cross_yx
        
        # GDFNs
        ca_xy = x + self.ffn_1(self.norm2_1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                               .contiguous().reshape(b, c, h, w))
        
        ca_yx = y + self.ffn_2(self.norm2_2(y.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                               .contiguous().reshape(b, c, h, w))

        return ca_xy, ca_yx