In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import pickle

# --------------------- 数据预处理部分 ---------------------
def load_connectivity_data(connectivity_path, annotation_path):
    """加载并预处理连接矩阵和注释数据"""
    # 加载注释文件
    df_annot = pd.read_csv(annotation_path)
  
    # 提取视觉感觉神经元ID
    mask = (df_annot['celltype'] == 'sensory') & (df_annot['additional_annotations'] == 'visual')
    sensory_visual_ids = []
    for _, row in df_annot[mask].iterrows():
        for col in ['left_id', 'right_id']:
            if (id_str := str(row[col]).lower()) != "no pair":
                sensory_visual_ids.append(int(id_str))
  
    # 去重排序
    sensory_visual_ids = sorted(list(set(sensory_visual_ids)))
    print(f"Found {len(sensory_visual_ids)} sensory-visual neuron IDs")
  
    # 加载连接矩阵
    df_conn = pd.read_csv(connectivity_path, index_col=0)
    df_conn.index = df_conn.index.astype(int)
    df_conn.columns = df_conn.columns.astype(int)
  
    # 筛选有效ID
    valid_sensory_ids = [nid for nid in sensory_visual_ids if nid in df_conn.index]
    other_ids = [nid for nid in df_conn.index if nid not in valid_sensory_ids]
  
    # 重新排序矩阵
    df_reindexed = df_conn.loc[valid_sensory_ids + other_ids, valid_sensory_ids + other_ids]
  
    # 标准化并拆分矩阵
    adj_matrix = df_reindexed.values * 1e-3  # 统一标准化
  
    num_S = len(valid_sensory_ids)
    return {
        'W_ss': adj_matrix[:num_S, :num_S],
        'W_sr': adj_matrix[:num_S, num_S:],
        'W_rs': adj_matrix[num_S:, :num_S],
        'W_rr': adj_matrix[num_S:, num_S:],
        'sensory_ids': valid_sensory_ids
    }

# --------------------- 模型定义部分 ---------------------
class DrosophilaRNN(nn.Module):
    def __init__(self, input_dim, sensory_dim, residual_dim, num_classes, conn_weights):
        super().__init__()
        # 初始化连接权重
        self.W_ss = nn.Parameter(torch.tensor(conn_weights['W_ss'], dtype=torch.float32), requires_grad=True)
        self.W_sr = nn.Parameter(torch.tensor(conn_weights['W_sr'], dtype=torch.float32), requires_grad=True)
        self.W_rs = nn.Parameter(torch.tensor(conn_weights['W_rs'], dtype=torch.float32), requires_grad=True)
        self.W_rr = nn.Parameter(torch.tensor(conn_weights['W_rr'], dtype=torch.float32), requires_grad=True)
      
        # 定义网络层
        self.input_proj = nn.Linear(input_dim, sensory_dim)
        self.output_layer = nn.Linear(residual_dim, num_classes)
        self.activation = nn.ReLU()
      
        # 维度验证
        assert self.W_ss.shape == (sensory_dim, sensory_dim)
        assert self.W_sr.shape == (sensory_dim, residual_dim)
        assert self.W_rs.shape == (residual_dim, sensory_dim)
        assert self.W_rr.shape == (residual_dim, residual_dim)

    def forward(self, x, time_steps=10):
        batch_size = x.shape[0]
        device = x.device
      
        # 初始化状态
        S = torch.zeros(batch_size, self.W_ss.shape[0], device=device)
        R = torch.zeros(batch_size, self.W_rr.shape[0], device=device)
      
        # 输入投影
        E = self.input_proj(x)  # [batch_size, sensory_dim]
      
        # 时间步模拟
        for t in range(time_steps):
            # 每10步注入输入
            E_t = E if t % 5 == 0 else torch.zeros_like(E)
          
            # 感觉神经元更新
            S_next = self.activation(
                S @ self.W_ss +    # S->S连接
                E_t +             # 外部输入
                R @ self.W_rs     # R->S连接
            )
          
            # 残留神经元更新
            R_next = self.activation(
                R @ self.W_rr +    # R->R连接
                S @ self.W_sr      # S->R连接
            )
          
            S, R = S_next, R_next
      
        return self.output_layer(R)

# --------------------- 训练流程部分 ---------------------
def main():
    # 加载连接数据
    conn_data = load_connectivity_data(
        # connectivity_path="./data/ad_connectivity_matrix.csv",
        connectivity_path="./data/signed_connectivity_matrix.csv",
        annotation_path="./data/science.add9330_data_s2.csv"
    )
  
    # 初始化模型
    model = DrosophilaRNN(
        input_dim=784,
        sensory_dim=len(conn_data['sensory_ids']),
        residual_dim=conn_data['W_rr'].shape[0],
        num_classes=10,
        conn_weights=conn_data
    )
    results = {
        "epoch_train_loss": [],
        "epoch_train_acc": [],
        "epoch_test_acc": [],
        "flops_acc": [],       # 保持结构一致（暂未实现FLOPs计算）
        "total_flops": 0,      # 保持结构一致
        "activations": None    # 保持结构一致
    }

    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
  
    train_loader = DataLoader(
        datasets.MNIST('./data', train=True, download=True, transform=transform),
        batch_size=64, shuffle=True
    )
  
    test_loader = DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
        batch_size=64, shuffle=False
    )
  
    # 训练配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    model.to(device)
  
    # 训练循环
    for epoch in range(10):
        model.train()
        total_loss, correct = 0.0, 0
      
        # 训练阶段
        for images, labels in train_loader:
            images = images.view(-1, 784).to(device)
            labels = labels.to(device)
          
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
          
            total_loss += loss.item() * images.size(0)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
      
        # 记录训练指标
        train_loss = total_loss / len(train_loader.dataset)
        train_acc = 100. * correct / len(train_loader.dataset)
        results["epoch_train_loss"].append(train_loss)
        results["epoch_train_acc"].append(train_acc/100)  # 转换为0-1范围
      
        # 验证阶段
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.view(-1, 784).to(device)
                labels = labels.to(device)
              
                outputs = model(images)
                test_correct += (outputs.argmax(1) == labels).sum().item()
      
        # 记录测试指标
        test_acc = 100. * test_correct / len(test_loader.dataset)
        results["epoch_test_acc"].append(test_acc/100)  # 转换为0-1范围
      
        # 打印统计信息
        print(f"Epoch {epoch+1}/10:")
        print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
        print(f"  Test  Acc: {test_acc:.2f}%")
        print("-" * 50)
  
    # 保存结果
    with open("Drosophila_Metrics.signed.pkl", "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    main()

Found 29 sensory-visual neuron IDs
Epoch 1/10:
  Train Loss: 1.3183 | Acc: 46.00%
  Test  Acc: 82.95%
--------------------------------------------------
Epoch 2/10:
  Train Loss: 0.3221 | Acc: 90.99%
  Test  Acc: 92.73%
--------------------------------------------------
Epoch 3/10:
  Train Loss: 0.2280 | Acc: 93.60%
  Test  Acc: 94.95%
--------------------------------------------------
Epoch 4/10:
  Train Loss: 0.1867 | Acc: 94.77%
  Test  Acc: 94.44%
--------------------------------------------------
Epoch 5/10:
  Train Loss: 0.1679 | Acc: 95.26%
  Test  Acc: 95.40%
--------------------------------------------------
Epoch 6/10:
  Train Loss: 0.1497 | Acc: 95.74%
  Test  Acc: 94.06%
--------------------------------------------------
Epoch 7/10:
  Train Loss: 0.1359 | Acc: 96.10%
  Test  Acc: 95.81%
--------------------------------------------------
Epoch 8/10:
  Train Loss: 0.1253 | Acc: 96.35%
  Test  Acc: 95.83%
--------------------------------------------------
Epoch 9/10:
  Train L