In [4]:
# File: train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from thop import profile
import pickle
import numpy as np
from tqdm import tqdm
from net import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
signed = False

experiments = {
    "Static_BaseRNN": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam"},
    "Static_BaseRNN_random": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "random"},
    "Static_BaseRNN_RandSparse": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "randsparse"},
    "Static_BaseRNN_RandStructure":{"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "randstructure"},
    "Learnable_BaseRNN": {"type": "basernn", "trainable": True, "pruning": None, "optim": "adam"},
    # "CWS_Droso": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": None, "optim": "adam"},
    # "CWS_TrainC_pruning": {"type": "cwsrnn", "train_W": True, "train_C": True, "pruning": "drosophila", "optim": "adam"},
    # "CWS_FixedC_random": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": None, "init": "random", "optim": "adam"},
    # "Static_CNN_RNN": {"type": "cnnrnn", "trainable": False, "pruning": None, "optim": "adam"},
    # "Static_BaseRNN_fewshot": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "fewshot": True},
    # "Static_CNN_RNN_fewshot": {"type": "cnnrnn", "trainable": False, "pruning": None, "optim": "adam", "fewshot": True},
    

    # Hungarian
    # "Hungarian_DrosoInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "droso"},
    # "Hungarian_DrosoInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "randsparse"},
    # "Hungarian_DrosoInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "randstructure"},
    "Hungarian_RandInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "droso"},
    "Hungarian_RandInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "randsparse"},
    "Hungarian_RandInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "randstructure"},
    # "Hungarian_RandSparseInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "droso"},
    # "Hungarian_RandSparseInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "randsparse"},
    # "Hungarian_RandSparseInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "randstructure"}
    
    # MLPs
    "Single_MLP": {"type": "singlemlp", "optim": "adam"},
    "Twohidden_MLP": {"type": "twohiddenmlp", "optim": "adam"},
    "Static_MLP": {"type": "staticmlp", "optim": "adam"},
    # "Logistic_Regression": {"type": "logistic", "optim": "adam"},
}

def get_input_shape(model_type):
    return (1, 1, 28, 28) if model_type in ["cnnrnn", "singlemlp", "twohiddenmlp", "logistic"] else (1, 28, 28)

def prepare_input(data, model):
    return data if isinstance(model, CNNRNN) else data.squeeze(1)

def train_epoch(model, optimizer, criterion, train_loader, flops_per_sample, cumulative_flops):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    flops_acc_pairs = []
  
    with tqdm(train_loader, unit="batch", desc="Training") as pbar:
        for batch_idx, (data, target) in enumerate(pbar):
            data = prepare_input(data.to(device), model)
            target = target.to(device)
          
            batch_flops = flops_per_sample * data.size(0) * 3
            cumulative_flops += batch_flops
          
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
          
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += data.size(0)
          
            if (batch_idx + 1) % 100 == 0:
                flops_acc_pairs.append((cumulative_flops, correct/total))
          
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{correct/total:.2%}",
                'FLOPs': f"{cumulative_flops/1e9:.2f}G"
            })
    
    # 动态剪枝：每个 epoch 结束后执行
    if isinstance(model, BaseRNN) and model.pruning_method == "hungarian":
        model.apply_hungarian_pruning()
    elif isinstance(model, CWSRNN):
        model.apply_drosophila_pruning()
  
    return total_loss/total, correct/total, flops_acc_pairs, cumulative_flops

def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    activations_list = []
    with torch.no_grad():
        for data, target in test_loader:
            data = prepare_input(data.to(device), model)
            target = target.to(device)
            
            if isinstance(model, (BaseRNN, CWSRNN)):
                batch_size = data.size(0)
                r_t = torch.zeros(batch_size, model.hidden_size, device=device)
                act_list = []
                
                E_t = model.input_to_hidden(data.view(batch_size, -1))
                W_eff = model.W if isinstance(model, BaseRNN) else (model.C * model.W * model.s.unsqueeze(1))
                r_t = torch.relu(r_t @ W_eff + E_t + r_t)
                act_list.append(r_t)
                
                zero_input = torch.zeros(batch_size, model.input_size, device=device)
                for _ in range(9):
                    E_t = model.input_to_hidden(zero_input)
                    r_t = torch.relu(r_t @ W_eff + E_t + r_t)
                    act_list.append(r_t)
                
                batch_activations = torch.stack(act_list, dim=0)
                batch_mean_activations = batch_activations.mean(dim=1)
                activations_list.append(batch_mean_activations)
                
                output = model.hidden_to_output(r_t)
            else:
                output = model(data)
            
            correct += output.argmax(dim=1).eq(target).sum().item()
            total += target.size(0)
    
    if activations_list:
        activations = torch.stack(activations_list, dim=0).mean(dim=0).cpu().numpy()
    else:
        activations = None
    
    return correct / total, activations

def train_experiment(exp_id, config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    
    if config.get("fewshot", False):
        num_classes = 10
        samples_per_class = int(len(train_set) * 0.05 / num_classes)
        indices = []
        targets = np.array(train_set.targets)
        
        for cls in range(num_classes):
            cls_indices = np.where(targets == cls)[0]
            sampled_indices = np.random.choice(cls_indices, samples_per_class, replace=False)
            indices.extend(sampled_indices)
        
        train_set = torch.utils.data.Subset(train_set, indices)
        print(f"Few-shot training set size: {len(train_set)} samples")
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
  
    non_zero = None
    csv_path = './data/signed_connectivity_matrix.csv' if signed else './data/ad_connectivity_matrix.csv'
    if config.get('pruning') == 'drosophila':
        W_droso, non_zero = load_drosophila_matrix(csv_path, apply_pruning=True, signed=signed)
    else:
        W_droso = load_drosophila_matrix(csv_path, apply_pruning=False, signed=signed)
  
  
    W_init = None
    if config.get('init') == 'random':
        W_init = torch.randn(W_droso.shape[0], W_droso.shape[0])
    elif config.get('init') == 'droso':
        W_init = W_droso
    elif config.get('init') == 'randsparse':
        non_zero_count = np.count_nonzero(W_droso)
        total_elements = W_droso.size
        mask = torch.zeros(W_droso.shape, dtype=torch.float32)
        indices = torch.randperm(total_elements)[:non_zero_count]
        mask.view(-1)[indices] = 1
        W_init = torch.randn(W_droso.shape) * mask
    elif config.get('init') == 'randstructure':
        mask = (torch.tensor(W_droso) != 0).float()
        W_init = torch.randn(W_droso.shape) * mask
    else:
        W_init = W_droso if not config.get('trainable', True) else None
  
    W_ref = None
    if config.get('ref') == 'droso':
        W_ref = W_droso
    elif config.get('ref') == 'randsparse':
        non_zero_count = np.count_nonzero(W_droso)
        total_elements = W_droso.size
        mask = torch.zeros(W_droso.shape, dtype=torch.float32)
        indices = torch.randperm(total_elements)[:non_zero_count]
        mask.view(-1)[indices] = 1
        W_ref = torch.randn(W_droso.shape) * mask
    elif config.get('ref') == 'randstructure':
        mask = (torch.tensor(W_droso) != 0).float()
        W_ref = torch.randn(W_droso.shape) * mask
  
    if config['type'] == 'basernn':
        model = BaseRNN(
            784, W_droso.shape[0], 10,
            W_init=W_init,
            W_ref=W_ref,
            trainable=config['trainable'],
            pruning_method=config.get('pruning')
        )
    elif config['type'] == 'cwsrnn':
        model = CWSRNN(
            784, W_droso.shape[0], 10, W_droso,
            train_W=config['train_W'],
            train_C=config.get('train_C', False),
            non_zero_count=non_zero if config.get('pruning') == 'drosophila' else None
        )
    elif config['type'] == 'cnnrnn':
        model = CNNRNN(torch.tensor(W_droso))
    elif config['type'] == 'singlemlp':
        model = SingleMLP(784, W_droso.shape[0], 10)
    elif config['type'] == 'twohiddenmlp':
        model = TwohiddenMLP(784, 1360, 10)
    elif config['type'] == 'staticmlp':
        model = StaticMLP(784, 1360, 10)
    elif config['type'] == 'logistic':
        model = LogisticRegression(784, 10)
    model.to(device)
  
    if config['optim'] == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=1e-5)
    else:
        optimizer = FISTAOptimizer(model.parameters(), lr=1e-3, lambda_l1=1e-5)
  
    input_shape = get_input_shape(config['type'])
    dummy_input = torch.randn(input_shape).to(device)
    macs, _ = profile(model, inputs=(dummy_input,))
    flops_forward = macs * 2

    results = {
        "epoch_train_loss": [], "epoch_train_acc": [],
        "epoch_test_acc": [], "flops_acc": [],
        "total_flops": 0, "activations": None
    }
    cumulative_flops = 0
    criterion = nn.CrossEntropyLoss()
  
    for epoch in range(10):
        epoch_loss, epoch_acc, flops_pairs, cumulative_flops = train_epoch(
            model, optimizer, criterion, train_loader, flops_forward, cumulative_flops
        )
        test_acc, activations = evaluate(model, test_loader)
      
        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)
        results["epoch_test_acc"].append(test_acc)
        results["flops_acc"].extend(flops_pairs)
        results["activations"] = activations
        print(f"Epoch {epoch+1} | Test Acc: {test_acc:.2%}")
  
    results["total_flops"] = cumulative_flops
    filename = f"{exp_id}.signed.pkl" if signed else f"{exp_id}.pkl"
    with open(filename, "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        print(f"\n{'='*30}\nTraining {exp_id}\n{'='*30}")
        train_experiment(exp_id, config)


Training Static_BaseRNN
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:18<00:00, 51.20batch/s, loss=0.4088, acc=81.73%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 90.23%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.86batch/s, loss=0.2795, acc=90.68%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 92.03%


Training: 100%|██████████| 938/938 [00:21<00:00, 44.51batch/s, loss=0.3733, acc=92.24%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 93.10%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.39batch/s, loss=0.3051, acc=93.26%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 93.85%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.56batch/s, loss=0.1126, acc=93.97%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 94.29%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.99batch/s, loss=0.1169, acc=94.56%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 94.63%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.26batch/s, loss=0.1347, acc=95.00%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 95.05%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.09batch/s, loss=0.0508, acc=95.43%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 95.37%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.29batch/s, loss=0.1645, acc=95.80%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 95.65%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.45batch/s, loss=0.1160, acc=96.14%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 95.77%

Training Static_BaseRNN_random


  self.register_buffer('W', torch.tensor(W_init * 1e-5, dtype=torch.float32))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:18<00:00, 50.89batch/s, loss=0.4542, acc=82.23%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 90.48%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.97batch/s, loss=0.1883, acc=90.85%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 92.20%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.88batch/s, loss=0.0639, acc=92.31%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 93.10%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.08batch/s, loss=0.1663, acc=93.35%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 93.75%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.23batch/s, loss=0.0679, acc=94.07%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 94.22%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.85batch/s, loss=0.0835, acc=94.58%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 94.78%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.15batch/s, loss=0.1900, acc=95.08%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 95.12%


Training: 100%|██████████| 938/938 [00:14<00:00, 66.49batch/s, loss=0.0899, acc=95.47%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 95.41%


Training: 100%|██████████| 938/938 [00:13<00:00, 70.90batch/s, loss=0.0505, acc=95.81%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 95.67%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.76batch/s, loss=0.2887, acc=96.14%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 95.80%

Training Static_BaseRNN_RandSparse


  self.register_buffer('W', torch.tensor(W_init * 1e-5, dtype=torch.float32))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:18<00:00, 51.23batch/s, loss=0.5836, acc=82.19%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 90.24%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.36batch/s, loss=0.4006, acc=90.76%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 91.93%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.22batch/s, loss=0.1223, acc=92.25%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 93.00%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.28batch/s, loss=0.1280, acc=93.26%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 93.70%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.32batch/s, loss=0.2034, acc=94.00%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 94.35%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.20batch/s, loss=0.0850, acc=94.55%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 94.65%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.02batch/s, loss=0.1259, acc=95.00%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 94.99%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.26batch/s, loss=0.0665, acc=95.41%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 95.42%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.19batch/s, loss=0.2035, acc=95.78%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 95.71%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.25batch/s, loss=0.0999, acc=96.13%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 95.91%

Training Static_BaseRNN_RandStructure


  self.register_buffer('W', torch.tensor(W_init * 1e-5, dtype=torch.float32))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:18<00:00, 51.69batch/s, loss=0.5459, acc=82.28%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 90.43%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.27batch/s, loss=0.2553, acc=90.80%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 92.09%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.07batch/s, loss=0.1905, acc=92.30%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 93.05%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.44batch/s, loss=0.1574, acc=93.32%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 93.79%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.28batch/s, loss=0.1385, acc=94.05%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 94.40%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.34batch/s, loss=0.3850, acc=94.62%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 94.77%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.56batch/s, loss=0.0723, acc=95.03%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 95.13%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.16batch/s, loss=0.1515, acc=95.45%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 95.25%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.46batch/s, loss=0.0417, acc=95.79%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 95.61%


Training: 100%|██████████| 938/938 [00:18<00:00, 51.41batch/s, loss=0.1080, acc=96.12%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 95.80%

Training Learnable_BaseRNN
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:19<00:00, 47.11batch/s, loss=0.1005, acc=92.72%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 95.82%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.02batch/s, loss=0.1071, acc=96.95%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 97.38%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.00batch/s, loss=0.0370, acc=98.13%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 97.93%


Training: 100%|██████████| 938/938 [00:20<00:00, 46.79batch/s, loss=0.0192, acc=98.69%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 97.16%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.14batch/s, loss=0.0069, acc=99.15%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 97.69%


Training: 100%|██████████| 938/938 [00:20<00:00, 46.86batch/s, loss=0.0008, acc=99.35%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 97.75%


Training: 100%|██████████| 938/938 [00:18<00:00, 50.08batch/s, loss=0.0508, acc=99.52%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 98.13%


Training: 100%|██████████| 938/938 [00:14<00:00, 63.08batch/s, loss=0.0069, acc=99.63%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 98.14%


Training: 100%|██████████| 938/938 [00:14<00:00, 63.34batch/s, loss=0.0775, acc=99.67%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 97.66%


Training: 100%|██████████| 938/938 [00:15<00:00, 59.49batch/s, loss=0.0024, acc=99.68%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 97.18%

Training Hungarian_RandInit_DrosoRef


  self.W = nn.Parameter(torch.tensor(W_init * 1e-5, dtype=torch.float32))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:14<00:00, 62.87batch/s, loss=0.1065, acc=92.53%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 96.20%


Training: 100%|██████████| 938/938 [00:15<00:00, 60.21batch/s, loss=0.0983, acc=97.06%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 97.00%


Training: 100%|██████████| 938/938 [00:19<00:00, 46.93batch/s, loss=0.0602, acc=98.10%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 97.60%


Training: 100%|██████████| 938/938 [00:17<00:00, 52.97batch/s, loss=0.0210, acc=98.65%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 97.73%


Training: 100%|██████████| 938/938 [00:19<00:00, 46.98batch/s, loss=0.3072, acc=99.06%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 97.68%


Training: 100%|██████████| 938/938 [00:20<00:00, 46.67batch/s, loss=0.0013, acc=99.45%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 97.81%


Training: 100%|██████████| 938/938 [00:15<00:00, 62.02batch/s, loss=0.0533, acc=99.53%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 97.86%


Training: 100%|██████████| 938/938 [00:15<00:00, 62.10batch/s, loss=0.0053, acc=99.55%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 98.26%


Training: 100%|██████████| 938/938 [00:15<00:00, 62.03batch/s, loss=0.0267, acc=99.65%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 98.05%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.14batch/s, loss=0.0008, acc=99.72%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 97.91%

Training Hungarian_RandInit_RandSparseRef


  self.W = nn.Parameter(torch.tensor(W_init * 1e-5, dtype=torch.float32))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:19<00:00, 49.04batch/s, loss=0.0299, acc=92.68%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 96.33%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.11batch/s, loss=0.0267, acc=97.13%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 97.23%


Training: 100%|██████████| 938/938 [00:20<00:00, 46.76batch/s, loss=0.0218, acc=98.13%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 97.31%


Training: 100%|██████████| 938/938 [00:19<00:00, 46.92batch/s, loss=0.0063, acc=98.72%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 97.31%


Training: 100%|██████████| 938/938 [00:20<00:00, 46.68batch/s, loss=0.0882, acc=99.06%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 97.72%


Training: 100%|██████████| 938/938 [00:19<00:00, 46.93batch/s, loss=0.0137, acc=99.42%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 97.74%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.03batch/s, loss=0.0063, acc=99.49%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 98.17%


Training: 100%|██████████| 938/938 [00:19<00:00, 47.22batch/s, loss=0.0027, acc=99.59%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 97.61%


Training: 100%|██████████| 938/938 [00:15<00:00, 62.18batch/s, loss=0.0789, acc=99.71%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 97.95%


Training: 100%|██████████| 938/938 [00:15<00:00, 62.00batch/s, loss=0.0005, acc=99.72%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 97.49%

Training Hungarian_RandInit_RandStructureRef


  self.W = nn.Parameter(torch.tensor(W_init * 1e-5, dtype=torch.float32))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:19<00:00, 47.63batch/s, loss=0.0484, acc=92.62%, FLOPs=8342.35G]


Epoch 1 | Test Acc: 96.41%


Training: 100%|██████████| 938/938 [00:23<00:00, 39.56batch/s, loss=0.1234, acc=97.01%, FLOPs=16684.70G]


Epoch 2 | Test Acc: 97.25%


Training: 100%|██████████| 938/938 [00:23<00:00, 39.59batch/s, loss=0.1064, acc=98.11%, FLOPs=25027.06G]


Epoch 3 | Test Acc: 97.67%


Training: 100%|██████████| 938/938 [00:15<00:00, 62.02batch/s, loss=0.0086, acc=98.69%, FLOPs=33369.41G]


Epoch 4 | Test Acc: 97.83%


Training: 100%|██████████| 938/938 [00:19<00:00, 48.81batch/s, loss=0.0196, acc=99.12%, FLOPs=41711.76G]


Epoch 5 | Test Acc: 98.00%


Training: 100%|██████████| 938/938 [00:15<00:00, 61.98batch/s, loss=0.0105, acc=99.43%, FLOPs=50054.11G]


Epoch 6 | Test Acc: 97.92%


Training: 100%|██████████| 938/938 [00:15<00:00, 61.84batch/s, loss=0.0035, acc=99.50%, FLOPs=58396.46G]


Epoch 7 | Test Acc: 97.69%


Training: 100%|██████████| 938/938 [00:15<00:00, 61.71batch/s, loss=0.0004, acc=99.55%, FLOPs=66738.82G]


Epoch 8 | Test Acc: 98.07%


Training: 100%|██████████| 938/938 [00:15<00:00, 61.87batch/s, loss=0.0004, acc=99.69%, FLOPs=75081.17G]


Epoch 9 | Test Acc: 98.07%


Training: 100%|██████████| 938/938 [00:21<00:00, 44.08batch/s, loss=0.0012, acc=99.76%, FLOPs=83423.52G]


Epoch 10 | Test Acc: 98.11%

Training Single_MLP
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


Training: 100%|██████████| 938/938 [00:12<00:00, 77.09batch/s, loss=0.4481, acc=83.18%, FLOPs=843.80G]


Epoch 1 | Test Acc: 89.44%


Training: 100%|██████████| 938/938 [00:11<00:00, 80.02batch/s, loss=0.2287, acc=89.77%, FLOPs=1687.60G]


Epoch 2 | Test Acc: 91.02%


Training: 100%|██████████| 938/938 [00:12<00:00, 72.27batch/s, loss=0.3073, acc=90.78%, FLOPs=2531.40G]


Epoch 3 | Test Acc: 91.58%


Training: 100%|██████████| 938/938 [00:12<00:00, 72.23batch/s, loss=0.1868, acc=91.35%, FLOPs=3375.20G]


Epoch 4 | Test Acc: 91.76%


Training: 100%|██████████| 938/938 [00:13<00:00, 72.08batch/s, loss=0.3405, acc=91.64%, FLOPs=4219.00G]


Epoch 5 | Test Acc: 91.93%


Training: 100%|██████████| 938/938 [00:13<00:00, 71.72batch/s, loss=0.1269, acc=91.91%, FLOPs=5062.80G]


Epoch 6 | Test Acc: 92.20%


Training: 100%|██████████| 938/938 [00:12<00:00, 77.16batch/s, loss=0.2807, acc=92.04%, FLOPs=5906.60G]


Epoch 7 | Test Acc: 92.20%


Training: 100%|██████████| 938/938 [00:10<00:00, 86.19batch/s, loss=0.4526, acc=92.24%, FLOPs=6750.40G]


Epoch 8 | Test Acc: 92.42%


Training: 100%|██████████| 938/938 [00:10<00:00, 86.17batch/s, loss=0.1148, acc=92.27%, FLOPs=7594.20G]


Epoch 9 | Test Acc: 92.38%


Training: 100%|██████████| 938/938 [00:10<00:00, 85.83batch/s, loss=0.3289, acc=92.41%, FLOPs=8438.00G]


Epoch 10 | Test Acc: 92.37%

Training Twohidden_MLP
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.


Training: 100%|██████████| 938/938 [00:11<00:00, 80.90batch/s, loss=0.3214, acc=83.80%, FLOPs=1054.60G]


Epoch 1 | Test Acc: 91.13%


Training: 100%|██████████| 938/938 [00:11<00:00, 81.48batch/s, loss=0.1322, acc=91.48%, FLOPs=2109.20G]


Epoch 2 | Test Acc: 92.58%


Training: 100%|██████████| 938/938 [00:11<00:00, 81.05batch/s, loss=0.2427, acc=93.05%, FLOPs=3163.80G]


Epoch 3 | Test Acc: 93.65%


Training: 100%|██████████| 938/938 [00:11<00:00, 81.29batch/s, loss=0.2655, acc=94.06%, FLOPs=4218.39G]


Epoch 4 | Test Acc: 94.39%


Training: 100%|██████████| 938/938 [00:11<00:00, 81.16batch/s, loss=0.3097, acc=94.71%, FLOPs=5272.99G]


Epoch 5 | Test Acc: 94.80%


Training: 100%|██████████| 938/938 [00:11<00:00, 78.68batch/s, loss=0.3028, acc=95.28%, FLOPs=6327.59G]


Epoch 6 | Test Acc: 95.48%


Training: 100%|██████████| 938/938 [00:11<00:00, 81.22batch/s, loss=0.0626, acc=95.79%, FLOPs=7382.19G]


Epoch 7 | Test Acc: 95.86%


Training: 100%|██████████| 938/938 [00:11<00:00, 80.43batch/s, loss=0.1280, acc=96.17%, FLOPs=8436.79G]


Epoch 8 | Test Acc: 95.99%


Training: 100%|██████████| 938/938 [00:11<00:00, 78.90batch/s, loss=0.0514, acc=96.61%, FLOPs=9491.39G]


Epoch 9 | Test Acc: 96.34%


Training: 100%|██████████| 938/938 [00:11<00:00, 80.93batch/s, loss=0.0812, acc=96.89%, FLOPs=10545.98G]


Epoch 10 | Test Acc: 96.52%

Training Static_MLP
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.


Training: 100%|██████████| 938/938 [00:09<00:00, 95.41batch/s, loss=0.3530, acc=79.11%, FLOPs=388.74G] 


Epoch 1 | Test Acc: 90.49%


Training: 100%|██████████| 938/938 [00:09<00:00, 94.02batch/s, loss=0.0262, acc=92.50%, FLOPs=777.48G] 


Epoch 2 | Test Acc: 93.27%


Training: 100%|██████████| 938/938 [00:10<00:00, 92.16batch/s, loss=0.0359, acc=94.99%, FLOPs=1166.23G]


Epoch 3 | Test Acc: 94.57%


Training: 100%|██████████| 938/938 [00:10<00:00, 91.57batch/s, loss=0.0312, acc=96.35%, FLOPs=1554.97G]


Epoch 4 | Test Acc: 95.50%


Training: 100%|██████████| 938/938 [00:09<00:00, 99.49batch/s, loss=0.0939, acc=97.40%, FLOPs=1943.71G] 


Epoch 5 | Test Acc: 95.89%


Training: 100%|██████████| 938/938 [00:12<00:00, 74.31batch/s, loss=0.0497, acc=98.13%, FLOPs=2332.45G] 


Epoch 6 | Test Acc: 95.84%


Training: 100%|██████████| 938/938 [00:13<00:00, 70.97batch/s, loss=0.0431, acc=98.69%, FLOPs=2721.20G]


Epoch 7 | Test Acc: 96.56%


Training: 100%|██████████| 938/938 [00:12<00:00, 72.22batch/s, loss=0.0237, acc=99.07%, FLOPs=3109.94G]


Epoch 8 | Test Acc: 96.61%


Training: 100%|██████████| 938/938 [00:12<00:00, 74.24batch/s, loss=0.0150, acc=99.40%, FLOPs=3498.68G]


Epoch 9 | Test Acc: 96.76%


Training: 100%|██████████| 938/938 [00:10<00:00, 86.17batch/s, loss=0.0207, acc=99.60%, FLOPs=3887.42G]


Epoch 10 | Test Acc: 96.92%
