In [None]:
# File: train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from thop import profile
import pickle
import numpy as np
from tqdm import tqdm
from net import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
signed = True
fewshot = True
sample = False

experiments = {
    # "Static_BaseRNN": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam"},
    # "Static_BaseRNN_random": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "random"},
    # "Static_BaseRNN_RandSparse": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "randsparse"},
    # "Static_BaseRNN_RandStructure":{"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "randstructure"},
    # "Learnable_BaseRNN": {"type": "basernn", "trainable": True, "pruning": None, "optim": "adam"},
    "Learnable_DrosophilaRNN": {"type": "drosophilarnn", "trainable": True, "pruning": None, "optim": "adam"},
    # "CWS_Droso": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": None, "optim": "adam"},
    # "CWS_TrainC_pruning": {"type": "cwsrnn", "train_W": True, "train_C": True, "pruning": "drosophila", "optim": "adam"},
    # "CWS_FixedC_random": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": None, "init": "random", "optim": "adam"},
    # "Static_CNN_RNN": {"type": "cnnrnn", "trainable": False, "pruning": None, "optim": "adam"},
    # "Static_BaseRNN_fewshot": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "fewshot": True},
    # "Static_CNN_RNN_fewshot": {"type": "cnnrnn", "trainable": False, "pruning": None, "optim": "adam", "fewshot": True},
  

    # Hungarian
    # "Hungarian_DrosoInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "droso"},
    # # "Hungarian_DrosoInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "randsparse"},
    # # "Hungarian_DrosoInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "randstructure"},
    # # "Hungarian_RandInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "droso"},
    # "Hungarian_RandInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "randsparse"},
    # "Hungarian_RandInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "randstructure"},

    # "Hungarian_RandSparseInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "droso"},
    # "Hungarian_RandSparseInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "randsparse"},
    # "Hungarian_RandSparseInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "randstructure"}
  
    # MLPs
    # "Single_MLP": {"type": "singlemlp", "optim": "adam"},
    # "Twohidden_MLP": {"type": "twohiddenmlp", "optim": "adam"},
    # "Static_MLP": {"type": "staticmlp", "optim": "adam"},
    # "Logistic_Regression": {"type": "logistic", "optim": "adam"},
}

# 自动生成FewShot版本
if fewshot:
    fewshot_experiments = {}
    for exp_id, config in experiments.items():
        # 原始版本
        # fewshot_experiments[exp_id] = config
        # FewShot版本
        fewshot_config = config.copy()
        fewshot_config["fewshot"] = True
        fewshot_experiments[f"{exp_id}_fewshot"] = fewshot_config
    experiments = fewshot_experiments


def get_input_shape(model_type):
    return (1, 1, 28, 28) if model_type in ["cnnrnn", "singlemlp", "twohiddenmlp", "logistic"] else (1, 28, 28)

def prepare_input(data, model):
    return data if isinstance(model, CNNRNN) else data.squeeze(1)

def create_fewshot_subset(dataset, epoch_seed):
    """创建每个epoch的fewshot子集"""
    num_classes = 10
    samples_per_class = 120
    rng = np.random.default_rng(epoch_seed)
    indices = []
    targets = np.array(dataset.targets)
    for cls in range(num_classes):
        cls_indices = np.where(targets == cls)[0]
        sampled_indices = rng.choice(cls_indices, samples_per_class, replace=False)
        indices.extend(sampled_indices)
    return torch.utils.data.Subset(dataset, indices)


def train_epoch(model, optimizer, criterion, train_loader, test_loader, 
               flops_per_sample, cumulative_flops, cumulative_batches):
    """训练一个epoch，同时记录两种指标"""
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    flops_acc_pairs = []
    samples_acc_pairs = []

    with tqdm(train_loader, unit="batch", desc="Training") as pbar:
        for batch_idx, (data, target) in enumerate(pbar):
            # ========== 数据准备 ==========
            data = prepare_input(data.to(device), model)
            target = target.to(device)
            batch_size = data.size(0)
          
            # ========== 计算指标 ==========
            cumulative_flops += flops_per_sample * batch_size * 3  # 前向+反向x3
            cumulative_batches += 1
            current_samples = cumulative_batches // 10

            # ========== 训练步骤 ==========
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            # ========== 记录指标 ==========
            total_loss += loss.item() * batch_size
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += batch_size

            # 每10个batch记录一次
            if (batch_idx + 1) % 10 == 0:
                # 记录FLOPs相关数据（使用训练准确率）
                flops_acc_pairs.append((cumulative_flops, correct/total))
              
                # 按需记录样本数相关数据（使用测试准确率）
                if sample:
                    test_acc, _ = evaluate(model, test_loader)
                    samples_acc_pairs.append((current_samples, test_acc))

            # ========== 更新进度条 ==========
            pbar_info = {
                'loss': f"{loss.item():.4f}",
                'acc': f"{correct/total:.2%}",
                'FLOPs': f"{cumulative_flops/1e9:.2f}G"
            }
            if sample:
                pbar_info['Samples'] = f"{current_samples}"
            pbar.set_postfix(pbar_info)

    # ========== 剪枝操作 ==========
    if isinstance(model, BaseRNN) and model.pruning_method == "hungarian":
        model.apply_hungarian_pruning()
    elif isinstance(model, CWSRNN):
        model.apply_drosophila_pruning()

    return (
        total_loss/total, 
        correct/total, 
        flops_acc_pairs,
        samples_acc_pairs,
        cumulative_flops,
        cumulative_batches
    )

def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    activations_list = []
    with torch.no_grad():
        for data, target in test_loader:
            data = prepare_input(data.to(device), model)
            target = target.to(device)
          
            if isinstance(model, (BaseRNN, CWSRNN)):
                batch_size = data.size(0)
                r_t = torch.zeros(batch_size, model.hidden_size, device=device)
                act_list = []
              
                E_t = model.input_to_hidden(data.view(batch_size, -1))
                W_eff = model.W if isinstance(model, BaseRNN) else (model.C * model.W * model.s.unsqueeze(1))
                r_t = torch.relu(r_t @ W_eff + E_t + r_t)
                act_list.append(r_t)
              
                zero_input = torch.zeros(batch_size, model.input_size, device=device)
                for _ in range(9):
                    E_t = model.input_to_hidden(zero_input)
                    r_t = torch.relu(r_t @ W_eff + E_t + r_t)
                    act_list.append(r_t)
              
                batch_activations = torch.stack(act_list, dim=0)
                batch_mean_activations = batch_activations.mean(dim=1)
                activations_list.append(batch_mean_activations)
              
                output = model.hidden_to_output(r_t)
            else:
                output = model(data)
          
            correct += output.argmax(dim=1).eq(target).sum().item()
            total += target.size(0)
  
    if activations_list:
        activations = torch.stack(activations_list, dim=0).mean(dim=0).cpu().numpy()
    else:
        activations = None
  
    return correct / total, activations

def train_experiment(exp_id, config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    full_train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)

    non_zero = None
    csv_path = './data/signed_connectivity_matrix.csv' if signed else './data/ad_connectivity_matrix.csv'
    if config.get('pruning') == 'drosophila':
        W_droso, non_zero = load_drosophila_matrix(csv_path, apply_pruning=True, signed=signed)
    else:
        W_droso = load_drosophila_matrix(csv_path, apply_pruning=False, signed=signed)

    W_init = None
    if config.get('init') == 'random':
        W_init = torch.randn(W_droso.shape[0], W_droso.shape[0])
    elif config.get('init') == 'droso':
        W_init = W_droso
    elif config.get('init') == 'randsparse':
        non_zero_count = np.count_nonzero(W_droso)
        total_elements = W_droso.size
        mask = torch.zeros(W_droso.shape, dtype=torch.float32)
        indices = torch.randperm(total_elements)[:non_zero_count]
        mask.view(-1)[indices] = 1
        W_init = torch.randn(W_droso.shape) * mask
    elif config.get('init') == 'randstructure':
        mask = (torch.tensor(W_droso) != 0).float()
        W_init = torch.randn(W_droso.shape) * mask
    else:
        W_init = W_droso if not config.get('trainable', True) else None

    W_ref = None
    if config.get('ref') == 'droso':
        W_ref = W_droso
    elif config.get('ref') == 'randsparse':
        non_zero_count = np.count_nonzero(W_droso)
        total_elements = W_droso.size
        mask = torch.zeros(W_droso.shape, dtype=torch.float32)
        indices = torch.randperm(total_elements)[:non_zero_count]
        mask.view(-1)[indices] = 1
        W_ref = torch.randn(W_droso.shape) * mask
    elif config.get('ref') == 'randstructure':
        mask = (torch.tensor(W_droso) != 0).float()
        W_ref = torch.randn(W_droso.shape) * mask

    if config['type'] == 'basernn':
        model = BaseRNN(
            784, W_droso.shape[0], 10,
            W_init=W_init,
            W_ref=W_ref,
            trainable=config['trainable'],
            pruning_method=config.get('pruning')
        )
    if config['type'] == 'drosophilarnn':
        model = DrosophilaRNN(
            input_dim=784,
            sensory_dim=len(conn_data['sensory_ids']),
            residual_dim=conn_data['W_rr'].shape[0],
            num_classes=10,
            conn_weights=conn_data
        )
    elif config['type'] == 'cwsrnn':
        model = CWSRNN(
            784, W_droso.shape[0], 10, W_droso,
            train_W=config['train_W'],
            train_C=config.get('train_C', False),
            non_zero_count=non_zero if config.get('pruning') == 'drosophila' else None
        )
    elif config['type'] == 'cnnrnn':
        model = CNNRNN(torch.tensor(W_droso))
    elif config['type'] == 'singlemlp':
        model = SingleMLP(784, W_droso.shape[0], 10)
    elif config['type'] == 'twohiddenmlp':
        model = TwohiddenMLP(784, 1360, 10)
    elif config['type'] == 'staticmlp':
        model = StaticMLP(784, 1360, 10)
    elif config['type'] == 'logistic':
        model = LogisticRegression(784, 10)
    model.to(device)

    if config['optim'] == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=1e-5)
    else:
        optimizer = FISTAOptimizer(model.parameters(), lr=1e-3, lambda_l1=1e-5)

    input_shape = get_input_shape(config['type'])
    dummy_input = torch.randn(input_shape).to(device)
    macs, _ = profile(model, inputs=(dummy_input,))
    flops_per_sample = macs * 2 

    results = {
        "epoch_train_loss": [], "epoch_train_acc": [],
        "epoch_test_acc": [], "flops_acc": [],
        "samples_acc": [] if sample else None,  #
        "total_flops": 0, "activations": None
    }
    cumulative_flops = 0
    cumulative_batches = 0 
    criterion = nn.CrossEntropyLoss()

    # 初始评估（epoch 0）
    initial_test_acc, initial_activations = evaluate(model, test_loader)
    results["epoch_test_acc"].append(initial_test_acc)
    results["flops_acc"].append((0, initial_test_acc))
    if sample:
        results["samples_acc"].append((0, initial_test_acc))  # 样本数从0开始
    print(f"Initial Epoch 0 | Test Acc: {initial_test_acc:.2%}")

    for epoch in range(10):
        # 动态创建训练集
        if config.get("fewshot"):
            train_subset = create_fewshot_subset(full_train_set, epoch_seed=epoch)
            train_loader = torch.utils.data.DataLoader(train_subset, batch_size=30, shuffle=True)
        else:
            train_loader = torch.utils.data.DataLoader(full_train_set, batch_size=64, shuffle=True)

        # 训练epoch
        epoch_loss, epoch_acc, flops_pairs, samples_pairs, cumulative_flops, cumulative_batches = train_epoch(
            model, optimizer, criterion, train_loader, test_loader,
            flops_per_sample, cumulative_flops, cumulative_batches
        )
      
        # 记录结果
        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)
        results["flops_acc"].extend(flops_pairs)
        if sample:
            results["samples_acc"].extend(samples_pairs)
      
        # Epoch结束后的评估
        test_acc, _ = evaluate(model, test_loader)
        results["epoch_test_acc"].append(test_acc)
        print(f"Epoch {epoch+1} | Test Acc: {test_acc:.2%}")

    # ========== 保存结果 ==========
    filename_suffix = []
    if signed: filename_suffix.append("signed")
    if sample: filename_suffix.append("sample")
    suffix = ".".join(filename_suffix) if filename_suffix else ""
    filename = f"{exp_id}.{suffix}.pkl" if suffix else f"{exp_id}.pkl"
  
    with open(filename, "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        print(f"\n{'='*30}\nTraining {exp_id}\n{'='*30}")
        train_experiment(exp_id, config)


Training Single_MLP_fewshot
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Initial Epoch 0 | Test Acc: 6.73%


Training: 100%|██████████| 40/40 [00:00<00:00, 119.46batch/s, loss=1.9150, acc=27.83%, FLOPs=16.88G]


Epoch 1 | Test Acc: 53.33%


Training: 100%|██████████| 40/40 [00:00<00:00, 128.06batch/s, loss=1.5557, acc=59.50%, FLOPs=33.75G]


Epoch 2 | Test Acc: 70.20%


Training: 100%|██████████| 40/40 [00:00<00:00, 130.12batch/s, loss=1.3587, acc=72.33%, FLOPs=50.63G]


Epoch 3 | Test Acc: 76.37%


Training: 100%|██████████| 40/40 [00:00<00:00, 117.53batch/s, loss=1.1899, acc=76.67%, FLOPs=67.50G]


Epoch 4 | Test Acc: 79.95%


Training: 100%|██████████| 40/40 [00:00<00:00, 131.98batch/s, loss=1.0971, acc=77.75%, FLOPs=84.38G]


Epoch 5 | Test Acc: 81.75%


Training: 100%|██████████| 40/40 [00:00<00:00, 130.35batch/s, loss=0.9479, acc=80.83%, FLOPs=101.26G]


Epoch 6 | Test Acc: 83.26%


Training: 100%|██████████| 40/40 [00:00<00:00, 130.27batch/s, loss=0.8016, acc=83.17%, FLOPs=118.13G]


Epoch 7 | Test Acc: 84.48%


Training: 100%|██████████| 40/40 [00:00<00:00, 132.56batch/s, loss=0.9278, acc=83.83%, FLOPs=135.01G]


Epoch 8 | Test Acc: 85.18%


Training: 100%|██████████| 40/40 [00:00<00:00, 119.24batch/s, loss=0.7875, acc=83.92%, FLOPs=151.88G]


Epoch 9 | Test Acc: 85.60%


Training: 100%|██████████| 40/40 [00:00<00:00, 133.11batch/s, loss=0.7642, acc=83.17%, FLOPs=168.76G]


Epoch 10 | Test Acc: 86.34%


In [9]:
def train_epoch(model, optimizer, criterion, train_loader, test_loader, cumulative_batch_count):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    samples_acc_pairs = []

    with tqdm(train_loader, unit="batch", desc="Training") as pbar:
        for batch_idx, (data, target) in enumerate(pbar):
            data = prepare_input(data.to(device), model)
            target = target.to(device)
      
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
      
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += data.size(0)
      
            # 更新累计的batch计数
            cumulative_batch_count += 1
            current_samples = cumulative_batch_count // 10
      
            # 每处理10个batch进行一次评估
            if cumulative_batch_count % 10 == 0:
                test_acc, _ = evaluate(model, test_loader)
                samples_acc_pairs.append((current_samples, test_acc))
      
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{correct/total:.2%}",
                'Samples': f"{current_samples}"
            })

    # 动态剪枝：每个 epoch 结束后执行
    if isinstance(model, BaseRNN) and model.pruning_method == "hungarian":
        model.apply_hungarian_pruning()
    elif isinstance(model, CWSRNN):
        model.apply_drosophila_pruning()

    return total_loss/total, correct/total, samples_acc_pairs, cumulative_batch_count

def train_experiment(exp_id, config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    full_train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
  
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)

    # ...（中间保持模型初始化部分不变）...

    # 初始评估（epoch 0）
    initial_test_acc, initial_activations = evaluate(model, test_loader)
    results = {
        "epoch_train_loss": [],
        "epoch_train_acc": [],
        "epoch_test_acc": [initial_test_acc],
        "samples_acc": [(0, initial_test_acc)],  # 初始样本数为0
        "activations": None
    }
    cumulative_batch_count = 0  # 初始化累计batch计数
    print(f"Initial Epoch 0 | Test Acc: {initial_test_acc:.2%}")

    for epoch in range(10):
        # 动态创建训练集
        if config.get("fewshot"):
            train_subset = create_fewshot_subset(full_train_set, epoch_seed=epoch)
            train_loader = torch.utils.data.DataLoader(train_subset, batch_size=64, shuffle=True)
        else:
            train_loader = torch.utils.data.DataLoader(full_train_set, batch_size=64, shuffle=True)
    
        # 训练过程
        epoch_loss, epoch_acc, samples_pairs, cumulative_batch_count = train_epoch(
            model, optimizer, criterion, train_loader, test_loader, cumulative_batch_count
        )
        results["samples_acc"].extend(samples_pairs)
        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)
    
        # 每个epoch结束后的测试
        test_acc, activations = evaluate(model, test_loader)
        results["epoch_test_acc"].append(test_acc)
        print(f"Epoch {epoch+1} | Test Acc: {test_acc:.2%}")

    # 保存结果
    filename = f"{exp_id}.signed.pkl" if signed else f"{exp_id}.pkl"
    with open(filename, "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        print(f"\n{'='*30}\nTraining {exp_id}\n{'='*30}")
        train_experiment(exp_id, config)


Training Single_MLP_fewshot


NameError: name 'model' is not defined