In [1]:
from types import SimpleNamespace

# 假设config是你的配置字典
config = {
    'args': {
        'skip_connection': True, 
        'num_classes': 5, 
        'device': 'cuda', 
        'in_dim': 4096, 
        'out_dim': 4096, 
        'in_channels': 1, 
        'scale': 1
    }
}

args = SimpleNamespace(**config['args'])

# 现在你可以使用点符号来访问args的值
print(args.in_dim)  # 输出：4096

4096


In [1]:
import torch

print(torch.version.cuda)
print(torch.cuda.is_available())

12.1
True


# 绘图测试

## 网络结构

In [1]:
import torch 
import torch.nn as nn
from einops import rearrange

class SignalProcessingLayer(nn.Module):
    def __init__(self, signal_processing_modules, input_channels, output_channels,skip_connection=True):
        super(SignalProcessingLayer, self).__init__()
        self.norm = nn.InstanceNorm1d(input_channels)
        self.weight_connection = nn.Linear(input_channels, output_channels)
        self.signal_processing_modules = signal_processing_modules
        self.module_num = len(signal_processing_modules)
        if skip_connection:
            self.skip_connection = nn.Linear(input_channels, output_channels)
    def forward(self, x):
        # 信号标准化
        x = rearrange(x, 'b l c -> b c l')
        normed_x = self.norm(x)
        normed_x = rearrange(normed_x, 'b c l -> b l c')
        # 通过线性层
        x = self.weight_connection(normed_x)

        # 按模块数拆分
        splits = torch.split(x, x.size(2) // self.module_num, dim=2)

        # 通过模块计算
        outputs = []
        for module, split in zip(self.signal_processing_modules.values(), splits):
            outputs.append(module(split))
        x = torch.cat(outputs, dim=2)
        # 添加skip connection
        if hasattr(self, 'skip_connection'):
            x = x + self.skip_connection(normed_x)
        return x
    
class FeatureExtractorlayer(nn.Module):
    def __init__(self, feature_extractor_modules,in_channels=1, out_channels=1):
        super(FeatureExtractorlayer, self).__init__()
        self.weight_connection = nn.Linear(in_channels, out_channels)
        self.feature_extractor_modules = feature_extractor_modules

    def norm(self,x): # feature normalization
        mean = x.mean(dim = 0,keepdim = True)
        std = x.std(dim = 0,keepdim = True)
        out = (x-mean)/(std + 1e-10)
        return out
           
    def forward(self, x):
        x = self.weight_connection(x)
        x = rearrange(x, 'b l c -> b c l')
        outputs = []
        for module in self.feature_extractor_modules.values():
            outputs.append(module(x))
        res = torch.cat(outputs, dim=1)
        return self.norm(res)

class Classifier(nn.Module):
    def __init__(self, in_channels, num_classes): # TODO logic
        super(Classifier, self).__init__()
        self.clf = nn.Sequential(
            nn.Linear(in_channels, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes)
            
        )
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.clf(x)

class Transparent_Signal_Processing_Network(nn.Module):
    def __init__(self, signal_processing_modules,feature_extractor, args):
        super(Transparent_Signal_Processing_Network, self).__init__()
        self.layer_num = len(signal_processing_modules)
        self.signal_processing_modules = signal_processing_modules
        self.feature_extractor_modules = feature_extractor
        self.args = args

        self.init_signal_processing_layers()
        self.init_feature_extractor_layers()
        self.init_classifier()

    def init_signal_processing_layers(self):
        print('# build signal processing layers')
        in_channels = self.args.in_channels
        out_channels = self.args.out_channels 

        self.signal_processing_layers = nn.ModuleList()
        for i in range(self.layer_num):
            self.signal_processing_layers.append(SignalProcessingLayer(self.signal_processing_modules[i],
                                                                       in_channels,
                                                                         out_channels,
                                                                         self.args.skip_connection))
            in_channels = out_channels 
            assert out_channels % self.signal_processing_layers[i].module_num == 0 
            out_channels = int(out_channels * self.args.scale)
        self.channel_for_feature = out_channels // self.args.scale

    def init_feature_extractor_layers(self):
        print('# build feature extractor layers')
        self.feature_extractor_layers = FeatureExtractorlayer(self.feature_extractor_modules,self.channel_for_feature,self.channel_for_feature)
        len_feature = len(self.feature_extractor_modules)
        self.channel_for_classifier = self.channel_for_feature * len_feature


    def init_classifier(self):
        print('# build classifier')
        self.clf = Classifier(self.channel_for_classifier, self.args.num_classes)

    def forward(self, x):

        for layer in self.signal_processing_layers:
            x = layer(x)
        x = self.feature_extractor_layers(x)
        x = self.clf(x)

In [2]:
from config import args
from config import signal_processing_modules,feature_extractor_modules

net = Transparent_Signal_Processing_Network(signal_processing_modules,feature_extractor_modules, args).cuda()
x = torch.randn(2, 4096, 2).cuda()
y = net(x)
print(y.shape)

# build signal processing layers
# build feature extractor layers
# build classifier


  from .autonotebook import tqdm as notebook_tqdm


RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2

## 绘图部分

In [2]:

import networkx as nx
import matplotlib.pyplot as plt
import torch

## SP

In [None]:
def draw_signal_processing_layer(G,layer, layer_idx,input_nodes): # 
    # 
    
    # 获取weight_connection的权重
    weight = layer.weight_connection.weight.detach().numpy() # 获取权重
    module_num = layer.module_num # 信号处理模块的数量
    
    in_channel = weight.shape[1] # 输入通道
    out_channel = weight.shape[0] # 输出通道
    
    num_per_module = out_channel // module_num # 每个模块的输入通道数量


    
    output_nodes = [f'$x^{layer_idx + 1}_{j}$' for j in range(out_channel)]

    G.add_nodes_from(output_nodes, layer='output') # 不需要隐藏掉了
    

    module_nodes = []
    for idx, module in enumerate(layer.signal_processing_modules.values(), 1):
        for i in range(num_per_module):
            module_name = f'{module.name}_{idx}'
            module_nodes.append(module_name)
    G.add_nodes_from(module_nodes, layer='module')
    # 添加边
    for i, input_node in enumerate(input_nodes):
        for j, module_node in enumerate(module_nodes):
            # 根据权重调整边的属性   
            G.add_edge(input_node, module_node,weight=weight[j, i])
            G.add_edge(module_node, output_nodes[j])
    
    
    # 如果存在skip_connection，则添加跳跃连接
    if hasattr(layer, 'skip_connection'):
        skip_weight = layer.skip_connection.weight.detach().numpy()
        
        # output_nodes = [f'$O_{j}$' for j in range(out_channel)] # 输出节点 
        for i, input_node in enumerate(input_nodes):
            for j, output_node in enumerate(output_nodes):
                G.add_edge(input_node, output_node, weight = skip_weight[j, i], skip=True)
    return G, output_nodes

def draw_signal_processing_layers(model, input):
    G = nx.Graph()
    input_nodes = [f'$x^0_{j}$' for j in range(input.shape[2])]
    for idx, layer in enumerate(model.signal_processing_layers):
        G, input_nodes = draw_signal_processing_layer(G,layer, idx, input_nodes)    
    
    # 使用networkx绘制图形
    pos = nx.spring_layout(G)  # 可以根据需要选择不同的布局
    weights = [G[u][v]['weight'] for u, v in G.edges()]
    nx.draw(G, pos, with_labels=True, edges=G.edges(), width=weights)
    plt.title(f'Signal Processing Layers')
    plt.show()
    return G, input_nodes

## feature

## classifier