In [23]:
import torch

x1_raw = torch.randn(1, 64, 256, 256)
b, c, h1, w1 = x1_raw.size()

# Step 1: Reshape to [b, c, h1/8*w1/8, 8*8]
x1_step1 = x1_raw.view(b, c, -1, 8, 8).contiguous().view(b, c, -1, 8*8).contiguous()
print("Step 1 - x1_raw and x1_step1 are equal:", torch.allclose(x1_raw, x1_step1.view(b, c, h1, w1), atol=1e-6))

# Step 2: Reshape to [b, c, h1/8, w1/8, 8, 8]
x1_step2 = x1_step1.view(b, c, h1 // 8, w1 // 8, 8, 8).contiguous()
print("Step 2 - x1_step1 and x1_step2 are equal:", torch.allclose(x1_raw, x1_step2.view(b, c, -1, 8*8).view(b, c, h1, w1), atol=1e-6))

# Step 4: Reshape to [b, c, h1, w1]
x1_step3 = x1_step2.view(b, c, h1, w1).contiguous()
# print("Step 4 - x1_step3 and x1_step4 are equal:", torch.allclose(x1_step3.view(b, c, h1 // 8, w1 // 8, 8, 8), x1_step4.view(b, c, h1 // 8, w1 // 8, 8, 8), atol=1e-6))

# Final check
print("x1 input and output are equal:", torch.allclose(x1_raw, x1_step3, atol=1e-6))


Step 1 - x1_raw and x1_step1 are equal: True
Step 2 - x1_step1 and x1_step2 are equal: True
x1 input and output are equal: True


In [None]:
x2 = x2_raw.view(b, c, -1, 4, 4).view(b, c, -1, 4*4)  # [b, C, h1/8*w1/8, 4*4]
x3 = x3_raw.view(b, c, -1, 2, 2).view(b, c, -1, 2*2)  # [b, C, h1/8*w1/8, 2*2]
x4 = x4_raw.view(b, c, -1, 1, 1).view(b, c, -1, 1*1)  # [b, C, h1/8*w1/8, 1*1]

# Concatenate along the last dimension
x = torch.cat((x1, x2, x3, x4), dim=-1)  # [b, C, h1/8*w1/8, 8*8 + 4*4 + 2*2 + 1*1]

# Flatten for MLP
x = x.view(-1, 8*8 + 4*4 + 2*2 + 1*1)  # [b*C,  (8*8 + 4*4 + 2*2 + 1*1)]

# Apply MLP
print(x.shape,"153")
# x = self.mlp(x)  # [b, C * (8*8 + 4*4 + 2*2 + 1*1)]

# Reshape back to original dimensions
x = x.view(b, c, -1, 8*8 + 4*4 + 2*2 + 1*1)  # [b, C, h1/8*w1/8, 8*8 + 4*4 + 2*2 + 1*1]

# Split back into individual tensors
x1, x2, x3, x4 = x.split([8*8, 4*4, 2*2, 1*1], dim=-1)

# Reshape back to original spatial dimensions
x1 = x1.view(b, c, h1 // 8, w1 // 8, 8, 8).permute(0,1,2,4,3,5).contiguous().view(b, c, h1, w1)  # [b, C, h1, w1] 
x2 = x2.view(b, c, h1 // 8, w1 // 8, 4, 4).permute(0,1,2,4,3,5).contiguous().view(b, c, h1 // 2, w1 // 2)  # [b, C, h1/2, w1/2]
x3 = x3.view(b, c, h1 // 8, w1 // 8, 2, 2).permute(0,1,2,4,3,5).contiguous().view(b, c, h1 // 4, w1 // 4)  # [b, C, h1/2, w1/4]
x4 = x4.view(b, c, h1 // 8, w1 // 8, 1, 1).permute(0,1,2,4,3,5).contiguous().view(b, c, h1 // 8, w1 // 8)  # [b, C, h1/8, w1/8]
print(x1_raw-x1)


torch.Size([65536, 85]) 153
tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.5856, -1.2384,  0.6905],
          [-1.1543,  0.0142,  0.9163,  ...,  1.8587,  1.0796, -0.8682],
          [ 0.2547, -1.3659, -0.2858,  ..., -0.9787,  0.5877,  0.1279],
          ...,
          [-0.6535, -0.2089,  2.4657,  ...,  1.0804, -1.5527,  0.9178],
          [ 1.6886,  1.2923, -0.1877,  ..., -1.6766,  1.6754,  0.4312],
          [ 1.4011,  0.8134,  0.3454,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  1.3305, -0.4923,  1.6294],
          [ 0.6098, -1.1154,  0.4179,  ...,  0.4997, -0.3057, -0.3300],
          [-0.9073,  0.8674, -0.8304,  ...,  0.0151, -1.0819, -0.5116],
          ...,
          [-0.1777, -0.7293,  0.9552,  ...,  2.2984,  1.5865,  2.3561],
          [ 1.7521, -1.7325, -0.0658,  ...,  0.5796, -1.5226,  0.7169],
          [-1.4531,  0.7956, -0.0077,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ..., -0.7846,  0.6745, -0.7153],
  

In [4]:
import torch

# 创建维度为[256, 1]的张量
tensor1 = torch.randn(256, 1,1)

# 创建维度为[8, 256, 16, 16]的张量
tensor2 = torch.randn(8, 256, 16, 16)

# 对tensor2进行转置操作，使其满足矩阵乘法的维度匹配规则
# tensor2_transposed = tensor2.permute(0, 2, 3, 1).view(8, -1, 256)

# 进行矩阵乘法
result = tensor1*tensor2



print(result.shape)  

torch.Size([8, 256, 16, 16])


In [None]:


class CMUNeXtBlock_MK(nn.Module):  # 根据SCSA灵感，直接在一个block中为不同的channel设置不同的卷积核，同时引入通道注意力
    def __init__(self, ch_in, ch_out, depth=1, k=3): # ch_in 需要是4的倍数
        super(CMUNeXtBlock, self).__init__()
        self.block = nn.Sequential(
            *[nn.Sequential(
                Residual(nn.Sequential(
                    # deep wise
                    nn.Conv2d(ch_in, ch_in, kernel_size=(k, k), groups=ch_in, padding=(k // 2, k // 2)),
                    nn.GELU(),
                    nn.BatchNorm2d(ch_in)
                )),
                nn.Conv2d(ch_in, ch_in * 4, kernel_size=(1, 1)),
                nn.GELU(),
                nn.BatchNorm2d(ch_in * 4),
                nn.Conv2d(ch_in * 4, ch_in, kernel_size=(1, 1)),
                nn.GELU(),
                nn.BatchNorm2d(ch_in)
            ) for i in range(depth)]
        )
        self.up = conv_block(ch_in, ch_out)
        self.group_chans = group_chans = self.dim // 4


    def forward(self, x):
        b, c, h_, w_ = x.size()
        l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x, self.group_chans, dim=1)
        x = self.block(x)
        x = self.up(x)
        return x
