In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from thop import profile
import pickle
import numpy as np
from tqdm import tqdm
from net import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

experiments = {
    # "Static_BaseRNN": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam"},
    # "Static_BaseRNN_random": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "random"},
    # "Static_BaseRNN_RandSparse": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "randsparse"},
    # "Static_BaseRNN_RandStructure":{"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "init": "randstructure"},
    # "Learnable_BaseRNN": {"type": "basernn", "trainable": True, "pruning": None, "optim": "adam"},
    # "CWS_Droso": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": None, "optim": "adam"},
    # "CWS_TrainC_pruning": {"type": "cwsrnn", "train_W": True, "train_C": True, "pruning": "drosophila", "optim": "adam"},
    # "CWS_FixedC_random": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": None, "init": "random", "optim": "adam"},
    # "Static_CNN_RNN": {"type": "cnnrnn", "trainable": False, "pruning": None, "optim": "adam"},
    # "Static_BaseRNN_fewshot": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam", "fewshot": True},
    # "Static_CNN_RNN_fewshot": {"type": "cnnrnn", "trainable": False, "pruning": None, "optim": "adam", "fewshot": True},
    # "Single_MLP": {"type": "singlemlp", "optim": "adam"},
    "Twohidden_MLP": {"type": "twohiddenmlp", "optim": "adam"},
    # "Logistic_Regression": {"type": "logistic", "optim": "adam"},
    # Hungarian
    # "Hungarian_DrosoInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "droso"},
    # "Hungarian_DrosoInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "randsparse"},
    # "Hungarian_DrosoInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "droso", "ref": "randstructure"},
    # "Hungarian_RandInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "droso"},
    # "Hungarian_RandInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "randsparse"},
    # "Hungarian_RandInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "random", "ref": "randstructure"},
    # "Hungarian_RandSparseInit_DrosoRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "droso"},
    # "Hungarian_RandSparseInit_RandSparseRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "randsparse"},
    # "Hungarian_RandSparseInit_RandStructureRef": {"type": "basernn", "trainable": True, "pruning": "hungarian", "optim": "adam", "init": "randsparse", "ref": "randstructure"}
}

def get_input_shape(model_type):
    return (1, 1, 28, 28) if model_type in ["cnnrnn", "singlemlp", "twohiddenmlp", "logistic"] else (1, 28, 28)

def prepare_input(data, model):
    return data if isinstance(model, CNNRNN) else data.squeeze(1)

def train_epoch(model, optimizer, criterion, train_loader, flops_per_sample, cumulative_flops):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    flops_acc_pairs = []
  
    with tqdm(train_loader, unit="batch", desc="Training") as pbar:
        for batch_idx, (data, target) in enumerate(pbar):
            data = prepare_input(data.to(device), model)
            target = target.to(device)
          
            batch_flops = flops_per_sample * data.size(0) * 3
            cumulative_flops += batch_flops
          
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
          
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += data.size(0)
          
            if (batch_idx + 1) % 100 == 0:
                flops_acc_pairs.append((cumulative_flops, correct/total))
          
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{correct/total:.2%}",
                'FLOPs': f"{cumulative_flops/1e9:.2f}G"
            })
    
    # dynamicaaly pruning: run after each epoch ends
    if isinstance(model, BaseRNN) and model.pruning_method == "hungarian":
        model.apply_hungarian_pruning()
    elif isinstance(model, CWSRNN):
        model.apply_drosophila_pruning()
  
    return total_loss/total, correct/total, flops_acc_pairs, cumulative_flops

def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    activations_list = []
    with torch.no_grad():
        for data, target in test_loader:
            data = prepare_input(data.to(device), model)
            target = target.to(device)
            
            if isinstance(model, (BaseRNN, CWSRNN)):
                batch_size = data.size(0)
                r_t = torch.zeros(batch_size, model.hidden_size, device=device)
                act_list = []
                
                E_t = model.input_to_hidden(data.view(batch_size, -1))
                W_eff = model.W if isinstance(model, BaseRNN) else (model.C * model.W * model.s.unsqueeze(1))
                r_t = torch.relu(r_t @ W_eff + E_t + r_t)
                act_list.append(r_t)
                
                zero_input = torch.zeros(batch_size, model.input_size, device=device)
                for _ in range(9):
                    E_t = model.input_to_hidden(zero_input)
                    r_t = torch.relu(r_t @ W_eff + E_t + r_t)
                    act_list.append(r_t)
                
                batch_activations = torch.stack(act_list, dim=0)
                batch_mean_activations = batch_activations.mean(dim=1)
                activations_list.append(batch_mean_activations)
                
                output = model.hidden_to_output(r_t)
            else:
                output = model(data)
            
            correct += output.argmax(dim=1).eq(target).sum().item()
            total += target.size(0)
    
    if activations_list:
        activations = torch.stack(activations_list, dim=0).mean(dim=0).cpu().numpy()
    else:
        activations = None
    
    return correct / total, activations

def train_experiment(exp_id, config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    
    if config.get("fewshot", False):
        num_classes = 10
        samples_per_class = int(len(train_set) * 0.05 / num_classes)
        indices = []
        targets = np.array(train_set.targets)
        
        for cls in range(num_classes):
            cls_indices = np.where(targets == cls)[0]
            sampled_indices = np.random.choice(cls_indices, samples_per_class, replace=False)
            indices.extend(sampled_indices)
        
        train_set = torch.utils.data.Subset(train_set, indices)
        print(f"Few-shot training set size: {len(train_set)} samples")
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
  
    non_zero = None
    W_droso = load_drosophila_matrix('./data/ad_connectivity_matrix.csv')
    if config.get('pruning') == 'drosophila':
        W_droso, non_zero = load_drosophila_matrix('./data/ad_connectivity_matrix.csv', apply_pruning=True)
    elif config.get('pruning') == 'hungarian':
        W_droso = load_drosophila_matrix('./data/ad_connectivity_matrix.csv')
  
    W_init = None
    if config.get('init') == 'random':
        W_init = torch.randn(W_droso.shape[0], W_droso.shape[0])
    elif config.get('init') == 'droso':
        W_init = W_droso
    elif config.get('init') == 'randsparse':
        non_zero_count = np.count_nonzero(W_droso)
        total_elements = W_droso.size
        mask = torch.zeros(W_droso.shape, dtype=torch.float32)
        indices = torch.randperm(total_elements)[:non_zero_count]
        mask.view(-1)[indices] = 1
        W_init = torch.randn(W_droso.shape) * mask
    elif config.get('init') == 'randstructure':
        mask = (torch.tensor(W_droso) != 0).float()
        W_init = torch.randn(W_droso.shape) * mask
    else:
        W_init = W_droso if not config.get('trainable', True) else None
  
    W_ref = None
    if config.get('ref') == 'droso':
        W_ref = W_droso
    elif config.get('ref') == 'randsparse':
        non_zero_count = np.count_nonzero(W_droso)
        total_elements = W_droso.size
        mask = torch.zeros(W_droso.shape, dtype=torch.float32)
        indices = torch.randperm(total_elements)[:non_zero_count]
        mask.view(-1)[indices] = 1
        W_ref = torch.randn(W_droso.shape) * mask
    elif config.get('ref') == 'randstructure':
        mask = (torch.tensor(W_droso) != 0).float()
        W_ref = torch.randn(W_droso.shape) * mask
  
    if config['type'] == 'basernn':
        model = BaseRNN(
            784, W_droso.shape[0], 10,
            W_init=W_init,
            W_ref=W_ref,
            trainable=config['trainable'],
            pruning_method=config.get('pruning')
        )
    elif config['type'] == 'cwsrnn':
        model = CWSRNN(
            784, W_droso.shape[0], 10, W_droso,
            train_W=config['train_W'],
            train_C=config.get('train_C', False),
            non_zero_count=non_zero if config.get('pruning') == 'drosophila' else None
        )
    elif config['type'] == 'cnnrnn':
        model = CNNRNN(torch.tensor(W_droso))
    elif config['type'] == 'singlemlp':
        model = SingleMLP(784, W_droso.shape[0], 10)
    elif config['type'] == 'twohiddenmlp':
        model = TwohiddenMLP(784, 1360, 10)
    elif config['type'] == 'logistic':
        model = LogisticRegression(784, 10)
    model.to(device)
  
    if config['optim'] == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=1e-5)
    else:
        optimizer = FISTAOptimizer(model.parameters(), lr=1e-3, lambda_l1=1e-5)
  
    input_shape = get_input_shape(config['type'])
    dummy_input = torch.randn(input_shape).to(device)
    macs, _ = profile(model, inputs=(dummy_input,))
    flops_forward = macs * 2

    results = {
        "epoch_train_loss": [], "epoch_train_acc": [],
        "epoch_test_acc": [], "flops_acc": [],
        "total_flops": 0, "activations": None
    }
    cumulative_flops = 0
    criterion = nn.CrossEntropyLoss()
  
    for epoch in range(10):
        epoch_loss, epoch_acc, flops_pairs, cumulative_flops = train_epoch(
            model, optimizer, criterion, train_loader, flops_forward, cumulative_flops
        )
        test_acc, activations = evaluate(model, test_loader)
      
        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)
        results["epoch_test_acc"].append(test_acc)
        results["flops_acc"].extend(flops_pairs)
        results["activations"] = activations
        print(f"Epoch {epoch+1} | Test Acc: {test_acc:.2%}")
  
    results["total_flops"] = cumulative_flops
    with open(f"{exp_id}.pkl", "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        print(f"\n{'='*30}\nTraining {exp_id}\n{'='*30}")
        train_experiment(exp_id, config)


Training Twohidden_MLP
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.


Training: 100%|██████████| 938/938 [00:12<00:00, 74.69batch/s, loss=0.2356, acc=83.84%, FLOPs=1054.60G]


Epoch 1 | Test Acc: 91.02%


Training: 100%|██████████| 938/938 [00:12<00:00, 76.51batch/s, loss=0.2103, acc=91.43%, FLOPs=2109.20G]


Epoch 2 | Test Acc: 92.73%


Training: 100%|██████████| 938/938 [00:11<00:00, 78.93batch/s, loss=0.1190, acc=93.02%, FLOPs=3163.80G]


Epoch 3 | Test Acc: 93.61%


Training: 100%|██████████| 938/938 [00:11<00:00, 79.31batch/s, loss=0.1063, acc=93.94%, FLOPs=4218.39G]


Epoch 4 | Test Acc: 94.35%


Training: 100%|██████████| 938/938 [00:12<00:00, 76.31batch/s, loss=0.1694, acc=94.73%, FLOPs=5272.99G]


Epoch 5 | Test Acc: 94.89%


Training: 100%|██████████| 938/938 [00:11<00:00, 78.20batch/s, loss=0.1674, acc=95.33%, FLOPs=6327.59G]


Epoch 6 | Test Acc: 95.38%


Training: 100%|██████████| 938/938 [00:11<00:00, 79.88batch/s, loss=0.1292, acc=95.75%, FLOPs=7382.19G]


Epoch 7 | Test Acc: 95.73%


Training: 100%|██████████| 938/938 [00:11<00:00, 79.94batch/s, loss=0.2426, acc=96.19%, FLOPs=8436.79G]


Epoch 8 | Test Acc: 96.16%


Training: 100%|██████████| 938/938 [00:11<00:00, 79.47batch/s, loss=0.1874, acc=96.54%, FLOPs=9491.39G]


Epoch 9 | Test Acc: 96.37%


Training: 100%|██████████| 938/938 [00:11<00:00, 78.63batch/s, loss=0.1532, acc=96.84%, FLOPs=10545.98G]


Epoch 10 | Test Acc: 96.55%
