In [None]:
import os
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import pickle
import numpy as np
from tqdm import tqdm
from net import *
from connectome_utils import *
from sklearn.model_selection import StratifiedShuffleSplit

# Load config
with open("config.yaml", "r") as f:
    config_data = yaml.safe_load(f)

# Global parameters
signed = config_data.get("signed", True)
num_trials = config_data.get("num_trials", 10)
num_epochs = config_data.get("num_epochs", 10)
batch_size = config_data.get("batch_size", 64)
learning_rate = config_data.get("learning_rate", 0.001)
experiments = config_data.get("experiments", {})

# Few-shot settings
fewshot_config = config_data.get("fewshot", {})
fewshot_enabled = fewshot_config.get("enabled", False)
fewshot_samples = fewshot_config.get("samples", 60)
fewshot_batch_size = fewshot_config.get("batch_size", 10)
if fewshot_enabled:
    fewshot_experiments = {}
    for exp_id, exp_config in experiments.items():
        cfg = exp_config.copy()
        cfg["fewshot"] = fewshot_samples
        cfg["fewshot_batch_size"] = fewshot_batch_size
        fewshot_experiments[f"{exp_id}_fewshot_{fewshot_samples}"] = cfg
    experiments = fewshot_experiments

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create stratified few-shot subset
def create_fewshot_subset(dataset, seed, samples_per_class=60):
    targets = np.array(dataset.targets)
    train_size = (samples_per_class * 10) / len(targets)
    sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size, random_state=seed)
    indices, _ = next(sss.split(np.zeros_like(targets), targets))
    return torch.utils.data.Subset(dataset, indices)

def prepare_input(data, model):
    return data if isinstance(model, CNNRNN) else data.squeeze(1)

def get_weight_matrix(base, mode, trainable=True):
    if mode == 'random':
        arr_np = np.random.randn(*base.shape).astype(np.float32)
        return arr_np
    
    elif mode == 'droso':
        return base
    
    elif mode == 'randsparse':
        non_zero = np.count_nonzero(base)
        mask = np.zeros(base.shape, dtype=np.float32)
        idx = np.random.permutation(mask.size)[:non_zero]
        mask.flat[idx] = 1
        arr_np = np.random.randn(*base.shape).astype(np.float32) * mask
        return arr_np
    
    elif mode == 'randstructure':
        mask = (base != 0).astype(np.float32)
        arr_np = np.abs(np.random.randn(*base.shape).astype(np.float32)) * mask
        return arr_np
    
    else:
        return base if not trainable else None


def load_base_matrix(cfg_data, signed, config):
    path = cfg_data["csv_paths"]["signed"] if signed else cfg_data["csv_paths"]["unsigned"]
    if config.get('pruning') == 'drosophila':
        W, _ = load_drosophila_matrix(path, apply_pruning=True, signed=signed)
    else:
        W = load_drosophila_matrix(path, apply_pruning=False, signed=signed)
    return W

def load_connectivity_info(mode, W_matrix, SIO, cfg_data):
    if mode == 'droso':
        if SIO:
            return load_sio_connectivity_data(
                connectivity_path=cfg_data["csv_paths"]["signed"],
                annotation_path=cfg_data["annotation_path"]
            )
        else:
            return load_connectivity_data(
                connectivity_path=cfg_data["csv_paths"]["signed"],
                annotation_path=cfg_data["annotation_path"]
            )
    else:
        return load_random_matrix(W_matrix, 29)

def load_datasets(transform):
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
    return train_set, test_loader

def initialize_model(config, W_droso):
    if config['type'] == 'drosophilarnn':
        W_init = get_weight_matrix(W_droso, config.get('init'), config.get('trainable', True))
        conn = load_connectivity_info(config.get('init'), W_init, config.get('SIO', False), config_data)
        W_ref, ref = None, None
        pruning_cfg = config.get('pruning', {})
        pruning_enabled = pruning_cfg.get("enable", False)
        if pruning_enabled:
            W_ref = get_weight_matrix(W_droso, config.get('ref'))
            ref = load_connectivity_info(config.get('ref'), W_ref, config.get('SIO', False), config_data)
        return DrosophilaRNN(
            input_dim=784,
            num_classes=10,
            conn_weights=conn,
            ref_weights=ref,
            trainable=config.get('trainable'),
            pruning=pruning_enabled,
            SIO=config.get('SIO'),
            pruning_cfg=pruning_cfg
        )
    elif config['type'] == 'singlemlp':
        return SingleMLP(784, W_droso.shape[0], 10)
    elif config['type'] == 'twohiddenmlp':
        return TwohiddenMLP(784, 300, 10)
    elif config['type'] == 'staticmlp':
        return StaticMLP(784, 300, 10)
    else:
        raise ValueError(f"Unknown model type: {config['type']}")

# Train one epoch
def train_epoch(model, optimizer, criterion, train_loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(train_loader, unit="batch", desc="Training")
    for data, target in pbar:
        data = prepare_input(data.to(device), model)
        target = target.to(device)
        optimizer.zero_grad()

        # If RNN specifically requires flattening:
        data = data.view(data.size(0), -1)

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        correct += output.argmax(dim=1).eq(target).sum().item()
        total += data.size(0)
        train_acc = correct / total if total else 0
        pbar.set_postfix(loss=f"{loss.item():.4f}", train_acc=f"{train_acc:.2%}")
    
    similarity_dict = {}
    if hasattr(model, "apply_structure_constraint_pruning") and model.pruning and model.pruning_constraint=='structure':
        similarity_dict = model.apply_structure_constraint_pruning()
    elif hasattr(model, "apply_sparsity_constraint_pruning") and model.pruning and model.pruning_constraint=='sparsity':
        similarity_dict = model.apply_sparsity_constraint_pruning()

    print(similarity_dict)
    return total_loss / total, correct / total, similarity_dict

# Evaluate model and compute inference FLOPs
def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data = prepare_input(data.to(device), model)
            target = target.to(device)
            if isinstance(model, DrosophilaRNN):
                data = data.view(data.size(0), -1)
            output = model(data)
            correct += output.argmax(dim=1).eq(target).sum().item()
            total += target.size(0)
    acc = correct / total if total > 0 else 0
    return acc

# Run training loop and record results
def run_training_loop(model, config, full_train_set, test_loader, trial_num, num_epochs, batch_size, fewshot_batch_size):
    results = {"epoch_train_loss": [],
               "epoch_train_acc": [],
               "epoch_test_acc": [],
               'submodules_nonzero': [],
               'similarity_dict': []}
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # initial evaluation for epoch 0
    init_acc = evaluate(model, test_loader)
    results["epoch_test_acc"].append(init_acc)

    print(f"Trial {trial_num} | Epoch 0 | Test Acc: {init_acc:.2%}")
    for epoch in range(num_epochs):
        if "fewshot" in config:
            subset = create_fewshot_subset(full_train_set, epoch, config["fewshot"])
            train_loader = torch.utils.data.DataLoader(subset, batch_size=config.get("fewshot_batch_size", fewshot_batch_size), shuffle=True)
        else:
            train_loader = torch.utils.data.DataLoader(full_train_set, batch_size=batch_size, shuffle=True)

        epoch_loss, epoch_acc, similarity_dict = train_epoch(model, optimizer, criterion, train_loader)

        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)
        results["similarity_dict"].append(similarity_dict)

        test_acc = evaluate(model, test_loader)

        # save for futher flops calculation
        submodule_nonzero_dict = {}
        for name, submodule in model.named_children():
            sub_nonzero = 0
            # If you want deeper submodules, consider submodule.named_modules()
            for param in submodule.parameters(recurse=False):
                sub_nonzero += torch.count_nonzero(param).item()
            submodule_nonzero_dict[name] = sub_nonzero
        submodule_nonzero_dict['total'] = sum(torch.count_nonzero(p).item() for p in model.parameters())
        results['submodules_nonzero'].append(submodule_nonzero_dict)
        results["epoch_test_acc"].append(test_acc)

        print(f"submodule nonzero values: {submodule_nonzero_dict}")
        print(f"Trial {trial_num} | Epoch {epoch+1} | Test Acc: {test_acc:.2%}")

    return results

def save_results(exp_id, config, trial_num, results, signed):
    os.makedirs("results", exist_ok=True)
    filename = f"{exp_id}_trial{trial_num}"
    if "fewshot" in config:
        filename = f"{exp_id}_trial{trial_num}"
    if signed:
        filename += ".signed"
    filename += ".pkl"
    with open(os.path.join("results", filename), "wb") as f:
        pickle.dump(results, f)

# Full experiment run
def train_experiment(exp_id, config, trial_num):
    print("========================================")
    print(f"Starting Experiment: {exp_id} Trial {trial_num}")
    print("Experiment configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    print("========================================\n")
    torch.manual_seed(trial_num) # todo
    np.random.seed()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    full_train_set, test_loader = load_datasets(transform)
    W_droso = load_base_matrix(config_data, signed, config)
    model = initialize_model(config, W_droso)
    model.to(device)
    results = run_training_loop(model, config, full_train_set, test_loader, trial_num,
                                  num_epochs, batch_size, fewshot_batch_size)
    save_results(exp_id, config, trial_num, results, signed)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        for trial_num in range(1, num_trials + 1):
            print(f"\n=== Training {exp_id} Trial {trial_num} ===")
            train_experiment(exp_id, config, trial_num)



=== Training Learnable_DPU_fewshot_120 Trial 1 ===
Starting Experiment: Learnable_DPU_fewshot_120 Trial 1
Experiment configuration:
  type: drosophilarnn
  trainable: True
  pruning: {'enable': True, 'constraint': 'structure', 'lambda_reg': 0.0002, 'max_iter': 200, 'fista_threshold': 0.01, 'fista_gamma': 0.7}
  init: droso
  ref: droso
  fewshot: 120
  fewshot_batch_size: 17

Found 29 sensory-visual neuron IDs
Found 29 sensory-visual neuron IDs
Trial 1 | Epoch 0 | Test Acc: 9.73%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.24batch/s, loss=1.8566, train_acc=13.17%]



 Similarity Metrics: Overlap ratio: 0.9997, Non-zero count of X: 63525, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 63525, L1 error: 266.1201, L2 error: 1.6461
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 115530}
Trial 1 | Epoch 1 | Test Acc: 10.10%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.51batch/s, loss=2.1251, train_acc=15.33%]



 Similarity Metrics: Overlap ratio: 0.9753, Non-zero count of X: 61977, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61977, L1 error: 362.9801, L2 error: 2.2061
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113982}
Trial 1 | Epoch 2 | Test Acc: 9.79%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.19batch/s, loss=2.0330, train_acc=19.75%]



 Similarity Metrics: Overlap ratio: 0.9682, Non-zero count of X: 61527, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61527, L1 error: 431.2100, L2 error: 2.6094
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113532}
Trial 1 | Epoch 3 | Test Acc: 9.79%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.17batch/s, loss=2.0923, train_acc=19.25%]



 Similarity Metrics: Overlap ratio: 0.9655, Non-zero count of X: 61353, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61353, L1 error: 483.3000, L2 error: 2.9465
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113358}
Trial 1 | Epoch 4 | Test Acc: 9.82%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.07batch/s, loss=1.9695, train_acc=19.67%]



 Similarity Metrics: Overlap ratio: 0.9649, Non-zero count of X: 61317, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61317, L1 error: 523.6900, L2 error: 3.2529
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113322}
Trial 1 | Epoch 5 | Test Acc: 9.82%


Training: 100%|██████████| 71/71 [00:06<00:00, 11.64batch/s, loss=1.8460, train_acc=21.00%]



 Similarity Metrics: Overlap ratio: 0.9647, Non-zero count of X: 61300, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61300, L1 error: 572.0800, L2 error: 3.6243
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113305}
Trial 1 | Epoch 6 | Test Acc: 9.80%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.18batch/s, loss=1.9672, train_acc=21.17%]



 Similarity Metrics: Overlap ratio: 0.9651, Non-zero count of X: 61325, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61325, L1 error: 605.8600, L2 error: 3.9145
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113330}
Trial 1 | Epoch 7 | Test Acc: 9.80%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.40batch/s, loss=2.4278, train_acc=21.33%]



 Similarity Metrics: Overlap ratio: 0.9650, Non-zero count of X: 61321, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61321, L1 error: 659.2700, L2 error: 4.3216
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113326}
Trial 1 | Epoch 8 | Test Acc: 9.80%


Training: 100%|██████████| 71/71 [00:05<00:00, 11.92batch/s, loss=2.0277, train_acc=19.83%]



 Similarity Metrics: Overlap ratio: 0.9635, Non-zero count of X: 61226, Non-zero count of W_ref: 63545, Overlap of non-zero positions: 61226, L1 error: 746.4200, L2 error: 4.8701
submodule nonzero values: {'input_proj': 22765, 'output_layer': 29240, 'activation': 0, 'total': 113231}
Trial 1 | Epoch 9 | Test Acc: 9.82%


Training: 100%|██████████| 71/71 [00:05<00:00, 12.26batch/s, loss=1.8840, train_acc=23.75%]
