In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import pickle

def load_connectivity_data(connectivity_path, annotation_path):
    df_annot = pd.read_csv(annotation_path)
  
    mask = (df_annot['celltype'] == 'sensory') & (df_annot['additional_annotations'] == 'visual')
    sensory_visual_ids = []
    for _, row in df_annot[mask].iterrows():
        for col in ['left_id', 'right_id']:
            if (id_str := str(row[col]).lower()) != "no pair":
                sensory_visual_ids.append(int(id_str))

    sensory_visual_ids = sorted(list(set(sensory_visual_ids)))
    print(f"Found {len(sensory_visual_ids)} sensory-visual neuron IDs")
  
    df_conn = pd.read_csv(connectivity_path, index_col=0)
    df_conn.index = df_conn.index.astype(int)
    df_conn.columns = df_conn.columns.astype(int)
  
    valid_sensory_ids = [nid for nid in sensory_visual_ids if nid in df_conn.index]
    other_ids = [nid for nid in df_conn.index if nid not in valid_sensory_ids]
  
    df_reindexed = df_conn.loc[valid_sensory_ids + other_ids, valid_sensory_ids + other_ids]
  
    adj_matrix = df_reindexed.values * 1e-3  # 统一标准化
  
    num_S = len(valid_sensory_ids)
    return {
        'W_ss': adj_matrix[:num_S, :num_S],
        'W_sr': adj_matrix[:num_S, num_S:],
        'W_rs': adj_matrix[num_S:, :num_S],
        'W_rr': adj_matrix[num_S:, num_S:],
        'sensory_ids': valid_sensory_ids
    }


In [7]:
class DrosophilaRNN(nn.Module):
    def __init__(self, input_dim, sensory_dim, residual_dim, num_classes, conn_weights):
        super().__init__()
        self.W_ss = nn.Parameter(torch.tensor(conn_weights['W_ss'], dtype=torch.float32), requires_grad=True)
        self.W_sr = nn.Parameter(torch.tensor(conn_weights['W_sr'], dtype=torch.float32), requires_grad=True)
        self.W_rs = nn.Parameter(torch.tensor(conn_weights['W_rs'], dtype=torch.float32), requires_grad=True)
        self.W_rr = nn.Parameter(torch.tensor(conn_weights['W_rr'], dtype=torch.float32), requires_grad=True)
      
        self.input_proj = nn.Linear(input_dim, sensory_dim)
        self.output_layer = nn.Linear(residual_dim, num_classes)
        self.activation = nn.ReLU()
      
        assert self.W_ss.shape == (sensory_dim, sensory_dim)
        assert self.W_sr.shape == (sensory_dim, residual_dim)
        assert self.W_rs.shape == (residual_dim, sensory_dim)
        assert self.W_rr.shape == (residual_dim, residual_dim)

    def forward(self, x, time_steps=10):
        batch_size = x.shape[0]
        device = x.device

        S = torch.zeros(batch_size, self.W_ss.shape[0], device=device)
        R = torch.zeros(batch_size, self.W_rr.shape[0], device=device)
      
        E = self.input_proj(x)  # [batch_size, sensory_dim]

        for t in range(time_steps):
            E_t = E if t % 5 == 0 else torch.zeros_like(E)
          
            S_next = self.activation(
                S @ self.W_ss +    # S->S连接
                E_t +             # 外部输入
                R @ self.W_rs     # R->S连接
            )
          
            R_next = self.activation(
                R @ self.W_rr +    # R->R连接
                S @ self.W_sr      # S->R连接
            )
          
            S, R = S_next, R_next
      
        return self.output_layer(R)

In [8]:
class VisualCNN(nn.Module):
    def __init__(self, out_dim=29):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)),  # 压缩到 (batch, 32, 1, 1)
        )
        self.fc = nn.Linear(32, out_dim)

    def forward(self, x):
        # x: (batch_size, 1, 28, 28)
        z = self.conv(x)           # (batch_size, 32, 1, 1)
        z = z.view(z.size(0), -1)  # (batch_size, 32)
        return self.fc(z)          # (batch_size, out_dim)

# 2) 新的 RNN：无 W_ss，只有 W_sr, W_rs, W_rr
class DrosophilaRNNNoWss(nn.Module):
    def __init__(
        self,
        residual_dim,            # 内部脑维度 (R 的大小)
        num_classes,
        conn_weights,            # 字典，至少包含 W_sr, W_rs, W_rr
        sensory_dim=29           # 我们还想让CNN输出 29 维
    ):
        super().__init__()

        # 从 conn_weights 中读取 并注册为可学习参数(也可固定不学)
        self.W_sr = nn.Parameter(
            torch.tensor(conn_weights['W_sr'], dtype=torch.float32), 
            requires_grad=True
        ) # shape (sensory_dim, residual_dim)
        self.W_rs = nn.Parameter(
            torch.tensor(conn_weights['W_rs'], dtype=torch.float32),
            requires_grad=True
        ) # shape (residual_dim, sensory_dim)
        self.W_rr = nn.Parameter(
            torch.tensor(conn_weights['W_rr'], dtype=torch.float32),
            requires_grad=True
        ) # shape (residual_dim, residual_dim)

        # 用 CNN 取代 W_ss
        self.visual_cnn = VisualCNN(out_dim=sensory_dim)

        self.output_layer = nn.Linear(residual_dim, num_classes)
        self.activation = nn.ReLU()

    def forward(self, x, time_steps=10):
        """
        x: (batch_size, 1, 28, 28) - MNIST图像
        time_steps: 运行多少步RNN迭代
        """
        device = x.device
        batch_size = x.size(0)

        # 初始化 S, R (如果还需要 S 的初始值，也可以是0或可学习参数)
        S = torch.zeros(batch_size, self.W_sr.shape[0], device=device)
        R = torch.zeros(batch_size, self.W_rr.shape[0], device=device)

        # CNN 只看一次输入，或者每步都看，看需求
        # 这里演示：每次t都喂同一张图片
        for t in range(time_steps):
            # 把图像 CNN 出来得到 29维
            E_t = self.visual_cnn(x)  # shape (batch_size, 29)

            # 更新方程(去掉 W_ss)
            # S_{t+1} = ReLU( CNN输出 + W_sr * R_t )
            S_next = self.activation(
                E_t + (R @ self.W_sr.transpose(0,1))
            )

            # R_{t+1} = ReLU( W_rr * R_t + W_rs * S_{t+1} )
            R_next = self.activation(
                (R @ self.W_rr.transpose(0,1)) + (S_next @ self.W_rs.transpose(0,1))
            )

            S, R = S_next, R_next

        # 最终用内部脑 R 做分类
        logits = self.output_layer(R)
        return logits

In [9]:

def main():
    conn_data = load_connectivity_data(
        # connectivity_path="./data/ad_connectivity_matrix.csv",
        connectivity_path="./data/signed_connectivity_matrix.csv",
        annotation_path="./data/science.add9330_data_s2.csv"
    )
  
    model = DrosophilaRNNNoWss(
        # input_dim=784,
        # sensory_dim=len(conn_data['sensory_ids']),
        residual_dim=conn_data['W_rr'].shape[0],
        num_classes=10,
        conn_weights=conn_data
    )
    results = {
        "epoch_train_loss": [],
        "epoch_train_acc": [],
        "epoch_test_acc": [],
        "flops_acc": [],       # 保持结构一致（暂未实现FLOPs计算）
        "total_flops": 0,      # 保持结构一致
        "activations": None    # 保持结构一致
    }

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
  
    train_loader = DataLoader(
        datasets.MNIST('./data', train=True, download=True, transform=transform),
        batch_size=64, shuffle=True
    )
  
    test_loader = DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
        batch_size=64, shuffle=False
    )
  
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    model.to(device)
  
    for epoch in range(10):
        model.train()
        total_loss, correct = 0.0, 0
      
        for images, labels in train_loader:
            images = images.view(-1, 784).to(device)
            labels = labels.to(device)
          
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
          
            total_loss += loss.item() * images.size(0)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
      
        train_loss = total_loss / len(train_loader.dataset)
        train_acc = 100. * correct / len(train_loader.dataset)
        results["epoch_train_loss"].append(train_loss)
        results["epoch_train_acc"].append(train_acc/100)  # 转换为0-1范围
      
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.view(-1, 784).to(device)
                labels = labels.to(device)
              
                outputs = model(images)
                test_correct += (outputs.argmax(1) == labels).sum().item()
      
        test_acc = 100. * test_correct / len(test_loader.dataset)
        results["epoch_test_acc"].append(test_acc/100)  # 转换为0-1范围
      
        print(f"Epoch {epoch+1}/10:")
        print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
        print(f"  Test  Acc: {test_acc:.2f}%")
        print("-" * 50)
  
    with open("Drosophila_Metrics.signed.pkl", "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    main()

Found 29 sensory-visual neuron IDs


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 784]