In [None]:
class ModelConfig:
    """统一管理模型所有超参数的配置类"""
    def __init__(self,
                 input_neuron=256,        # 输入神经元数量
                 time_bins=100,           # 时间步数量
                 d_model = 128,           # Transformer模型维度
                 nhead=4,                # 注意力头数
                 num_transformer_layers=2, # Transformer层数
                 conv_channels=64,       # 卷积通道数
                 num_conv_blocks=3,      # 卷积块数量
                 num_classes=10,         # 分类类别数
                 residual_dims=[256, 512, 1024], # 残差层维度
                 use_positional_encoding=True,  # 是否使用位置编码
                 dim_feedforward_ratio=4,       # FeedForward维度比例
                 activation='relu',
                 use_neuron_masking=True,  # 新增：是否启用神经元遮蔽
                 mask_ratio=0.15,
                 mask_replacement='random'):
        
        # Transformer 
        self.transformer = {
            'd_model': d_model,
            'nhead': nhead,
            'num_layers': num_transformer_layers,
            'dim_feedforward': d_model * dim_feedforward_ratio,
            'activation': activation
        }
        
        # cnn
        self.convolution = {
            'channels': conv_channels,
            'num_blocks': num_conv_blocks,
            'kernel_size': (3, 3),
            'pool_size': (2, 2)
        }
        
        # resnet
        self.residual = {
            'dims': residual_dims,
            'skip_connection': True
        }
        
        self.masking = {
            'enabled': use_neuron_masking,
            'ratio': mask_ratio,
            'replacement': mask_replacement
        }

        self.input_dim = input_neuron
        self.time_steps = time_bins
        self.num_classes = num_classes
        self.positional_encoding = use_positional_encoding

In [None]:
import torch
import torch.nn as nn
import math

class NeuronMasker(nn.Module):
    def __init__(self, mask_ratio=0.15, replacement='zero'):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.replacement = replacement
        
    def forward(self, x):
        if self.training:
            batch_size, seq_len, feat_dim = x.shape
            mask = torch.rand_like(x) < self.mask_ratio
            
            if self.replacement == 'zero':
                x_masked = x.masked_fill(mask, 0)
            elif self.replacement == 'random':
                random_values = torch.randn_like(x) * 0.02
                x_masked = x.masked_scatter(mask, random_values)
            return x_masked
        
class ResidualLinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
        self.activation = nn.ReLU()
        self.downsample = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()

    def forward(self, x):
        residual = self.downsample(x)
        x = self.linear(x)
        x = self.norm(x)
        x = self.activation(x)
        return x + residual

class TimeTransformerConvModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        self._init_masking()
        self.input_proj = nn.Linear(config.input_dim, config.transformer['d_model'])
        self.pos_encoder = PositionalEncoding(config.transformer['d_model']) if config.positional_encoding else nn.Identity()
        
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=config.transformer['d_model'],
            nhead=config.transformer['nhead'],
            dim_feedforward=config.transformer['dim_feedforward'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, config.transformer['num_layers'])
        
        self.conv_blocks = nn.Sequential()
        in_channels = 1
        for _ in range(config.convolution['num_blocks']):
            self.conv_blocks.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, config.convolution['channels'], 
                            kernel_size=config.convolution['kernel_size'], padding='same'),
                    nn.BatchNorm2d(config.convolution['channels']),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size=config.convolution['pool_size'])
                )
            )
            in_channels = config.convolution['channels']
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(config.convolution['channels'], config.num_classes)
        
        self.residual_layers = nn.Sequential()
        current_dim = config.convolution['channels']
        for dim in config.residual['dims']:
            self.residual_layers.append(ResidualLinearBlock(current_dim, dim))
            current_dim = dim
        if current_dim != 1024:
            self.residual_layers.append(nn.Linear(current_dim, 1024))
            self.residual_layers.append(nn.LayerNorm(1024))

    def _init_masking(self):
        if self.config.masking['enabled']:
            self.masker = NeuronMasker(
                mask_ratio=self.config.masking['ratio'],
                replacement=self.config.masking['replacement']
            )
        else:
            self.masker = nn.Identity()

    def forward(self, x):
        x = self.masker(x)  # [B, T, D]
        
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        
        x = x.unsqueeze(1)
        x = self.conv_blocks(x)
        x = self.adaptive_pool(x)
        x = x.flatten(1)
        
        logits = self.classifier(x)
        features = self.residual_layers(x)
        
        return logits, features

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1)]
        return x

In [1]:
import numpy as np
a = np.load("/media/ubuntu/sda/data/filter_neuron/mouse_6/natural_image/seg_fr/021322_2_4.npy")
a.shape

(22, 50)

In [7]:
277/ (22*50)

0.25181818181818183