In [None]:
import os
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import pickle
import numpy as np
from tqdm import tqdm
from net import *
from connectome_utils import *
from sklearn.model_selection import StratifiedShuffleSplit

# Load config
with open("config.yaml", "r") as f:
    config_data = yaml.safe_load(f)

# Global parameters
signed = config_data.get("signed", True)
sio = config_data.get("sio", True)
num_trials = config_data.get("num_trials", 10)
num_epochs = config_data.get("num_epochs", 10)
batch_size = config_data.get("batch_size", 64)
learning_rate = config_data.get("learning_rate", 0.001)
experiments = config_data.get("experiments", {})

# Few-shot settings
fewshot_config = config_data.get("fewshot", {})
fewshot_enabled = fewshot_config.get("enabled", False)
fewshot_samples = fewshot_config.get("samples", 60)
fewshot_batch_size = fewshot_config.get("batch_size", 10)
if fewshot_enabled:
    fewshot_experiments = {}
    for exp_id, exp_config in experiments.items():
        cfg = exp_config.copy()
        cfg["fewshot"] = fewshot_samples
        cfg["fewshot_batch_size"] = fewshot_batch_size
        fewshot_experiments[f"{exp_id}_fewshot_{fewshot_samples}"] = cfg
    experiments = fewshot_experiments

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create stratified few-shot subset
def create_fewshot_subset(dataset, seed, samples_per_class=60):
    targets = np.array(dataset.targets)
    train_size = (samples_per_class * 10) / len(targets)
    sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size, random_state=seed)
    indices, _ = next(sss.split(np.zeros_like(targets), targets))
    return torch.utils.data.Subset(dataset, indices)

def get_weight_matrix(base, mode):
    if mode == 'random':
        # use He Initialization for ReLU
        arr_np = (np.random.randn(*base.shape) / np.sqrt(base.shape[0])).astype(np.float32)
        return arr_np
    
    elif mode == 'droso':
        return base
    
    elif mode == 'permuted_droso':
        nonzero_vals = base[base != 0].astype(np.float32)
        np.random.shuffle(nonzero_vals)
        
        non_zero_count = len(nonzero_vals)
        idx = np.random.choice(base.size, non_zero_count, replace=False)
        arr_np = np.zeros_like(base, dtype=np.float32)
        
        arr_np_flat = arr_np.flatten()
        arr_np_flat[idx] = nonzero_vals
        arr_np = arr_np_flat.reshape(base.shape)
        
        return arr_np
    elif mode == 'randsparse':
        non_zero = np.count_nonzero(base)
        mask = np.zeros(base.shape, dtype=np.float32)
        idx = np.random.permutation(mask.size)[:non_zero]
        mask.flat[idx] = 1
        scaling_factor = np.sqrt(non_zero / base.size)  # normalization factor
        arr_np = (np.random.randn(*base.shape) * scaling_factor).astype(np.float32) * mask
        return arr_np
    else:
        return None

def load_connectivity_info(cfg_data):
    if sio:
        return load_sio_connectivity_data(
            connectivity_path=cfg_data["csv_paths"]["signed"],
            annotation_path=cfg_data["annotation_path"], rescale_factor=cfg_data.get('rescale_factor', 4e-2)
        )
    else:
        return load_connectivity_data(
            connectivity_path=cfg_data["csv_paths"]["signed"],
            annotation_path=cfg_data["annotation_path"], rescale_factor=cfg_data.get('rescale_factor', 4e-2)
        )

def load_datasets(transform):
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
    return train_set, test_loader

def initialize_model(config):
    if config['type'] == 'basicrnn':
        conn = load_connectivity_info(config_data)
        # conn['W'] is the rearranged connectivity matrix
        W_init = get_weight_matrix(conn['W'], config.get('init'))

        # Get LoRA configuration
        lora_config = config.get('lora', {})
        use_lora = lora_config.get('enabled', False)
        lora_rank = lora_config.get('rank', 8)
        lora_alpha = lora_config.get('alpha', 16)

        return BasicRNN(
            W_init=W_init,
            input_dim=784,
            sensory_dim=conn['W_ss'].shape[0],
            internal_dim=conn['W_rr'].shape[0],
            output_dim=conn['W_oo'].shape[0],
            num_classes=10,
            trainable=config.get('trainable'),
            pruning=config.get('pruning'),
            target_nonzeros=np.count_nonzero(W_init),
            lambda_l1=config.get('lambda_l1'),
            use_lora=use_lora,
            lora_rank=lora_rank,
            lora_alpha=lora_alpha
        )   
    elif config['type'] == 'threehiddenmlp':
        return ThreeHiddenMLP(784, 29, 147, 400, 10, config.get('freeze', False))
    else:
        raise ValueError(f"Unknown model type: {config['type']}")

# Train one epoch
def train_epoch(model, optimizer, criterion, train_loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(train_loader, unit="batch", desc="Training")
    for data, target in pbar:
        data.squeeze(1)
        target = target.to(device)
        optimizer.zero_grad()

        output = model(data)

        # L1-penalized training-loss to perserve sparsity level
        if model.pruning:
            logits = model(data)
            ce_loss = F.cross_entropy(logits, target)
            l1_loss = model.lambda_l1 * model.get_l1_loss() if model.lambda_l1 is not None else 0
            loss = ce_loss + l1_loss
        else:
            loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        correct += output.argmax(dim=1).eq(target).sum().item()
        total += data.size(0)
        train_acc = correct / total if total else 0
        pbar.set_postfix(loss=f"{loss.item():.4f}", train_acc=f"{train_acc:.2%}")
    
    if hasattr(model, "enforce_sparsity") and model.pruning:
        print("enforce sparsity start, nonzeros: ", torch.count_nonzero(model.W).item())
        model.enforce_sparsity()
        print("enforce sparsity end, nonzeros: ", torch.count_nonzero(model.W).item())
    return total_loss / total, correct / total

# Evaluate model and compute inference FLOPs
def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data.squeeze(1)
            target = target.to(device)
            output = model(data)
            correct += output.argmax(dim=1).eq(target).sum().item()
            total += target.size(0)
    acc = correct / total if total > 0 else 0
    return acc

# Run training loop and record results
def run_training_loop(model, config, full_train_set, test_loader, trial_num, num_epochs, batch_size, fewshot_batch_size):
    results = {"epoch_train_loss": [],
               "epoch_train_acc": [],
               "epoch_test_acc": [],
               'submodules_nonzero': [],
               'similarity_dict': []}
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # initial evaluation for epoch 0
    init_acc = evaluate(model, test_loader)
    results["epoch_test_acc"].append(init_acc)

    print(f"Trial {trial_num} | Epoch 0 | Test Acc: {init_acc:.2%}")
    for epoch in range(num_epochs):
        if "fewshot" in config:
            subset = create_fewshot_subset(full_train_set, epoch, config["fewshot"])
            train_loader = torch.utils.data.DataLoader(subset, batch_size=config.get("fewshot_batch_size", fewshot_batch_size), shuffle=True)
        else:
            train_loader = torch.utils.data.DataLoader(full_train_set, batch_size=batch_size, shuffle=True)

        epoch_loss, epoch_acc = train_epoch(model, optimizer, criterion, train_loader)

        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)

        test_acc = evaluate(model, test_loader)

        # save for further flops calculation
        submodule_nonzero_dict = {}
        for name, submodule in model.named_children():
            sub_nonzero = 0
            for param in submodule.parameters(recurse=False):
                sub_nonzero += torch.count_nonzero(param).item()
            submodule_nonzero_dict[name] = sub_nonzero
        submodule_nonzero_dict['total'] = sum(torch.count_nonzero(p).item() for p in model.parameters())
        results['submodules_nonzero'].append(submodule_nonzero_dict)
        results["epoch_test_acc"].append(test_acc)

        print(f"submodule nonzero values: {submodule_nonzero_dict}")
        print(f"Trial {trial_num} | Epoch {epoch+1} | Test Acc: {test_acc:.2%}")

    return results

def save_results(exp_id, config, trial_num, results, signed):
    os.makedirs("results", exist_ok=True)
    filename = f"{exp_id}_trial{trial_num}"
    if "fewshot" in config:
        filename = f"{exp_id}_trial{trial_num}"
    if signed:
        filename += ".signed"
    filename += ".pkl"
    with open(os.path.join("results", filename), "wb") as f:
        pickle.dump(results, f)

# Full experiment run
def train_experiment(exp_id, config, trial_num):
    print("========================================")
    print(f"Starting Experiment: {exp_id} Trial {trial_num}")
    print("Experiment configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    print("========================================\n")
    torch.manual_seed(trial_num) # TODO
    np.random.seed()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    full_train_set, test_loader = load_datasets(transform)
    model = initialize_model(config)
    model.to(device)
    results = run_training_loop(model, config, full_train_set, test_loader, trial_num,
                                  num_epochs, batch_size, fewshot_batch_size)
    save_results(exp_id, config, trial_num, results, signed)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        for trial_num in range(1, num_trials + 1):
            print(f"\n=== Training {exp_id} Trial {trial_num} ===")
            train_experiment(exp_id, config, trial_num)



=== Training Learnable_RNN_No_Sparsity_fewshot_300 Trial 1 ===
Starting Experiment: Learnable_RNN_No_Sparsity_fewshot_300 Trial 1
Experiment configuration:
  type: basicrnn
  trainable: True
  init: droso
  fewshot: 300
  fewshot_batch_size: 17

Annotation file: Found 29 sensory neuron IDs
Annotation file: Found 400 output neuron IDs
Connectivity matrix contains 2952 neurons
After filtering, found 29 sensory neurons in matrix
After filtering, found 400 output neurons in matrix
Remaining 2523 neurons classified as internal
BasicRNN init: trainable=True, pruning=None, target_nonzeros=63545, lambda_l1=None
LoRA config: use_lora=False, rank=8, alpha=16
Trial 1 | Epoch 0 | Test Acc: 9.31%


Training: 100%|██████████| 177/177 [00:23<00:00,  7.68batch/s, loss=1.5462, train_acc=18.17%]


submodule nonzero values: {'input_proj': 22765, 'output_layer': 4010, 'activation': 0, 'total': 5573278}
Trial 1 | Epoch 1 | Test Acc: 19.66%


Training: 100%|██████████| 177/177 [00:22<00:00,  7.80batch/s, loss=1.6001, train_acc=24.47%]


submodule nonzero values: {'input_proj': 22765, 'output_layer': 4010, 'activation': 0, 'total': 5587360}
Trial 1 | Epoch 2 | Test Acc: 31.02%


Training: 100%|██████████| 177/177 [00:22<00:00,  7.74batch/s, loss=1.1505, train_acc=31.07%]


submodule nonzero values: {'input_proj': 22765, 'output_layer': 4010, 'activation': 0, 'total': 5593251}
Trial 1 | Epoch 3 | Test Acc: 36.35%


Training: 100%|██████████| 177/177 [00:23<00:00,  7.66batch/s, loss=1.6133, train_acc=38.80%]


submodule nonzero values: {'input_proj': 22765, 'output_layer': 4010, 'activation': 0, 'total': 5593953}
Trial 1 | Epoch 4 | Test Acc: 43.95%


Training: 100%|██████████| 177/177 [00:22<00:00,  7.91batch/s, loss=1.1844, train_acc=45.40%]


submodule nonzero values: {'input_proj': 22765, 'output_layer': 4010, 'activation': 0, 'total': 5595055}
Trial 1 | Epoch 5 | Test Acc: 49.79%


Training: 100%|██████████| 177/177 [00:25<00:00,  7.03batch/s, loss=1.1358, train_acc=54.17%]


submodule nonzero values: {'input_proj': 22765, 'output_layer': 4010, 'activation': 0, 'total': 5597388}
Trial 1 | Epoch 6 | Test Acc: 52.97%


Training:  69%|██████▉   | 122/177 [00:16<00:07,  7.68batch/s, loss=0.5592, train_acc=62.34%]