In [1]:
from cmath import sqrt
from re import X
import torch
import torch.nn as nn
from torch import einsum
from typing import Tuple
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.nn import functional as F
from torch.nn import Module, Conv2d, Parameter, Softmax
from torchvision import models
from torch.nn import init
from torchinfo import summary
# from torchstat import stat

from functools import partial
# from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
# from timm.models.layers import DropPath, trunc_normal_


In [2]:

class MLP_FFN(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [3]:
# ----New Bridge---------
# Spatial Fuse module
class MultiScaleAtten(nn.Module):
    def __init__(self, dim):
        super(MultiScaleAtten, self).__init__()
        self.qkv_linear = nn.Linear(dim, dim * 3)
        self.softmax = nn.Softmax(dim=-1)
        self.proj = nn.Linear(dim, dim)
        self.num_head = 8
        self.scale = (dim // self.num_head)**0.5

    def forward(self, x):
        B, num_blocks, _, _, C = x.shape  # (B, num_blocks, num_blocks, N, C)
        qkv = self.qkv_linear(x).reshape(B, num_blocks, num_blocks, -1, 3, self.num_head, C // self.num_head).permute(4, 0, 1, 2, 5, 3, 6).contiguous() # (3, B, num_block, num_block, head, N, C)
        q, k, v = qkv[0], qkv[1], qkv[2]
        atten = q @ k.transpose(-1, -2).contiguous()
        atten = self.softmax(atten)
        atten_value = (atten @ v).transpose(-2, -3).contiguous().reshape(B, num_blocks, num_blocks, -1, C)
        atten_value = self.proj(atten_value)  # (B, num_block, num_block, N, C)
        return atten_value


class InterTransBlock(nn.Module):
    def __init__(self, dim):
        super(InterTransBlock, self).__init__()
        self.SlayerNorm_1 = nn.LayerNorm(dim, eps=1e-6)
        self.SlayerNorm_2 = nn.LayerNorm(dim, eps=1e-6)
        self.Attention = MultiScaleAtten(dim)
        self.mlp = MLP_FFN(dim,4*dim)

    def forward(self, x):
        h = x  # (B, N, H)
        x = self.SlayerNorm_1(x)

        x = self.Attention(x)  # padding 到right_size
        x = h + x

        h = x
        x = self.SlayerNorm_2(x)

        x = self.mlp(x)
        x = h + x

        return x


class SpatialAwareTrans(nn.Module):
    def __init__(self, dim=64, num_sp_layer=1):  # (224*64, 112*128, 56*256, 28*256, 14*512) dim = 256
        super(SpatialAwareTrans, self).__init__()
        self.win_size_list = [8,4,2,1]
        self.channels = [64, 64*2, 64*5, 64*8]
        self.dim = dim
        self.depth = 4
        self.fc1 = nn.Linear(self.channels[0],dim)
        self.fc2 = nn.Linear(self.channels[1],dim)
        self.fc3 = nn.Linear(self.channels[2],dim)
        self.fc4 = nn.Linear(self.channels[3],dim)

        self.fc1_back = nn.Linear(dim, self.channels[0])
        self.fc2_back = nn.Linear(dim, self.channels[1])
        self.fc3_back = nn.Linear(dim, self.channels[2])
        self.fc4_back = nn.Linear(dim, self.channels[3])

        self.fc_back = nn.ModuleList()
        for i in range(self.depth):
            self.fc_back.append(nn.Linear(self.dim, self.channels[i]))
      
        self.num = num_sp_layer # the number of layers
    

        self.group_attention = []
        for i in range(self.num):
            self.group_attention.append(InterTransBlock(dim))
        self.group_attention = nn.Sequential(*self.group_attention)
        self.split_list = [8 * 8, 4 * 4, 2 * 2, 1 * 1]

    def forward(self, x):
        # project channel dimension to 256
        # print("Start spatial aware:------------")
        # print(f"x_0:{x[0].shape}")
        # print(f"x_1:{x[1].shape}")
        # print(f"x_2:{x[2].shape}")
        # print(f"x_3:{x[3].shape}")

        # utilize linear to project from other channel number to 256(C)
        x[0] = self.fc1(x[0].permute(0,2,3,1))
        x[1] = self.fc2(x[1].permute(0,2,3,1))
        x[2] = self.fc3(x[2].permute(0,2,3,1))
        x[3] = self.fc4(x[3].permute(0,2,3,1))
        # x = [self.fc_module[i](item.permute(0, 2, 3, 1)) for i, item in enumerate(x)]  # [(B, H, W, C)]
        # Patch Matching
        # print("-----------------")
        for j, item in enumerate(x):
            # print(f"#{j} shape: {item.shape}")
            B, H, W, C = item.shape
            win_size = self.win_size_list[j]
            # print(f'window size: {win_size}')
            item = item.reshape(B, H // win_size, win_size, W // win_size, win_size, C).permute(0, 1, 3, 2, 4, 5).contiguous()#([B,H/win,W/win, win,win,C])
            # print(f'reshape first step:{item.shape}')
            item = item.reshape(B, H // win_size, W // win_size, win_size * win_size, C).contiguous()#([B,H/win,W/win, win*win,C])
            # print(f'reshape second step:{item.shape}')
            x[j] = item
        x = tuple(x)
        x = torch.cat(x, dim=-2)  # (B, H // win, W // win, N, C)
        # print(f"\n fuse the four level together:{x.shape}")
        
        # Scale fusion
        for i in range(self.num):
            x = self.group_attention[i](x)  # (B, H // win_size, W // win_size, win_size*win_size, C)

        x = torch.split(x, self.split_list, dim=-2)
        x = list(x)
        # patch reversion
        # print("-------reversion----------")
        for j, item in enumerate(x):
            B, num_blocks, _, N, C = item.shape
            win_size = self.win_size_list[j]
            item = item.reshape(B, num_blocks, num_blocks, win_size, win_size, C).permute(0, 1, 3, 2, 4, 5).contiguous().reshape(B, num_blocks*win_size, num_blocks*win_size, C)
            item = self.fc_back[j](item).permute(0, 3, 1, 2).contiguous()
            # print(f"#{j} shape: {item.shape}")
            x[j] = item
       
        return x


In [4]:
class MixFFN_skip(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.dwconv = DWConv(c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1)
        self.norm1 = nn.LayerNorm(c2)
        self.norm2 = nn.LayerNorm(c2)
        self.norm3 = nn.LayerNorm(c2)
    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        ax = self.act(self.norm1(self.dwconv(self.fc1(x), H, W)+self.fc1(x)))
        out = self.fc2(ax)
        return out
    
class DWConv(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)

    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        B, N, C = x.shape
        # print('input in DWConv: {}'.format(x.shape))
        tx = x.transpose(1, 2).view(B, C, H, W)
        conv_x = self.dwconv(tx)
        return conv_x.flatten(2).transpose(1, 2)



In [5]:
class Scale_reduce(nn.Module):
    def __init__(self, dim, reduction_ratio):
        super().__init__()
        self.dim = dim
        self.reduction_ratio = reduction_ratio
        if(len(self.reduction_ratio)==4):
            self.sr0 = nn.Conv2d(dim, dim, reduction_ratio[3], reduction_ratio[3])
            self.sr1 = nn.Conv2d(dim*2, dim*2, reduction_ratio[2], reduction_ratio[2])
            self.sr2 = nn.Conv2d(dim*5, dim*5, reduction_ratio[1], reduction_ratio[1])
        
        elif(len(self.reduction_ratio)==3):
            self.sr0 = nn.Conv2d(dim*2, dim*2, reduction_ratio[2], reduction_ratio[2])
            self.sr1 = nn.Conv2d(dim*5, dim*5, reduction_ratio[1], reduction_ratio[1])
        
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        if(len(self.reduction_ratio)==4):
            tem0 = x[:,:3136,:].reshape(B, 56, 56, C).permute(0, 3, 1, 2) 
            tem1 = x[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0, 3, 1, 2)
            tem2 = x[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0, 3, 1, 2)
            tem3 = x[:,5684:6076,:]

            sr_0 = self.sr0(tem0).reshape(B, C, -1).permute(0, 2, 1)
            sr_1 = self.sr1(tem1).reshape(B, C, -1).permute(0, 2, 1)
            sr_2 = self.sr2(tem2).reshape(B, C, -1).permute(0, 2, 1)

            reduce_out = self.norm(torch.cat([sr_0, sr_1, sr_2, tem3], -2))
        
        if(len(self.reduction_ratio)==3):
            tem0 = x[:,:1568,:].reshape(B, 28, 28, C*2).permute(0, 3, 1, 2) 
            tem1 = x[:,1568:2548,:].reshape(B, 14, 14, C*5).permute(0, 3, 1, 2)
            tem2 = x[:,2548:2940,:]

            sr_0 = self.sr0(tem0).reshape(B, C, -1).permute(0, 2, 1)
            sr_1 = self.sr1(tem1).reshape(B, C, -1).permute(0, 2, 1)
            
            reduce_out = self.norm(torch.cat([sr_0, sr_1, tem2], -2))
        
        return reduce_out

        


class M_EfficientSelfAtten(nn.Module):
    def __init__(self, dim, head, reduction_ratio):
        super().__init__()
        self.head = head
        self.reduction_ratio = reduction_ratio # list[1  2  4  8]
        self.scale = (dim // head) ** -0.5
        self.q = nn.Linear(dim, dim, bias=True)
        self.kv = nn.Linear(dim, dim*2, bias=True)
        self.proj = nn.Linear(dim, dim)
        
        if reduction_ratio is not None:
            self.scale_reduce = Scale_reduce(dim,reduction_ratio)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)

        if self.reduction_ratio is not None:
            x = self.scale_reduce(x)
            
        kv = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn_score = attn.softmax(dim=-1)

        x_atten = (attn_score @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(x_atten)


        return out



In [6]:
class BridgeLayer_new(nn.Module):
    def __init__(self, dims, head, reduction_ratios):
        super().__init__()
        C = 64
        
        self.norm1 = nn.LayerNorm(dims)
        self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
        self.norm2 = nn.LayerNorm(dims)
        self.mixffn1 = MixFFN_skip(dims,dims*4)
        self.mixffn2 = MixFFN_skip(dims*2,dims*8)
        self.mixffn3 = MixFFN_skip(dims*5,dims*20)
        self.mixffn4 = MixFFN_skip(dims*8,dims*32)
        
        self.scale_fuse_att = SpatialAwareTrans(dim=dims, num_sp_layer=1)
        
    def forward(self, inputs):
        B = inputs[0].shape[0]
        C = 64
        H = 56
        W = 56
        if (type(inputs) == list):
            inputs = self.scale_fuse_att(inputs)
            c1, c2, c3, c4 = inputs
            B, C, _, _= c1.shape
            c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C)  # 3136*64
            c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C)  # 1568*64
            c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C)  # 980*64
            c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C)  # 392*64
            
            # print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
            inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
        else:
            B,_,C = inputs.shape 

        tx1 = inputs + self.attn(self.norm1(inputs))
        tx = self.norm2(tx1)


        tem1 = tx[:,:3136,:].reshape(B, -1, C) 
        tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2)
        tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5)
        tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)

        m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
        m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
        m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
        m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)

        t1 = torch.cat([m1f, m2f, m3f, m4f], -2)
        
        tx2 = tx1 + t1


        return tx2



class BridgeBlock_new(nn.Module):
    def __init__(self, dims, head, reduction_ratios):
        super().__init__()
        self.bridge_layer1 = BridgeLayer_new(dims, head, reduction_ratios)
        self.bridge_layer2 = BridgeLayer_new(dims, head, reduction_ratios)
        self.bridge_layer3 = BridgeLayer_new(dims, head, reduction_ratios)
        self.bridge_layer4 = BridgeLayer_new(dims, head, reduction_ratios)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # print('Checking bridge')
        bridge1 = self.bridge_layer1(x)
        bridge2 = self.bridge_layer2(bridge1)
        bridge3 = self.bridge_layer3(bridge2)
        bridge4 = self.bridge_layer4(bridge3)

        B,_,C = bridge4.shape
        outs = []

        sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2) 
        sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2) 
        sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2) 
        sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2) 

        outs.append(sk1)
        outs.append(sk2)
        outs.append(sk3)
        outs.append(sk4)

        return outs


In [7]:
reduction_ratios = [1, 2, 4, 8]
bridge = BridgeBlock_new(64, 1, reduction_ratios)

In [8]:
# output_enc[0]:torch.Size([1, 64, 56, 56])
# output_enc[1]:torch.Size([1, 128, 28, 28])
# output_enc[2]:torch.Size([1, 320, 14, 14])
# output_enc[3]:torch.Size([1, 512, 7, 7])

out_enc1 = torch.rand(1, 64, 56, 56)
out_enc2 = torch.rand(1, 128, 28, 28)
out_enc3 = torch.rand(1, 320, 14, 14)
out_enc4 = torch.rand(1, 512, 7, 7)

output_enc = []
output_enc.append(out_enc1)
output_enc.append(out_enc2)
output_enc.append(out_enc3)
output_enc.append(out_enc4)

output_bridge = bridge(output_enc)
print(f"output_enc[0]:{output_bridge[0].shape}")
print(f"output_enc[1]:{output_bridge[1].shape}")
print(f"output_enc[2]:{output_bridge[2].shape}")
print(f"output_enc[3]:{output_bridge[3].shape}")

output_enc[0]:torch.Size([1, 64, 56, 56])
output_enc[1]:torch.Size([1, 128, 28, 28])
output_enc[2]:torch.Size([1, 320, 14, 14])
output_enc[3]:torch.Size([1, 512, 7, 7])
