In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Embedding(nn.Module):
    # input: (1, 4, 31, 10, 3)
    # output: (1, 128, 42)

    def __init__(self):
        super(Embedding, self).__init__()
        self.Conv3D_block = nn.Conv3d(in_channels=4, out_channels=128, 
                                      kernel_size=(5, 5, 3), # (D,H,W)
                                      stride=(4,4,1),
                                      padding=(0,0,1))
        
        self.pos_emb = nn.Parameter(torch.randn(1, 128, 42))

    def forward(self,x):
        x = self.Conv3D_block(x)
        x = x.flatten(2) 
        x += self.pos_emb
        return x


class AttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(AttentionBlock, self).__init__()
        assert d_model % num_heads ==0, "dim must be divisible by num_heads"

        self.num_heads = num_heads
        self.d_model = d_model
        self.dk = d_model // num_heads
        self.scale = self.dk ** -0.5

        self.qkv = nn.Linear(d_model, d_model * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(d_model, d_model)
        self.proj_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.dk).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x        


class Attentionblock2(nn.Module):
    def __init__(self, d_model, num_heads=6, dropout=0.1):
        super(Attentionblock2, self).__init__()
        assert d_model % num_heads ==0, "dim must be divisible by num_heads"

        self.num_heads = num_heads
        self.d_model = d_model
        self.dk = d_model // num_heads
        self.scale = self.dk ** -0.5

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)

        #输出层
        self.proj = nn.Linear(d_model, d_model)

        
    def forward(self, x):

        B, N, C = x.shape # batch, tokens, d_model


        #(B, N, C)→(B, N, H, dk)→(B, H, N, dk),每个head内部的计算是并行的
        Q = self.Wq(x).reshape(B, N, self.num_heads, self.dk).permute(0, 2, 1, 3)
        K = self.Wk(x).reshape(B, N, self.num_heads, self.dk).permute(0, 2, 1, 3)
        V = self.Wv(x).reshape(B, N, self.num_heads, self.dk).permute(0, 2, 1, 3)
        
        scores = (Q @ K.transpose(-2, -1)) * self.scale
        attn = scores.softmax(dim=-1)
        
        x = (attn @ V).transpose(1, 2).reshape(B, N, C)
        return x

class MLP(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(d_model, 2*d_model)  
        self.fc2 = nn.Linear(2*d_model, d_model)  
        self.fc1_drop = nn.Dropout(dropout)
        self.fc2_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc1_drop(x)
        x = self.fc2(x)
        x = self.fc2_drop(x)
        return x

class ViT(nn.Module):
    def __init__(self, d_model=42, num_heads=6, num_layers=4):
        super(ViT, self).__init__()

        self.Attention = AttentionBlock(d_model, num_heads, dropout=0.1) 
        self.norm = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, dropout=0.1)

        
    def forward(self, x):
        x = self.norm(x)
        add1 = x + self.norm(self.Attention(x))
        x = x + self.mlp(x)
        return x
    

class FullNet(nn.Module):
    def __init__(self):
        super(FullNet, self).__init__()

        self.embedding = Embedding()
        self.ViT = ViT()

        self.fc1 = nn.Linear(42*128, 256)
        self.dropout = nn.Dropout(0.1)
        self.fc2 = nn.Linear(256, 256)

        self.left1 = nn.Linear(256, 25)
        self.left2 = nn.Linear(25, 25)
        self.outleft = nn.Linear(25, 1)

        self.right1 = nn.Linear(256, 25)
        self.right2 = nn.Linear(25, 25)
        self.outright = nn.Linear(25, 1)

    def forward(self, x):
        x = self.embedding(x)
        print('x:',x.shape)

        for i in range(4):
            x = self.ViT(x)
            
        print(f'ViT-{i} out:',x.shape)

        x = torch.flatten(x, start_dim=1)
        print(f'flatten-{i} out:',x.shape)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)

        left = self.left1(x)
        left = self.dropout(left)
        left = self.left2(left)

        right = self.right1(x)
        right = self.dropout(right)
        right = self.right2(right)

        outleft = self.outleft(left)
        outright = self.outright(right)

        return {outleft, outright}


# 输入张量 (N=1, C_in=4, D=31, H=10, W=3)
x = torch.randn(1, 4, 31, 10, 3)    


model = FullNet()
y = model(x)
print(y)

# gpt修改点总结：
# 1. Embedding: pos_emb 维度应为 (1,42,128)，并在 flatten 后加 transpose(1,2)，得到 (B,42,128)。
# 2. ViT Block: d_model 改为 128（不是 42）；使用 Pre-LN 残差 (x = x + Attn(LN(x)))；MLP 结构为 128→256→128。
# 3. ViT 堆叠: 共 4 层，不共享权重。
# 4. 输出: 最终返回时用 torch.cat([...], dim=-1) 拼接成 (B,2)，不能用 set。


x: torch.Size([1, 128, 42])
ViT-3 out: torch.Size([1, 128, 42])
flatten-3 out: torch.Size([1, 5376])
{tensor([[-0.0758]], grad_fn=<AddmmBackward0>), tensor([[-0.1046]], grad_fn=<AddmmBackward0>)}


21.333333333333332