In [2]:
import torch
import torch.nn as nn

class TransformerEncoderLayerWithChannels(nn.Module):
    def __init__(self, embedding_dim, num_heads, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayerWithChannels, self).__init__()
        
        # Multi-head Self-Attention
        self.self_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=dropout)
        
        # Feedforward Network
        self.feedforward = nn.Sequential(
            nn.Linear(embedding_dim, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, embedding_dim),
        )
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src 的形状为 (batch_size, channels, num_tokens, embedding_dim)
        batch_size, channels, num_tokens, embedding_dim = src.shape
        
        # 我们对每个 channel 独立应用 self-attention
        outputs = []
        for ch in range(channels):
            # 对于每个 channel，进行 multi-head self-attention
            src_ch = src[:, ch, :, :]  # 取出当前 channel 的数据，形状为 (batch_size, num_tokens, embedding_dim)
            
            # Self-Attention expects input as (num_tokens, batch_size, embedding_dim)
            src_ch_transposed = src_ch.transpose(0, 1)  # 转置为 (num_tokens, batch_size, embedding_dim)
            
            # Self-Attention, output shape: (num_tokens, batch_size, embedding_dim)
            attn_output, _ = self.self_attn(src_ch_transposed, src_ch_transposed, src_ch_transposed)
            
            # Residual Connection + Layer Normalization
            src2 = src_ch_transposed + self.dropout(attn_output)
            src2 = self.norm1(src2)
            
            # Feedforward Layer
            src2_transposed = src2.transpose(0, 1)  # 转回 (batch_size, num_tokens, embedding_dim)
            feedforward_output = self.feedforward(src2_transposed)
            
            # Residual Connection + Layer Normalization
            src3 = src2_transposed + self.dropout(feedforward_output)
            output = self.norm2(src3)
            
            # 将处理好的 channel 加入 outputs 列表
            outputs.append(output.unsqueeze(1))  # (batch_size, 1, num_tokens, embedding_dim)
        
        # 拼接所有 channels
        outputs = torch.cat(outputs, dim=1)  # 最终形状为 (batch_size, channels, num_tokens, embedding_dim)
        
        return outputs

# 测试 TransformerEncoderLayerWithChannels
batch_size = 16
channels = 8  # 相当于通道数
num_tokens = 32
embedding_dim = 128
num_heads = 8

# 输入数据 RF (batch_size, channels, num_tokens, embedding_dim)
RF_input = torch.randn(batch_size, channels, num_tokens, embedding_dim)

# 创建 TransformerEncoderLayerWithChannels 模块
encoder_layer = TransformerEncoderLayerWithChannels(embedding_dim=embedding_dim, num_heads=num_heads)

# 前向传播
encoder_output = encoder_layer(RF_input)

print(f"Encoder output shape: {encoder_output.shape}")  # 输出 (16, 8, 32, 128)


Encoder output shape: torch.Size([16, 8, 32, 128])


In [28]:
import numpy as np
a=[1,2,3,4]
Mtot = np.random.uniform(a[0],a[1])
print(Mtot)

1.4697151564184547
