In [1]:
from types import SimpleNamespace

# 假设config是你的配置字典
config = {
    'args': {
        'skip_connection': True, 
        'num_classes': 5, 
        'device': 'cuda', 
        'in_dim': 4096, 
        'out_dim': 4096, 
        'in_channels': 1, 
        'scale': 1
    }
}

args = SimpleNamespace(**config['args'])

# 现在你可以使用点符号来访问args的值
print(args.in_dim)  # 输出：4096

4096


In [1]:
import torch

print(torch.version.cuda)
print(torch.cuda.is_available())

12.1
True


# 绘图测试

## 网络结构

In [1]:
import torch 
import torch.nn as nn
from einops import rearrange
from model.TSPN import Transparent_Signal_Processing_Network


In [9]:
from config import args
from config import signal_processing_modules,feature_extractor_modules

net = Transparent_Signal_Processing_Network(signal_processing_modules,feature_extractor_modules, args).cuda()
x = torch.randn(2, 4096, 2).cuda()
y = net(x)
print(y.shape)

import scienceplots
import numpy as np
import matplotlib.pyplot as plt

font = {'family' : 'Times New Roman',
        'weight' : 'normal',
        'size'   : 16}
# plt.style.use(['science','ieee'])
plt.style.use(['science','ieee','no-latex'])

# build signal processing layers
# build feature extractor layers
# build classifier
torch.Size([2, 5])


## 绘图部分

In [7]:

import networkx as nx
import matplotlib.pyplot as plt
import torch

## SP

#### v1

In [10]:
# def draw_signal_processing_layer(G,layer, layer_idx,input_nodes): # 
#     # 
    
#     # 获取weight_connection的权重
#     weight = layer.weight_connection.weight.detach().cpu().numpy() # 获取权重
#     module_num = layer.module_num # 信号处理模块的数量
    
#     in_channel = weight.shape[1] # 输入通道
#     out_channel = weight.shape[0] # 输出通道
    
#     num_per_module = out_channel // module_num # 每个模块的输入通道数量


    
#     output_nodes = [f'$x^{layer_idx + 1}_{j}$' for j in range(out_channel)]

#     G.add_nodes_from(output_nodes, layer='output') # 不需要隐藏掉了
    

#     module_nodes = []
#     for idx, module in enumerate(layer.signal_processing_modules.values(), 1):
#         for i in range(num_per_module):
#             module_name = f'{module.name}_{idx}'
#             module_nodes.append(module_name)
#     G.add_nodes_from(module_nodes, layer='module')
#     # 添加边
#     for i, input_node in enumerate(input_nodes):
#         for j, module_node in enumerate(module_nodes):
#             # 根据权重调整边的属性   
#             G.add_edge(input_node, module_node,weight=weight[j, i])
#             G.add_edge(module_node, output_nodes[j])
    
    
#     # 如果存在skip_connection，则添加跳跃连接
#     if hasattr(layer, 'skip_connection'):
#         skip_weight = layer.skip_connection.weight.detach().cpu().numpy()
        
#         # output_nodes = [f'$O_{j}$' for j in range(out_channel)] # 输出节点 
#         for i, input_node in enumerate(input_nodes):
#             for j, output_node in enumerate(output_nodes):
#                 G.add_edge(input_node, output_node, weight = skip_weight[j, i], skip=True)
#     return G, output_nodes

# def draw_signal_processing_layers(model, input):
#     G = nx.Graph()
#     input_nodes = [f'$x^0_{j}$' for j in range(input.shape[2])]
#     for idx, layer in enumerate(model.signal_processing_layers):
#         G, input_nodes = draw_signal_processing_layer(G,layer, idx, input_nodes)    
    
#     # 使用networkx绘制图形
#     pos = nx.spring_layout(G)  # 可以根据需要选择不同的布局
#     weights = [G[u][v]['weight'] for u, v in G.edges()]
#     nx.draw(G, pos, with_labels=True, edges=G.edges(), width=weights)
#     plt.title(f'Signal Processing Layers')
#     plt.savefig('save/Signal_Processing_Layers.png')
#     plt.show()
#     return G, input_nodes



#### v2

In [11]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

def assign_positions_based_on_names(G):
    pos = {}
    layer_node_counts = {}  # 用于存储每层的节点数，以便于计算垂直位置

    for node in G.nodes:
        # 去掉特殊字符，然后根据'^'和'_'分割
        parts = node.strip('$').split('^')
        if len(parts) < 2:
            continue  # 如果格式不符合预期，则跳过该节点
        layer_part = parts[1]
        
        layer_idx, node_idx = layer_part.split('_')
        layer_idx = int(layer_idx)
        node_idx = int(node_idx)

        # 更新每层的节点数
        if layer_idx not in layer_node_counts:
            layer_node_counts[layer_idx] = 1
        else:
            layer_node_counts[layer_idx] += 1

        # 计算节点位置
        # X坐标由层级决定，Y坐标由该层中的节点顺序决定
        pos[node] = (layer_idx, -node_idx)

    # 为了让图中的节点布局更加均衡，我们可以根据每层节点数的最大值调整节点的Y坐标
    max_nodes_per_layer = max(layer_node_counts.values())
    for node, (layer_idx, node_idx) in pos.items():
        adjusted_node_idx = node_idx + max_nodes_per_layer / 2 - layer_node_counts[layer_idx] / 2
        pos[node] = (layer_idx, adjusted_node_idx)

    return pos






def draw_signal_processing_layer(G, layer, layer_idx, input_nodes):
    weight = layer.weight_connection.weight.detach().cpu().numpy()
    module_num = layer.module_num
    
    in_channel = weight.shape[1]
    out_channel = weight.shape[0]
    num_per_module = out_channel // module_num
    output_nodes = [f'$x^{layer_idx + 1}_{j}$' for j in range(out_channel)]

    G.add_nodes_from(output_nodes, layer='output')
    
    module_nodes = []
    for idx, module in enumerate(layer.signal_processing_modules.values(), 1):
        module_name = f'${module.name}^{layer_idx + 1}_{idx}$'  # f'$\phi^{layer_idx + 1}_{idx}$' #  f'{module.name}^{layer_idx + 1}_{idx}' 
        module_nodes.extend([module_name for _ in range(num_per_module)])
    G.add_nodes_from(module_nodes, layer='module')

    # 添加边，避免使用flatten
    for i, input_node in enumerate(input_nodes):
        for j, module_node in enumerate(module_nodes):
            # 直接使用权重矩阵中的值
            G.add_edge(input_node, module_node, weight=weight[j % num_per_module, i], style='solid')
    
    # 添加从模块到输出的虚线边
    for j, module_node in enumerate(module_nodes):
        G.add_edge(module_node, output_nodes[j % len(output_nodes)], style='dashed', arrows=False)
    
    if hasattr(layer, 'skip_connection'):
        skip_weight = layer.skip_connection.weight.detach().cpu().numpy()
        for i, input_node in enumerate(input_nodes):
            for j, output_node in enumerate(output_nodes):
                G.add_edge(input_node, output_node, weight=skip_weight[j, i], style='dashed')
    return G, output_nodes

def draw_signal_processing_layers(model, input):
    G = nx.DiGraph()  # 使用有向图更好地表示信息流
    input_nodes = [f'$x^{0}_{j}$' for j in range(input.shape[2])]
    for idx, layer in enumerate(model.signal_processing_layers):
        G, input_nodes = draw_signal_processing_layer(G, layer, idx, input_nodes)    
    # 使用自动分配位置的函数替换原有的pos计算
    
    # pos = nx.spring_layout(G)
    
    pos = assign_positions_based_on_names(G)
    
    edges = G.edges(data=True)
    solid_edges = [(u, v) for u, v, d in edges if d.get('style', None) == 'solid']
    dashed_edges = [(u, v) for u, v, d in edges if d.get('style', None) == 'dashed']
    
    # 绘制节点
    nx.draw_networkx_nodes(G, pos)
    nx.draw_networkx_labels(G, pos)
    
    # 绘制实线边
    nx.draw_networkx_edges(G, pos, edgelist=solid_edges, edge_color='black',width=1.0, node_size=100)
    
    # 绘制虚线边，可能需要根据您的matplotlib版本和后端调整实现方式
    nx.draw_networkx_edges(G, pos, edgelist=dashed_edges, edge_color='blue', style='dashed',width=1.0, node_size=100, arrows=False)

    plt.title('Signal Processing Layers Visualization')
    plt.axis('off')  # 隐藏坐标轴
    plt.savefig('Signal_Processing_Layers_Visualization_Updated.png', format='PNG')
    plt.show()
    return G, input_nodes


### 测试

In [12]:
G,input_nodes = draw_signal_processing_layers(net, x)

ValueError: 
$$WF$^1_2$
^
ParseException: Expected end of text, found '$'  (at char 0), (line:1, col:1)

ValueError: 
$$WF$^1_2$
^
ParseException: Expected end of text, found '$'  (at char 0), (line:1, col:1)

<Figure size 1980x1500 with 1 Axes>

### 裁剪节点和边
1. 如果weight小于阈值，删除节点和边。
2. 如果节点没有入度，删除节点和节点之后的边。

## feature

## classifier