In [3]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder # Useful for Tiny ImageNet structure
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt # Keep if needed later
import numpy as np
import pandas as pd
import time
import copy
# from itertools import product # No longer needed for CNN configs
import os
import json
import snntorch as snn
import snntorch.surrogate as surrogate
from snntorch import utils
from snntorch import functional as SF
from PIL import Image # Needed for TinyImageNet loading potentially
import shutil # For cleaning up directories if needed

# --- Configuration Class (Updated for Tiny ImageNet & ResNet) ---
class Config:
    # 数据集
    dataset_name = "TinyImageNet"
    data_root = './tiny-imagenet-200' # <<<--- IMPORTANT: SET PATH TO YOUR TINY IMAGENET FOLDER
    batch_size = 64 # May need to reduce based on GPU memory with ResNet
    input_size = 224 # Standard input size for ImageNet pre-trained models
    num_classes = 200 # Tiny ImageNet has 200 classes

    # CNN Backbone (Fixed to ResNet-18)
    backbone_name = "ResNet-18_pretrained"

    # Common SNN/Encoding Parameters
    chaos_dim = 128 # Dimension for projection before oscillator/lorenz/SNN layers
    num_steps = 5 # Reduced num_steps initially for faster testing with ResNet

    # --- Oscillator Parameters ---
    osc_alpha = 2.0
    osc_beta = 0.1
    osc_gamma = 0.1
    osc_omega = 1.0
    osc_drive = 0.0
    # osc_delta will be set specifically
    osc_dt = 0.05

    # --- Lorenz Parameters ---
    lorenz_sigma = 10.0
    lorenz_rho = 28.0
    lorenz_beta = 8.0/3.0
    lorenz_dt = 0.05

    # SNN Decay Rate
    beta = 0.95 # SNN Leaky Neuron Beta (Decay Rate)

    # 训练
    epochs = 200 # Adjust max epochs for Tiny ImageNet & fine-tuning
    lr = 1e-4 # Lower initial LR often better for fine-tuning
    weight_decay = 5e-4
    patience = 15

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_grad = surrogate.fast_sigmoid()

# --- OscillatorTransformFast Class (Unchanged) ---
class OscillatorTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.alpha = config.osc_alpha
        self.beta_osc = config.osc_beta
        self.gamma = config.osc_gamma
        self.delta = getattr(config, 'osc_delta', 0.0) # Allows delta to be optional
        self.omega = config.osc_omega
        self.dt = config.osc_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        current_delta = self.delta
        state = torch.cat([x, x*0.2, -x], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = y_cur
            dy = -self.alpha * x_cur - self.beta_osc * (x_cur**3) - current_delta * y_cur + self.gamma * z_cur
            dz = -self.omega * x_cur - current_delta * z_cur + self.gamma * x_cur * y_cur # Corrected term to match potential original intent or common forms
            derivatives = torch.cat([dx, dy, dz], dim=1)
            state = state + self.dt * derivatives
            trajectories[:, t, :] = state
        return trajectories

# --- LorenzTransformFast Class (Unchanged) ---
class LorenzTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.sigma = config.lorenz_sigma
        self.rho = config.lorenz_rho
        self.lorenz_beta_param = config.lorenz_beta
        self.dt = config.lorenz_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        state = torch.cat([
            x,
            0.2*x, # Initial y
            -x     # Initial z
        ], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = self.sigma * (y_cur - x_cur)
            dy = x_cur * (self.rho - z_cur) - y_cur
            dz = x_cur * y_cur - self.lorenz_beta_param * z_cur
            state = state + self.dt * torch.cat([dx, dy, dz], dim=1)
            trajectories[:, t, :] = state
        return trajectories

# --- Helper function to get ResNet-18 backbone ---
def _get_resnet_backbone(pretrained=True):
    if pretrained:
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        model = torchvision.models.resnet18(weights=weights)
        print("Loaded PRETRAINED ResNet-18 weights.")
    else:
        model = torchvision.models.resnet18(weights=None)
        print("Initialized ResNet-18 weights FROM SCRATCH.")
    model.fc = nn.Identity()
    return model

# --- CNNOscSNN Class (Unchanged, relies on config.num_classes) ---
class CNNOscSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.oscillator = OscillatorTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta, spike_grad=spike_grad)
        self.lif2 = snn.Leaky(beta=config.beta, spike_grad=spike_grad)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes)

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.oscillator(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- CNNLorenzSNN Class (Unchanged, relies on config.num_classes) ---
class CNNLorenzSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lorenz = LorenzTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta, spike_grad=spike_grad)
        self.lif2 = snn.Leaky(beta=config.beta, spike_grad=spike_grad)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes)

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.lorenz(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BasicCSNN Class (Unchanged, relies on config.num_classes) ---
class BasicCSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lif1 = snn.Leaky(beta=config.beta, spike_grad=spike_grad)
        self.lif2 = snn.Leaky(beta=config.beta, spike_grad=spike_grad)
        self.fc_out = nn.Linear(config.chaos_dim, config.num_classes)

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        cnn_features = torch.tanh(self.proj(features))
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        total_output_mem = 0
        batch_total_spikes = 0.0
        for _ in range(self.num_steps):
            spk1, mem1 = self.lif1(cnn_features, mem1) # Input features are static over steps
            spk2, mem2 = self.lif2(spk1, mem2)
            total_output_mem += self.fc_out(spk2)
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = total_output_mem
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BaseCNN Class (Unchanged, relies on config.num_classes) ---
class BaseCNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.fc1 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.fc2 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.classifier = nn.Linear(config.chaos_dim, config.num_classes)

    def forward(self, x, **kwargs): # Allow kwargs for compatibility
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.classifier(x)
        return x

# --- Evaluate 函数 - 返回 total_spikes 而不是 average ---
def evaluate(model, loader, config):
    model.eval()
    correct = 0
    total = 0
    total_spikes_evaluated = 0.0
    is_snn = isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN))

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            if is_snn:
                outputs, batch_spikes = model(images, return_spikes=True)
                total_spikes_evaluated += batch_spikes.item()
            else:
                outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100. * correct / total if total > 0 else 0.0
    return accuracy, total_spikes_evaluated


# --- Train and Evaluate with History ---
def train_and_evaluate_with_history(model, train_loader, test_loader, config, current_run_output_dir):
    model = model.to(device)
    head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone.') and p.requires_grad]
    backbone_params = [p for n, p in model.named_parameters() if n.startswith('backbone.') and p.requires_grad]

    optimizer_params = []
    if backbone_params:
        optimizer_params.append({'params': backbone_params, 'lr': config.lr * 0.1})
    if head_params:
        optimizer_params.append({'params': head_params, 'lr': config.lr})

    if not optimizer_params: # Handle case where no parameters require grad (e.g. fully frozen model)
        print("Warning: No parameters to optimize. Skipping training.")
        # Return plausible default values
        test_acc_final, spikes_final = evaluate(model, test_loader, config)
        return test_acc_final, [test_acc_final], spikes_final, [spikes_final], 0


    optimizer = torch.optim.Adam(optimizer_params, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    criterion = nn.CrossEntropyLoss()

    best_test_acc_epoch = 0.0
    epochs_no_improve = 0
    patience = config.patience
    best_model_path = os.path.join(current_run_output_dir, f"temp_best_model_{time.time()}_{id(model)}.pth")

    history = []
    spike_counts = []
    best_epoch_idx = 0

    print(f"--- Starting Training (Max Epochs: {config.epochs}, Patience: {patience}) ---")

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        start_epoch_time = time.time()

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images, return_spikes=False)
            loss = criterion(outputs, labels)

            if not torch.isfinite(loss):
                print(f"WARNING: Loss is {loss.item()} at epoch {epoch+1}, batch {i}. Skipping backward pass.")
                continue

            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            total_loss += loss.item()

        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time
        test_acc, epoch_spikes = evaluate(model, test_loader, config)
        history.append(test_acc)
        spike_counts.append(epoch_spikes)

        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = 100. * train_correct / train_total if train_total > 0 else 0

        current_lr = scheduler.get_last_lr()[0] if scheduler.get_last_lr() else config.lr # Handle empty LR list if T_max is 0
        print(f"Epoch [{epoch + 1}/{config.epochs}] Loss: {avg_loss:.4f} "
              f"Train Acc: {train_acc:.2f}% Test Acc: {test_acc:.2f}% "
              f"Total Spikes: {epoch_spikes:.0f} "
              f"Epoch Time: {epoch_duration:.2f}s LR: {current_lr:.1e}")

        if test_acc > best_test_acc_epoch:
            best_test_acc_epoch = test_acc
            best_epoch_idx = epoch
            epochs_no_improve = 0
            try:
                torch.save(model.state_dict(), best_model_path)
                print(f"  -> New best test accuracy: {best_test_acc_epoch:.2f}%. Model saved.")
            except Exception as e:
                print(f"  -> Error saving model: {e}")
        else:
            epochs_no_improve += 1
            print(f"  -> Test accuracy did not improve for {epochs_no_improve} epoch(s). Best: {best_test_acc_epoch:.2f}%")

        if epochs_no_improve >= patience:
            print(f"\n--- Early stopping triggered after {epoch + 1} epochs. ---")
            break
        scheduler.step()

    print("\n--- Training finished. Loading best model for final evaluation. ---")
    if os.path.exists(best_model_path):
        try:
            model.load_state_dict(torch.load(best_model_path))
            print(f"Successfully loaded best model state (Accuracy: {best_test_acc_epoch:.2f}%)")
            os.remove(best_model_path)
        except Exception as e:
            print(f"Error loading best model state from {best_model_path}: {e}. Using model from last epoch.")
            best_test_acc_epoch = history[-1] if history else 0.0
            best_epoch_idx = len(history) -1 if history else 0
    else:
        print("No best model was saved (or file missing). Using model from last epoch.")
        best_test_acc_epoch = history[-1] if history else 0.0
        best_epoch_idx = len(history) -1 if history else 0


    print("--- Running final evaluation on the best performing model state ---")
    final_test_acc, final_total_spikes_on_best_model = evaluate(model, test_loader, config) # Re-evaluate best model

    # Spikes at convergence are the spikes recorded during the epoch that achieved the best accuracy
    spikes_at_convergence = spike_counts[best_epoch_idx] if spike_counts and best_epoch_idx < len(spike_counts) else 0

    print(f"Final Evaluation on Best Model - Test Accuracy: {final_test_acc:.2f}%") # Should be same as best_test_acc_epoch
    if isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN)):
        print(f"Final Evaluation - Spikes at Convergence (Epoch {best_epoch_idx+1}): {spikes_at_convergence:.0f}")
        # The final_total_spikes_on_best_model is the re-evaluated spike count on the loaded best model,
        # which should ideally be very close to spikes_at_convergence.
        print(f"Final Evaluation - Total Spikes (re-evaluating best model): {final_total_spikes_on_best_model:.0f}")


    # Ensure history and spike_counts are returned even if training stopped early or had issues
    if not history: history = [0.0]
    if not spike_counts: spike_counts = [0.0]

    return best_test_acc_epoch, history, spikes_at_convergence, spike_counts, best_epoch_idx


# --- Tiny ImageNet Loading Function (Unchanged) ---
def load_tiny_imagenet(config):
    data_dir = config.data_root
    num_workers = min(4, os.cpu_count()) if os.cpu_count() else 0
    image_size = config.input_size

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        normalize,
    ])

    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val', 'images')

    if not os.path.exists(train_dir) or not os.path.exists(val_dir):
            raise FileNotFoundError(f"Tiny ImageNet data not found at expected paths: {train_dir} and {val_dir}. "
                                    f"Please ensure Tiny ImageNet is downloaded and extracted to '{data_dir}' "
                                    "following the standard directory structure.")

    train_dataset = ImageFolder(train_dir, transform=train_transform)

    class TinyImageNetVal(Dataset):
        def __init__(self, val_dir, annotations_file, class_to_idx, transform=None):
            self.val_dir = val_dir
            self.transform = transform
            self.class_to_idx = class_to_idx
            self.samples = []
            try:
                with open(annotations_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            img_name, wnid = parts[0], parts[1]
                            img_path = os.path.join(self.val_dir, img_name)
                            if os.path.exists(img_path) and wnid in self.class_to_idx:
                                self.samples.append((img_path, self.class_to_idx[wnid]))
            except FileNotFoundError:
                raise FileNotFoundError(f"Validation annotations file not found: {annotations_file}")

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            img_path, target = self.samples[idx]
            try:
                with open(img_path, 'rb') as f:
                    img = Image.open(f).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                return torch.zeros(3, image_size, image_size), -1 # Indicate error

            if self.transform:
                img = self.transform(img)
            return img, target

    class_to_idx = train_dataset.class_to_idx
    val_annotations_file = os.path.join(data_dir, 'val', 'val_annotations.txt')
    val_dataset = TinyImageNetVal(val_dir, val_annotations_file, class_to_idx, transform=val_transform)

    print(f"Tiny ImageNet - Found {len(train_dataset)} training images belonging to {len(train_dataset.classes)} classes.")
    print(f"Tiny ImageNet - Found {len(val_dataset)} validation images.")

    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, drop_last=True # drop_last for stability with some batch norms
    )
    test_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    return train_loader, test_loader


# --- Experiment Logger Class ---
class ExperimentLogger:
    def __init__(self, output_dir): # Takes output_dir
        self.results = {}
        self.output_dir = output_dir # Specific for this run/logger instance
        os.makedirs(self.output_dir, exist_ok=True)
        self.csv_headers = [
            "Config Name",
            "Model",
            "Delta/System",
            "Best Test Accuracy (%)",
            "Spikes at Convergence",
            "Training Time (s)",
            "Epochs Trained",
            "Convergence Epoch"
        ]

    def log_config(self, config_name, config):
        config_dict = {}
        config_dict['patience'] = config.patience
        for key, value in vars(config).items():
                if isinstance(value, (int, float, str, bool, list, tuple, dict, type(None))):
                    config_dict[key] = value
        self.results[config_name] = {
            "config": config_dict,
            "models": {}
        }

    def log_model_result(self, config_name, model_name, accuracy, training_time, epochs_history, spikes_at_convergence, spike_counts_all_epochs, best_epoch_idx):
        if config_name not in self.results:
            self.results[config_name] = {"config": {}, "models": {}}
            print(f"Warning: Config '{config_name}' not pre-logged. Creating entry.")
        epochs_trained = len(epochs_history)
        self.results[config_name]["models"][model_name] = {
            "accuracy": accuracy,
            "training_time": training_time,
            "epochs_history": epochs_history,
            "spikes_at_convergence": spikes_at_convergence,
            "epochs_trained": epochs_trained,
            "spike_counts_all_epochs": spike_counts_all_epochs, # Store all spike counts
            "convergence_epoch": best_epoch_idx + 1 # 1-indexed
        }
        print(f"Logged for {config_name} - {model_name}: Best Acc={accuracy:.2f}, Spikes at Convergence={spikes_at_convergence:.0f}, "
              f"Time={training_time:.2f}s, Epochs={epochs_trained}, Convergence Epoch={best_epoch_idx+1}")

    def save_results_json(self): # Renamed for clarity
        filepath = os.path.join(self.output_dir, "run_results.json") # Specific to this run
        try:
            results_to_save = {}
            for cfg_name, cfg_data in self.results.items():
                results_to_save[cfg_name] = {"config": cfg_data["config"], "models": {}}
                for mdl_name, mdl_data in cfg_data["models"].items():
                    results_to_save[cfg_name]["models"][mdl_name] = {
                        k: v for k, v in mdl_data.items() if k not in ['epochs_history', 'spike_counts_all_epochs'] # Exclude verbose lists
                    }
            with open(filepath, "w") as f:
                json.dump(results_to_save, f, indent=4, default=lambda o: '<not serializable>')
            print(f"Run results saved to {filepath}")
        except Exception as e:
            print(f"Error saving JSON results for run: {e}")

    def save_spike_data_csv(self): # Renamed for clarity
        for cfg_name, cfg_data in self.results.items():
            for mdl_name, mdl_data in cfg_data["models"].items():
                if 'spike_counts_all_epochs' in mdl_data and mdl_data['spike_counts_all_epochs']:
                    spike_df = pd.DataFrame({
                        'Epoch': range(1, len(mdl_data['spike_counts_all_epochs'])+1),
                        'Total Spikes': mdl_data['spike_counts_all_epochs']
                    })
                    spike_file = os.path.join(self.output_dir, f"{cfg_name}_{mdl_name}_spikes_per_epoch.csv")
                    spike_df.to_csv(spike_file, index=False)
                    print(f"Spike data per epoch saved to {spike_file}")

    def generate_summary_table(self):
        rows = []
        for config_name, data in self.results.items():
            for model_name, model_data in data["models"].items():
                delta_str = "N/A"
                if "Osc-SNN Delta=" in model_name:
                    try:
                        delta_val_str = model_name.split('=')[-1].split(')')[0]
                        delta_str = f"Osc(Δ={delta_val_str})"
                    except IndexError:
                        delta_str = "Osc(Δ=ErrorParsing)"
                elif "Lorenz-SNN" in model_name:
                    delta_str = "Lorenz"
                elif "CNN-SNN" in model_name and "Osc" not in model_name and "Lorenz" not in model_name:
                    delta_str = "Direct SNN"
                elif "CNN-ANN" in model_name:
                    delta_str = "ANN"

                spikes = model_data.get("spikes_at_convergence", 0.0)
                epochs_trained = model_data.get("epochs_trained", "N/A")
                convergence_epoch = model_data.get("convergence_epoch", "N/A")

                row = {
                    self.csv_headers[0]: config_name,
                    self.csv_headers[1]: model_name,
                    self.csv_headers[2]: delta_str,
                    self.csv_headers[3]: model_data.get("accuracy", 0.0),
                    self.csv_headers[4]: spikes,
                    self.csv_headers[5]: model_data.get("training_time", 0.0),
                    self.csv_headers[6]: epochs_trained,
                    self.csv_headers[7]: convergence_epoch
                }
                rows.append(row)
        if not rows: return pd.DataFrame(columns=self.csv_headers) # Return empty df if no rows

        df = pd.DataFrame(rows)
        df = df[self.csv_headers] # Ensure column order

        # Formatting numeric columns
        try:
            df[self.csv_headers[3]] = pd.to_numeric(df[self.csv_headers[3]], errors='coerce') # Accuracy
            df[self.csv_headers[4]] = pd.to_numeric(df[self.csv_headers[4]], errors='coerce') # Spikes
            df[self.csv_headers[5]] = pd.to_numeric(df[self.csv_headers[5]], errors='coerce') # Training Time
            # Epochs and Convergence Epoch are usually int, handle potential N/A or string from errors
            df[self.csv_headers[6]] = pd.to_numeric(df[self.csv_headers[6]], errors='coerce').fillna(0).astype(int)
            df[self.csv_headers[7]] = pd.to_numeric(df[self.csv_headers[7]], errors='coerce').fillna(0).astype(int)
        except Exception as e:
            print(f"Error formatting summary table columns for run: {e}")

        filepath = os.path.join(self.output_dir, "run_summary.csv") # Specific to this run
        try:
            # Save with consistent formatting
            df_to_save = df.copy()
            df_to_save[self.csv_headers[3]] = df_to_save[self.csv_headers[3]].map('{:.2f}'.format)
            df_to_save[self.csv_headers[4]] = df_to_save[self.csv_headers[4]].map('{:.0f}'.format)
            df_to_save[self.csv_headers[5]] = df_to_save[self.csv_headers[5]].map('{:.2f}'.format)
            df_to_save.to_csv(filepath, index=False)
            print(f"Run summary table saved to {filepath}")
        except Exception as e:
            print(f"Error saving summary CSV for run: {e}")
        return df


# --- Modified Main Experiment Function for a single run ---
def run_single_experiment(run_idx, base_output_dir):
    print(f"\n{'=' * 70}")
    print(f"开始第 {run_idx + 1} 次实验运行")
    print(f"{'=' * 70}")

    config = Config()
    config_name = f"{config.dataset_name}_{config.backbone_name}"
    current_run_output_dir = os.path.join(base_output_dir, f"run_{run_idx+1}")
    logger = ExperimentLogger(output_dir=current_run_output_dir)

    osc_delta_mode_b = -1.5 # Expansive
    osc_delta_mode_a = 10.0  # Dissipative

    print(f"使用设备: {device}")
    print(f"实验配置: {config_name}")
    print(f"结果将保存在: {current_run_output_dir}")
    # ... (省略部分配置打印，已在 ExperimentLogger 中记录)

    logger.log_config(config_name, config)

    try:
        train_loader, test_loader = load_tiny_imagenet(config)
    except Exception as e:
        print(f"无法加载数据集 {config_name} (运行 {run_idx+1}). 终止此次运行. 错误: {e}")
        import traceback
        traceback.print_exc()
        # Return an empty DataFrame or None to indicate failure of this run
        return pd.DataFrame(columns=logger.csv_headers)


    # --- Define Models to Run ---
    models_to_run = {} # Initialize as dict
    try:
        models_to_run = {
            "Baseline (CNN-ANN)": BaseCNN(config, pretrained_backbone=True),
            "Baseline (CNN-SNN)": BasicCSNN(config, pretrained_backbone=True),
            "Proposed (CNN-Lorenz-SNN)": CNNLorenzSNN(config, pretrained_backbone=True)
        }
        osc_config_b = copy.deepcopy(config); osc_config_b.osc_delta = osc_delta_mode_b
        models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_b})"] = CNNOscSNN(osc_config_b, pretrained_backbone=True)

        osc_config_a = copy.deepcopy(config); osc_config_a.osc_delta = osc_delta_mode_a
        models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_a})"] = CNNOscSNN(osc_config_a, pretrained_backbone=True)
    except Exception as model_init_e:
        print(f"模型初始化失败 (运行 {run_idx+1}): {model_init_e}")
        return pd.DataFrame(columns=logger.csv_headers)


    for model_name, model_instance in models_to_run.items():
        print(f"\n--- 正在为 {config_name} 训练 {model_name} (运行 {run_idx+1}) ---")
        start_time = time.time()
        current_model = model_instance # Already a fresh instance
        current_config_for_model = config # Default
        if "Osc-SNN Delta=" in model_name:
            try:
                delta_val_str = model_name.split('=')[-1].split(')')[0]
                delta_val = float(delta_val_str)
                if delta_val == osc_delta_mode_a: current_config_for_model = osc_config_a
                elif delta_val == osc_delta_mode_b: current_config_for_model = osc_config_b
                else: print(f"警告: 无法匹配 {model_name} 的delta值。使用默认配置。")
            except ValueError:
                 print(f"警告: 解析 {model_name} 的delta值失败。使用默认配置。")


        try:
            best_acc, epochs_history, spikes_at_convergence, spike_counts_all_epochs, best_epoch_idx = train_and_evaluate_with_history(
                current_model, train_loader, test_loader, current_config_for_model, current_run_output_dir
            )
            end_time = time.time()
            training_time = end_time - start_time

            print(f"--- {model_name} 完成. 最佳准确率: {best_acc:.2f}%, "
                  f"收敛时脉冲数 (Epoch {best_epoch_idx+1}): {spikes_at_convergence:.0f}, "
                  f"训练时间: {training_time:.2f}s, 训练轮数: {len(epochs_history)} ---")

            logger.log_model_result(
                config_name, model_name, best_acc, training_time,
                epochs_history, spikes_at_convergence, spike_counts_all_epochs, best_epoch_idx
            )
        except Exception as e:
            end_time = time.time()
            training_time = end_time - start_time
            print(f"!!! {model_name} 在 {config_name} (运行 {run_idx+1}) 的训练/评估过程中发生错误: {e}")
            logger.log_model_result(config_name, model_name, 0.0, training_time, [], 0.0, [], 0)
            import traceback
            traceback.print_exc()

    print(f"\n--- 第 {run_idx+1} 次运行的所有模型已处理完毕 ---")
    logger.save_results_json()
    logger.save_spike_data_csv()
    summary_df_for_run = logger.generate_summary_table() # This now returns the DataFrame

    print(f"第 {run_idx+1} 次运行完成! 结果保存在 '{logger.output_dir}' 目录中")
    if not summary_df_for_run.empty:
        print(f"第 {run_idx+1} 次运行结果汇总:")
        print(summary_df_for_run.to_string(index=False))
    else:
        print(f"第 {run_idx+1} 次运行结果汇总为空.")
    return summary_df_for_run


# --- New function to run multiple experiments and aggregate ---
def run_multiple_experiments(num_runs, base_output_dir_name="all_experiment_runs"):
    all_run_summaries = [] # List to store summary DataFrames from each run

    # Create a base directory for all runs
    script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in globals() else os.getcwd()
    base_output_dir = os.path.join(script_dir, base_output_dir_name)
    os.makedirs(base_output_dir, exist_ok=True)
    print(f"所有实验运行的基础目录: {base_output_dir}")


    for i in range(num_runs):
        # Set new seeds for each run if desired for variability,
        # or keep them fixed if you want identical runs (assuming no other stochasticity)
        current_seed = 42 + i # Example: vary seed per run
        torch.manual_seed(current_seed)
        np.random.seed(current_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(current_seed)
        print(f"Run {i+1}/{num_runs} with seed {current_seed}")

        summary_df = run_single_experiment(run_idx=i, base_output_dir=base_output_dir)
        if summary_df is not None and not summary_df.empty:
            summary_df['Run Index'] = i + 1 # Add run index for later aggregation
            all_run_summaries.append(summary_df)
        else:
            print(f"警告: 第 {i+1} 次运行没有返回有效的结果摘要。")


    if not all_run_summaries:
        print("所有实验运行均未产生有效结果。无法进行统计分析。")
        return

    # Concatenate all summary DataFrames
    combined_summary_df = pd.concat(all_run_summaries, ignore_index=True)
    combined_summary_path = os.path.join(base_output_dir, "combined_all_runs_summary.csv")
    combined_summary_df.to_csv(combined_summary_path, index=False)
    print(f"\n所有运行的合并摘要已保存到: {combined_summary_path}")

    # --- Perform statistical analysis ---
    print("\n" + "="*70)
    print("统计分析 (均值 ± 标准差)")
    print("="*70)

    # Columns for statistics
    metrics_to_analyze = ["Best Test Accuracy (%)", "Spikes at Convergence", "Training Time (s)"]
    grouping_columns = ["Config Name", "Model", "Delta/System"] # Columns to group by

    # Ensure metrics are numeric
    for metric in metrics_to_analyze:
        combined_summary_df[metric] = pd.to_numeric(combined_summary_df[metric], errors='coerce')

    # Calculate mean and std
    stats_summary = combined_summary_df.groupby(grouping_columns)[metrics_to_analyze].agg(['mean', 'std']).reset_index()

    # Format the output nicely
    stats_summary.columns = [' '.join(col).strip() for col in stats_summary.columns.values] # Flatten multi-index columns

    # Rename columns for clarity, e.g., "Best Test Accuracy (%) mean" to "Accuracy Mean"
    # This can be done more systematically if needed
    stats_summary.rename(columns={
        f'{metric} mean': f'{metric.split(" ")[0]} Mean' for metric in metrics_to_analyze
    }, inplace=True)
    stats_summary.rename(columns={
        f'{metric} std': f'{metric.split(" ")[0]} Std' for metric in metrics_to_analyze
    }, inplace=True)


    print("统计结果:")
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 2000)
    pd.set_option('display.max_colwidth', None)
    print(stats_summary.to_string(index=False))
    pd.reset_option('all')

    stats_summary_path = os.path.join(base_output_dir, "statistical_summary_across_runs.csv")
    stats_summary.to_csv(stats_summary_path, index=False)
    print(f"\n统计摘要已保存到: {stats_summary_path}")
    print("="*70)
    print("多次实验运行及分析完成!")
    print("="*70)


# --- Run Experiment ---
if __name__ == "__main__":
    # Optional: Set initial seeds for the first run if not varying them inside the loop
    # seed = 42
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic = True # Can impact performance
    # torch.backends.cudnn.benchmark = False

    # --- 配置运行次数 ---
    NUMBER_OF_RUNS = 5 # 您希望运行实验的次数

    # --- 配置基础输出目录名 ---
    BASE_OUTPUT_DIRECTORY_NAME = "TinyImageNet_ResNet18_ChaosSNN_Runs_128"

    run_multiple_experiments(num_runs=NUMBER_OF_RUNS, base_output_dir_name=BASE_OUTPUT_DIRECTORY_NAME)



所有实验运行的基础目录: /root/TinyImageNet_ResNet18_ChaosSNN_Runs_128
Run 1/5 with seed 42

开始第 1 次实验运行
使用设备: cuda
实验配置: TinyImageNet_ResNet-18_pretrained
结果将保存在: /root/TinyImageNet_ResNet18_ChaosSNN_Runs_128/run_1
Tiny ImageNet - Found 100000 training images belonging to 200 classes.
Tiny ImageNet - Found 10000 validation images.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.

--- 正在为 TinyImageNet_ResNet-18_pretrained 训练 Baseline (CNN-ANN) (运行 1) ---
--- Starting Training (Max Epochs: 200, Patience: 15) ---
Epoch [1/200] Loss: 4.0024 Train Acc: 13.30% Test Acc: 27.82% Total Spikes: 0 Epoch Time: 71.83s LR: 1.0e-05
  -> New best test accuracy: 27.82%. Model saved.
Epoch [2/200] Loss: 2.6472 Train Acc: 35.66% Test Acc: 44.80% Total Spikes: 0 Epoch Time: 70.30s LR: 1.0e-05
  -> New best test accuracy: 44.80%. Model saved.
Epoch [3/200] Loss: 2.1103 Train Acc: 47.