In [3]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder # Useful for Tiny ImageNet structure
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt # Keep if needed later
import numpy as np
import pandas as pd
import time
import copy
# from itertools import product # No longer needed for CNN configs
import os
import json
import snntorch as snn
import snntorch.surrogate as surrogate
from snntorch import utils
from snntorch import functional as SF
from PIL import Image # Needed for TinyImageNet loading potentially

# --- Configuration Class (Updated for Tiny ImageNet & ResNet) ---
class Config:
    # 数据集
    dataset_name = "TinyImageNet"
    data_root = './tiny-imagenet-200' # <<<--- IMPORTANT: SET PATH TO YOUR TINY IMAGENET FOLDER
    batch_size = 64 # May need to reduce based on GPU memory with ResNet
    input_size = 224 # Standard input size for ImageNet pre-trained models
    num_classes = 200 # Tiny ImageNet has 200 classes

    # CNN Backbone (Fixed to ResNet-18)
    backbone_name = "ResNet-18_pretrained"

    # Common SNN/Encoding Parameters
    # Projecting 512 ResNet features to 32 might be aggressive.
    # Consider increasing chaos_dim, e.g., 128 or 256, for better results.
    chaos_dim = 64 # Dimension for projection before oscillator/lorenz/SNN layers
    num_steps = 5 # Reduced num_steps initially for faster testing with ResNet

    # --- Oscillator Parameters ---
    osc_alpha = 2.0
    osc_beta = 0.1
    osc_gamma = 0.1
    osc_omega = 1.0
    osc_drive = 0.0
    # osc_delta will be set specifically
    osc_dt = 0.05

    # --- Lorenz Parameters ---
    lorenz_sigma = 10.0
    lorenz_rho = 28.0
    lorenz_beta = 8.0/3.0
    lorenz_dt = 0.05

    # SNN Decay Rate
    beta = 0.95 # SNN Leaky Neuron Beta (Decay Rate)

    # 训练
    epochs = 200 # Adjust max epochs for Tiny ImageNet & fine-tuning
    # Learning rate might need adjustment for fine-tuning ResNet
    lr = 1e-4 # Lower initial LR often better for fine-tuning
    weight_decay = 5e-4
    # Early Stopping
    patience = 15 # Maybe increase patience slightly

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_grad = surrogate.fast_sigmoid()

# --- OscillatorTransformFast Class (Unchanged) ---
class OscillatorTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.alpha = config.osc_alpha
        self.beta_osc = config.osc_beta
        self.gamma = config.osc_gamma
        self.delta = getattr(config, 'osc_delta', 0.0)
        self.omega = config.osc_omega
        self.dt = config.osc_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        current_delta = self.delta
        state = torch.cat([x, x*0.2, -x], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = y_cur
            dy = -self.alpha * x_cur - self.beta_osc * (x_cur**3) - current_delta * y_cur + self.gamma * z_cur
            dz = -self.omega * x_cur - current_delta * z_cur + self.gamma * x_cur * y_cur
            derivatives = torch.cat([dx, dy, dz], dim=1)
            state = state + self.dt * derivatives
            trajectories[:, t, :] = state
        return trajectories

# --- LorenzTransformFast Class (Unchanged) ---
class LorenzTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.sigma = config.lorenz_sigma
        self.rho = config.lorenz_rho
        self.lorenz_beta_param = config.lorenz_beta
        self.dt = config.lorenz_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        state = torch.cat([
            x,
            0.2*x,
            -x
        ], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = self.sigma * (y_cur - x_cur)
            dy = x_cur * (self.rho - z_cur) - y_cur
            dz = x_cur * y_cur - self.lorenz_beta_param * z_cur
            state = state + self.dt * torch.cat([dx, dy, dz], dim=1)
            trajectories[:, t, :] = state
        return trajectories

# --- Helper function to get ResNet-18 backbone ---
def _get_resnet_backbone(pretrained=True):
    """Loads a pretrained ResNet-18 model and removes the final fc layer."""
    if pretrained:
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        model = torchvision.models.resnet18(weights=weights)
        print("Loaded PRETRAINED ResNet-18 weights.")
    else:
        model = torchvision.models.resnet18(weights=None)
        print("Initialized ResNet-18 weights FROM SCRATCH.")

    # Remove the final fully connected layer (classifier)
    model.fc = nn.Identity() # Replace fc layer with identity
    return model

# --- CNNOscSNN Class (Unchanged, relies on config.num_classes) ---
class CNNOscSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.oscillator = OscillatorTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.oscillator(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- CNNLorenzSNN Class (Unchanged, relies on config.num_classes) ---
class CNNLorenzSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lorenz = LorenzTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.lorenz(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BasicCSNN Class (Unchanged, relies on config.num_classes) ---
class BasicCSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        cnn_features = torch.tanh(self.proj(features))
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        total_output_mem = 0
        batch_total_spikes = 0.0
        for _ in range(self.num_steps):
            spk1, mem1 = self.lif1(cnn_features, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            total_output_mem += self.fc_out(spk2)
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = total_output_mem / self.num_steps
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BaseCNN Class (Unchanged, relies on config.num_classes) ---
class BaseCNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.fc1 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.fc2 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.classifier = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, **kwargs):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.classifier(x)
        return x

# --- 修改后的 Evaluate 函数 - 返回 total_spikes 而不是 average ---
def evaluate(model, loader, config):
    """Evaluates the model, returns accuracy and total spikes."""
    model.eval()
    correct = 0
    total = 0
    total_spikes_evaluated = 0.0
    is_snn = isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN))

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            if is_snn:
                outputs, batch_spikes = model(images, return_spikes=True)
                total_spikes_evaluated += batch_spikes.item()
            else:
                outputs = model(images) # **kwargs in BaseCNN handles extra args
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100. * correct / total
    # 修改: 返回 total_spikes 而不是平均值
    return accuracy, total_spikes_evaluated


# --- 修改后的 Train and Evaluate with History ---
def train_and_evaluate_with_history(model, train_loader, test_loader, config):
    """Trains model with early stopping, returns best test acc, epoch history, and spikes at convergence."""
    model = model.to(device)
    # --- OPTIONAL: Differential Learning Rate ---
    # You might want different LRs for backbone and head
    # Example:
    head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone.')]
    backbone_params = [p for n, p in model.named_parameters() if n.startswith('backbone.')]
    optimizer = torch.optim.Adam([
        {'params': backbone_params, 'lr': config.lr * 0.1}, # Lower LR for backbone
        {'params': head_params, 'lr': config.lr} # Normal LR for head
    ], weight_decay=config.weight_decay)
    # --- Using single LR for simplicity now ---
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    criterion = nn.CrossEntropyLoss()

    best_test_acc_epoch = 0.0
    epochs_no_improve = 0
    patience = config.patience
    # Ensure output directory exists before saving temp model
    output_dir = f"experiment_results_{config.dataset_name}_{config.backbone_name}"
    os.makedirs(output_dir, exist_ok=True) # Ensure directory exists
    best_model_path = os.path.join(output_dir, f"temp_best_model_{time.time()}_{id(model)}.pth") # Save inside output dir
    
    # 修改: 添加 spike 记录
    history = []
    spike_counts = []  # 记录每个epoch的spikes
    best_epoch = 0  # 记录达到最佳性能的epoch

    print(f"--- Starting Training (Max Epochs: {config.epochs}, Patience: {patience}) ---")

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        start_epoch_time = time.time()

        for i, (images, labels) in enumerate(train_loader): # Add index i for progress printing
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images, return_spikes=False) # No need for spikes during training loop
            loss = criterion(outputs, labels)

            if not torch.isfinite(loss):
                print(f"WARNING: Loss is {loss.item()} at epoch {epoch+1}, batch {i}. Skipping backward pass.")
                continue # Skip batch if loss is invalid

            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            total_loss += loss.item()

            # Optional: Print progress within epoch
            # if (i + 1) % 100 == 0:
            #     print(f'  Epoch [{epoch+1}/{config.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')


        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time

        # 修改: 在每个epoch结束后评估并记录spikes
        test_acc, epoch_spikes = evaluate(model, test_loader, config)
        history.append(test_acc)
        spike_counts.append(epoch_spikes)  # 记录总spike数
        
        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = 100. * train_correct / train_total if train_total > 0 else 0

        # 修改: 打印spike信息
        print(f"Epoch [{epoch + 1}/{config.epochs}] Loss: {avg_loss:.4f} "
              f"Train Acc: {train_acc:.2f}% Test Acc: {test_acc:.2f}% "
              f"Total Spikes: {epoch_spikes:.0f} "
              f"Epoch Time: {epoch_duration:.2f}s LR: {scheduler.get_last_lr()[0]:.1e}")

        # Early Stopping Logic
        if test_acc > best_test_acc_epoch:
            best_test_acc_epoch = test_acc
            best_epoch = epoch  # 记录最佳epoch索引
            epochs_no_improve = 0
            try:
                torch.save(model.state_dict(), best_model_path)
                print(f"  -> New best test accuracy: {best_test_acc_epoch:.2f}%. Model saved.")
            except Exception as e:
                print(f"  -> Error saving model: {e}")
        else:
            epochs_no_improve += 1
            print(f"  -> Test accuracy did not improve for {epochs_no_improve} epoch(s). Best: {best_test_acc_epoch:.2f}%")

        if epochs_no_improve >= patience:
            print(f"\n--- Early stopping triggered after {epoch + 1} epochs. ---")
            break

        scheduler.step()

    # Post-Training: Load Best Model and Final Evaluation
    print("\n--- Training finished. Loading best model for final evaluation. ---")
    if os.path.exists(best_model_path):
        try:
            model.load_state_dict(torch.load(best_model_path))
            print(f"Successfully loaded best model state (Accuracy: {best_test_acc_epoch:.2f}%)")
            os.remove(best_model_path)
            # print(f"Removed temporary model file: {best_model_path}") # Optional: uncomment to confirm
        except Exception as e:
            print(f"Error loading best model state from {best_model_path}: {e}. Using model from last epoch.")
            best_test_acc_epoch = history[-1] if history else 0.0
    else:
        print("No best model was saved (or file missing). Using model from last epoch.")
        best_test_acc_epoch = history[-1] if history else 0.0

    print("--- Running final evaluation on the best performing model state ---")
    final_test_acc, final_total_spikes = evaluate(model, test_loader, config)

    # 修改: 获取收敛时的spike数
    spikes_at_convergence = spike_counts[best_epoch] if spike_counts and best_epoch < len(spike_counts) else 0

    print(f"Final Evaluation - Test Accuracy: {final_test_acc:.2f}%")
    if isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN)):
        print(f"Final Evaluation - Total Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}")
        print(f"Final Evaluation - Total Spikes in Final Evaluation: {final_total_spikes:.0f}")

    # 修改: 返回收敛时的spikes (不是平均值)
    return best_test_acc_epoch, history, spikes_at_convergence, spike_counts

# --- Tiny ImageNet Loading Function (未更改) ---
def load_tiny_imagenet(config):
    """Loads the Tiny ImageNet dataset."""
    data_dir = config.data_root
    num_workers = min(4, os.cpu_count()) if os.cpu_count() else 0
    image_size = config.input_size # Should be 224 for pre-trained ResNet

    # Standard ImageNet normalization
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Data augmentation and normalization for training
    # Adjust augmentation based on standard practices for ImageNet fine-tuning
    train_transform = transforms.Compose([
        transforms.Resize(image_size), # Resize first
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)), # Standard crop
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    # Just normalization for validation/testing
    val_transform = transforms.Compose([
        transforms.Resize(image_size), # Ensure validation images are also resized
        transforms.CenterCrop(image_size), # Use CenterCrop for validation
        transforms.ToTensor(),
        normalize,
    ])

    # --- Dataset Loading ---
    # Tiny ImageNet structure: train/[wnid]/images/*.JPEG, val/images/*.JPEG, val/val_annotations.txt
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val', 'images') # Validation images are flat

    if not os.path.exists(train_dir) or not os.path.exists(val_dir):
         raise FileNotFoundError(f"Tiny ImageNet data not found at expected paths: {train_dir} and {val_dir}. "
                                f"Please ensure Tiny ImageNet is downloaded and extracted to '{data_dir}' "
                                "following the standard directory structure.")

    train_dataset = ImageFolder(train_dir, transform=train_transform)

    # Validation dataset requires special handling due to annotations file
    # Creating a custom Dataset class is cleaner
    class TinyImageNetVal(Dataset):
        def __init__(self, val_dir, annotations_file, class_to_idx, transform=None):
            self.val_dir = val_dir
            self.transform = transform
            self.class_to_idx = class_to_idx
            self.samples = []
            try:
                with open(annotations_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            img_name, wnid = parts[0], parts[1]
                            img_path = os.path.join(self.val_dir, img_name)
                            if os.path.exists(img_path) and wnid in self.class_to_idx:
                                self.samples.append((img_path, self.class_to_idx[wnid]))
                            # else:
                                # print(f"Warning: Skipping invalid validation entry: {line.strip()}")
            except FileNotFoundError:
                 raise FileNotFoundError(f"Validation annotations file not found: {annotations_file}")


        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            img_path, target = self.samples[idx]
            # Ensure images are loaded in RGB format
            try:
                with open(img_path, 'rb') as f:
                    img = Image.open(f).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Return a dummy image/target or handle appropriately
                return torch.zeros(3, image_size, image_size), -1 # Indicate error

            if self.transform:
                img = self.transform(img)
            return img, target

    # Need class_to_idx mapping from the training set folders
    class_to_idx = train_dataset.class_to_idx
    val_annotations_file = os.path.join(data_dir, 'val', 'val_annotations.txt')
    val_dataset = TinyImageNetVal(val_dir, val_annotations_file, class_to_idx, transform=val_transform)


    print(f"Tiny ImageNet - Found {len(train_dataset)} training images belonging to {len(train_dataset.classes)} classes.")
    print(f"Tiny ImageNet - Found {len(val_dataset)} validation images.")


    # Data Loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    # Use validation set as the test set for Tiny ImageNet evaluation
    test_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader


# --- 修改 Experiment Logger Class ---
class ExperimentLogger:
    def __init__(self, output_dir="experiment_results_tinyimagenet_resnet18"): # Changed dir name
        self.results = {}
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        # 修改: 更新CSV标题，将Average Spikes改为Spikes at Convergence
        self.csv_headers = [
            "Config Name", # Renamed from CNN Config
            "Model",
            "Delta/System",
            "Best Test Accuracy (%)",
            "Spikes at Convergence", # 修改: 更改列名
            "Training Time (s)",
            "Epochs Trained",
            "Convergence Epoch" # 添加收敛epoch的列
        ]

    def log_config(self, config_name, config):
        config_dict = {}
        config_dict['patience'] = config.patience
        for key, value in vars(config).items():
             if isinstance(value, (int, float, str, bool, list, tuple, dict, type(None))):
                 config_dict[key] = value
             # ... (rest of serialization logic from previous version)
        self.results[config_name] = {
            "config": config_dict,
            "models": {}
        }

    # 修改: 更新log_model_result参数和逻辑
    def log_model_result(self, config_name, model_name, accuracy, training_time, epochs_history, spikes_at_convergence, spike_counts, best_epoch):
        if config_name not in self.results:
            self.results[config_name] = {"config": {}, "models": {}}
            print(f"Warning: Config '{config_name}' not pre-logged. Creating entry.")
        epochs_trained = len(epochs_history)
        self.results[config_name]["models"][model_name] = {
            "accuracy": accuracy,
            "training_time": training_time,
            "epochs_history": epochs_history, # Storing history can make JSON large
            "spikes_at_convergence": spikes_at_convergence, # 修改: 改为收敛时的spikes
            "epochs_trained": epochs_trained,
            "spike_counts": spike_counts, # 添加: 存储每个epoch的spike数据
            "convergence_epoch": best_epoch + 1 # 添加: 转换为1-indexed的epoch号
        }
        print(f"Logged for {config_name} - {model_name}: Best Acc={accuracy:.2f}, Spikes at Convergence={spikes_at_convergence:.0f}, "
              f"Time={training_time:.2f}s, Epochs={epochs_trained}, Convergence Epoch={best_epoch+1}")

    def save_results(self):
        filepath = os.path.join(self.output_dir, "results.json")
        try:
            # Save only essential results to JSON to avoid large files due to history
            results_to_save = {}
            for cfg_name, cfg_data in self.results.items():
                results_to_save[cfg_name] = {"config": cfg_data["config"], "models": {}}
                for mdl_name, mdl_data in cfg_data["models"].items():
                    results_to_save[cfg_name]["models"][mdl_name] = {
                        k: v for k, v in mdl_data.items() if k not in ['epochs_history', 'spike_counts']
                    }
            with open(filepath, "w") as f:
                json.dump(results_to_save, f, indent=4, default=lambda o: '<not serializable>')
            print(f"Results saved to {filepath}")
            
            # 添加: 保存spike数据到CSV文件
            for cfg_name, cfg_data in self.results.items():
                for mdl_name, mdl_data in cfg_data["models"].items():
                    if 'spike_counts' in mdl_data and mdl_data['spike_counts']:
                        spike_df = pd.DataFrame({
                            'Epoch': range(1, len(mdl_data['spike_counts'])+1),
                            'Total Spikes': mdl_data['spike_counts']
                        })
                        spike_file = os.path.join(self.output_dir, f"{cfg_name}_{mdl_name}_spikes.csv")
                        spike_df.to_csv(spike_file, index=False)
                        print(f"Spike data saved to {spike_file}")
        except Exception as e:
            print(f"Error saving JSON results: {e}")


    def generate_summary_table(self):
        rows = []
        for config_name, data in self.results.items():
            for model_name, model_data in data["models"].items():
                delta_str = "N/A"
                # Simplified model type identification for summary
                if "Osc-SNN Delta=" in model_name:
                    delta_str = f"Osc(Δ={model_name.split('=')[-1].split(')')[0]})"
                elif "Lorenz-SNN" in model_name:
                    delta_str = "Lorenz"
                elif "CNN-SNN" in model_name and "Osc" not in model_name and "Lorenz" not in model_name:
                     delta_str = "Direct SNN"
                elif "CNN-ANN" in model_name:
                     delta_str = "ANN"

                spikes = model_data.get("spikes_at_convergence", 0.0)  # 修改: 使用收敛时的spikes
                epochs_trained = model_data.get("epochs_trained", "N/A")
                convergence_epoch = model_data.get("convergence_epoch", "N/A")  # 添加: 获取收敛epoch

                row = {
                    self.csv_headers[0]: config_name, # Use "Config Name"
                    self.csv_headers[1]: model_name,
                    self.csv_headers[2]: delta_str,
                    self.csv_headers[3]: model_data["accuracy"],
                    self.csv_headers[4]: spikes,  # 修改: 使用收敛时的spikes
                    self.csv_headers[5]: model_data["training_time"],
                    self.csv_headers[6]: epochs_trained,
                    self.csv_headers[7]: convergence_epoch  # 添加: 收敛epoch
                }
                rows.append(row)
        if not rows: return pd.DataFrame(columns=self.csv_headers)
        df = pd.DataFrame(rows)
        df = df[self.csv_headers] # Ensure column order
        try:
            df[self.csv_headers[3]] = pd.to_numeric(df[self.csv_headers[3]], errors='coerce').map('{:.2f}'.format)
            df[self.csv_headers[4]] = pd.to_numeric(df[self.csv_headers[4]], errors='coerce').map('{:.0f}'.format)  # 修改: 整数格式
            df[self.csv_headers[5]] = pd.to_numeric(df[self.csv_headers[5]], errors='coerce').map('{:.2f}'.format)
        except Exception as e: print(f"Error formatting summary table columns: {e}")
        filepath = os.path.join(self.output_dir, "summary.csv")
        try: df.to_csv(filepath, index=False); print(f"Summary table saved to {filepath}")
        except Exception as e: print(f"Error saving summary CSV: {e}")
        return df

    # --- Plotting (Requires Matplotlib) ---
    # Consider simplifying or removing plotting if matplotlib is not available/needed now
    # def plot_results(self):
    #     # ... (Plotting code from previous version - needs matplotlib)
    #     # If keeping plots, update logic to handle single config name and model types
    #     pass

# --- 修改 Main Experiment Function ---
def run_experiment():
    print(f"使用设备: {device}")
    config = Config() # Use the updated config
    config_name = f"{config.dataset_name}_{config.backbone_name}" # Single config name
    logger = ExperimentLogger(output_dir=f"experiment_results_{config_name}")

    # Define the specific oscillator deltas for the two modes
    # Use values inspired by Table D1, e.g., one expansive, one dissipative
    osc_delta_mode_b = -1.5 # High performance potential (Expansive)
    osc_delta_mode_a = 10.0 # High efficiency (Dissipative)

    print(f"\n{'=' * 60}")
    print(f"开始实验配置: {config_name}")
    print(f"Dataset: {config.dataset_name} (Root: {config.data_root}, Classes: {config.num_classes}, Input Size: {config.input_size})")
    print(f"Backbone: {config.backbone_name}")
    print(f"Max Epochs: {config.epochs}, Patience: {config.patience}, LR: {config.lr}")
    print(f"Osc Params: alpha={config.osc_alpha}, beta={config.osc_beta}, gamma={config.osc_gamma}, omega={config.osc_omega}, dt={config.osc_dt}")
    print(f"Lorenz Params: sigma={config.lorenz_sigma:.2f}, rho={config.lorenz_rho:.2f}, beta={config.lorenz_beta:.2f}, dt={config.lorenz_dt}")
    print(f"SNN Params: steps={config.num_steps}, decay_beta={config.beta:.2f}, chaos_dim={config.chaos_dim}")
    print(f"{'=' * 60}")

    logger.log_config(config_name, config)

    try:
        # Load Tiny ImageNet data
        train_loader, test_loader = load_tiny_imagenet(config)
    except Exception as e:
        print(f"无法加载数据集 {config_name}. 终止实验. 错误: {e}")
        import traceback
        traceback.print_exc()
        return None # Exit if data loading fails

    # --- Define Models to Run ---
    # Ensure pretrained_backbone=True is passed correctly
    models_to_run = {
        "Baseline (CNN-ANN)": BaseCNN(config, pretrained_backbone=True),
        "Baseline (CNN-SNN)": BasicCSNN(config, pretrained_backbone=True),
        "Proposed (CNN-Lorenz-SNN)": CNNLorenzSNN(config, pretrained_backbone=True)
    }
    # Add Oscillator SNNs with specific deltas
    osc_config_b = copy.deepcopy(config); osc_config_b.osc_delta = osc_delta_mode_b
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_b})"] = CNNOscSNN(osc_config_b, pretrained_backbone=True)

    osc_config_a = copy.deepcopy(config); osc_config_a.osc_delta = osc_delta_mode_a
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_a})"] = CNNOscSNN(osc_config_a, pretrained_backbone=True)


    # --- Train and Evaluate Models ---
    for model_name, model_instance in models_to_run.items():
        print(f"\n--- Training {model_name} for {config_name} ---")
        start_time = time.time()
        # Use a fresh copy for each training run (already done by creating new instances above)
        current_model = model_instance
        current_config = config # Default config for non-oscillator models
        if "Osc-SNN Delta=" in model_name:
             delta_val = float(model_name.split('=')[-1].split(')')[0])
             if delta_val == osc_delta_mode_a: current_config = osc_config_a
             elif delta_val == osc_delta_mode_b: current_config = osc_config_b
             else: print(f"Warning: Could not match delta for {model_name}")


        try:
            # 修改: 添加对spike_counts和best_epoch的接收
            best_acc, epochs_history, spikes_at_convergence, spike_counts = train_and_evaluate_with_history(
                current_model, train_loader, test_loader, current_config # Pass correct config
            )
            end_time = time.time()
            training_time = end_time - start_time
            
            # 计算best_epoch (0-indexed)
            best_epoch = epochs_history.index(best_acc) if best_acc in epochs_history else 0
            
            print(f"--- {model_name} finished. Best Acc: {best_acc:.2f}%, "
                 f"Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}, "
                 f"Time: {training_time:.2f}s, Epochs Trained: {len(epochs_history)} ---")
                 
            # 修改: 更新logger.log_model_result调用
            logger.log_model_result(
                config_name, model_name, best_acc, training_time, 
                epochs_history, spikes_at_convergence, spike_counts, best_epoch
            )
        except Exception as e:
            end_time = time.time()
            training_time = end_time - start_time
            print(f"!!! ERROR during training/evaluation for {model_name} on {config_name}: {e}")
            # 修改: 更新错误情况下的logger.log_model_result调用
            logger.log_model_result(config_name, model_name, 0.0, training_time, [], 0.0, [], 0)
            import traceback
            traceback.print_exc()
            # Decide whether to continue with other models upon error
            # continue


    # --- Finalize and Save Results ---
    print("\n--- All models processed for this configuration ---")
    logger.save_results()
    summary_df = logger.generate_summary_table()
    # try:
    #     logger.plot_results() # Optional plotting
    # except Exception as e:
    #     print(f"An error occurred during final plotting: {e}")
    #     import traceback
    #     traceback.print_exc()

    print("\n" + "="*60)
    print(f"实验完成! 结果保存在 '{logger.output_dir}' 目录中")
    print("="*60 + "\n")
    if not summary_df.empty:
        print("结果汇总:")
        # Configure pandas for wider output
        pd.set_option('display.max_rows', None)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 2000)
        pd.set_option('display.max_colwidth', None) # Show full column width
        print(summary_df.to_string(index=False)) # Use to_string for better control
        pd.reset_option('all') # Reset pandas display options
    else:
        print("结果汇总为空.")

    return logger

# --- Run Experiment ---
if __name__ == "__main__":
    # Optional: Set seeds for reproducibility
    # seed = 42
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)
    # # Note: Full determinism can impact performance and might not be guaranteed on GPU
    # # torch.backends.cudnn.deterministic = True
    # # torch.backends.cudnn.benchmark = False

    logger = run_experiment()

使用设备: cuda

开始实验配置: TinyImageNet_ResNet-18_pretrained
Dataset: TinyImageNet (Root: ./tiny-imagenet-200, Classes: 200, Input Size: 224)
Backbone: ResNet-18_pretrained
Max Epochs: 200, Patience: 15, LR: 0.0001
Osc Params: alpha=2.0, beta=0.1, gamma=0.1, omega=1.0, dt=0.05
Lorenz Params: sigma=10.00, rho=28.00, beta=2.67, dt=0.05
SNN Params: steps=5, decay_beta=0.95, chaos_dim=64
Tiny ImageNet - Found 100000 training images belonging to 200 classes.
Tiny ImageNet - Found 10000 validation images.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.

--- Training Baseline (CNN-ANN) for TinyImageNet_ResNet-18_pretrained ---
--- Starting Training (Max Epochs: 200, Patience: 15) ---
Epoch [1/200] Loss: 4.4196 Train Acc: 8.08% Test Acc: 19.44% Total Spikes: 0 Epoch Time: 50.72s LR: 1.0e-05
  -> New best test accuracy: 19.44%. Model saved.
Epoch [2/200] Loss: 3.1506

In [5]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder # Useful for Tiny ImageNet structure
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt # Keep if needed later
import numpy as np
import pandas as pd
import time
import copy
# from itertools import product # No longer needed for CNN configs
import os
import json
import snntorch as snn
import snntorch.surrogate as surrogate
from snntorch import utils
from snntorch import functional as SF
from PIL import Image # Needed for TinyImageNet loading potentially

# --- Configuration Class (Updated for Tiny ImageNet & ResNet) ---
class Config:
    # 数据集
    dataset_name = "TinyImageNet"
    data_root = './tiny-imagenet-200' # <<<--- IMPORTANT: SET PATH TO YOUR TINY IMAGENET FOLDER
    batch_size = 64 # May need to reduce based on GPU memory with ResNet
    input_size = 224 # Standard input size for ImageNet pre-trained models
    num_classes = 200 # Tiny ImageNet has 200 classes

    # CNN Backbone (Fixed to ResNet-18)
    backbone_name = "ResNet-18_pretrained"

    # Common SNN/Encoding Parameters
    # Projecting 512 ResNet features to 32 might be aggressive.
    # Consider increasing chaos_dim, e.g., 128 or 256, for better results.
    chaos_dim = 128 # Dimension for projection before oscillator/lorenz/SNN layers
    num_steps = 5 # Reduced num_steps initially for faster testing with ResNet

    # --- Oscillator Parameters ---
    osc_alpha = 2.0
    osc_beta = 0.1
    osc_gamma = 0.1
    osc_omega = 1.0
    osc_drive = 0.0
    # osc_delta will be set specifically
    osc_dt = 0.05

    # --- Lorenz Parameters ---
    lorenz_sigma = 10.0
    lorenz_rho = 28.0
    lorenz_beta = 8.0/3.0
    lorenz_dt = 0.05

    # SNN Decay Rate
    beta = 0.95 # SNN Leaky Neuron Beta (Decay Rate)

    # 训练
    epochs = 200 # Adjust max epochs for Tiny ImageNet & fine-tuning
    # Learning rate might need adjustment for fine-tuning ResNet
    lr = 1e-4 # Lower initial LR often better for fine-tuning
    weight_decay = 5e-4
    # Early Stopping
    patience = 15 # Maybe increase patience slightly

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_grad = surrogate.fast_sigmoid()

# --- OscillatorTransformFast Class (Unchanged) ---
class OscillatorTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.alpha = config.osc_alpha
        self.beta_osc = config.osc_beta
        self.gamma = config.osc_gamma
        self.delta = getattr(config, 'osc_delta', 0.0)
        self.omega = config.osc_omega
        self.dt = config.osc_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        current_delta = self.delta
        state = torch.cat([x, x*0.2, -x], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = y_cur
            dy = -self.alpha * x_cur - self.beta_osc * (x_cur**3) - current_delta * y_cur + self.gamma * z_cur
            dz = -self.omega * x_cur - current_delta * z_cur + self.gamma * x_cur * y_cur
            derivatives = torch.cat([dx, dy, dz], dim=1)
            state = state + self.dt * derivatives
            trajectories[:, t, :] = state
        return trajectories

# --- LorenzTransformFast Class (Unchanged) ---
class LorenzTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.sigma = config.lorenz_sigma
        self.rho = config.lorenz_rho
        self.lorenz_beta_param = config.lorenz_beta
        self.dt = config.lorenz_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        state = torch.cat([
            x,
            0.2*x,
            -x
        ], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = self.sigma * (y_cur - x_cur)
            dy = x_cur * (self.rho - z_cur) - y_cur
            dz = x_cur * y_cur - self.lorenz_beta_param * z_cur
            state = state + self.dt * torch.cat([dx, dy, dz], dim=1)
            trajectories[:, t, :] = state
        return trajectories

# --- Helper function to get ResNet-18 backbone ---
def _get_resnet_backbone(pretrained=True):
    """Loads a pretrained ResNet-18 model and removes the final fc layer."""
    if pretrained:
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        model = torchvision.models.resnet18(weights=weights)
        print("Loaded PRETRAINED ResNet-18 weights.")
    else:
        model = torchvision.models.resnet18(weights=None)
        print("Initialized ResNet-18 weights FROM SCRATCH.")

    # Remove the final fully connected layer (classifier)
    model.fc = nn.Identity() # Replace fc layer with identity
    return model

# --- CNNOscSNN Class (Unchanged, relies on config.num_classes) ---
class CNNOscSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.oscillator = OscillatorTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.oscillator(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- CNNLorenzSNN Class (Unchanged, relies on config.num_classes) ---
class CNNLorenzSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lorenz = LorenzTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.lorenz(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BasicCSNN Class (Unchanged, relies on config.num_classes) ---
class BasicCSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        cnn_features = torch.tanh(self.proj(features))
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        total_output_mem = 0
        batch_total_spikes = 0.0
        for _ in range(self.num_steps):
            spk1, mem1 = self.lif1(cnn_features, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            total_output_mem += self.fc_out(spk2)
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = total_output_mem / self.num_steps
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BaseCNN Class (Unchanged, relies on config.num_classes) ---
class BaseCNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.fc1 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.fc2 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.classifier = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, **kwargs):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.classifier(x)
        return x

# --- 修改后的 Evaluate 函数 - 返回 total_spikes 而不是 average ---
def evaluate(model, loader, config):
    """Evaluates the model, returns accuracy and total spikes."""
    model.eval()
    correct = 0
    total = 0
    total_spikes_evaluated = 0.0
    is_snn = isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN))

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            if is_snn:
                outputs, batch_spikes = model(images, return_spikes=True)
                total_spikes_evaluated += batch_spikes.item()
            else:
                outputs = model(images) # **kwargs in BaseCNN handles extra args
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100. * correct / total
    # 修改: 返回 total_spikes 而不是平均值
    return accuracy, total_spikes_evaluated


# --- 修改后的 Train and Evaluate with History ---
def train_and_evaluate_with_history(model, train_loader, test_loader, config):
    """Trains model with early stopping, returns best test acc, epoch history, and spikes at convergence."""
    model = model.to(device)
    # --- OPTIONAL: Differential Learning Rate ---
    # You might want different LRs for backbone and head
    # Example:
    head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone.')]
    backbone_params = [p for n, p in model.named_parameters() if n.startswith('backbone.')]
    optimizer = torch.optim.Adam([
        {'params': backbone_params, 'lr': config.lr * 0.1}, # Lower LR for backbone
        {'params': head_params, 'lr': config.lr} # Normal LR for head
    ], weight_decay=config.weight_decay)
    # --- Using single LR for simplicity now ---
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    criterion = nn.CrossEntropyLoss()

    best_test_acc_epoch = 0.0
    epochs_no_improve = 0
    patience = config.patience
    # Ensure output directory exists before saving temp model
    output_dir = f"experiment_results_{config.dataset_name}_{config.backbone_name}"
    os.makedirs(output_dir, exist_ok=True) # Ensure directory exists
    best_model_path = os.path.join(output_dir, f"temp_best_model_{time.time()}_{id(model)}.pth") # Save inside output dir
    
    # 修改: 添加 spike 记录
    history = []
    spike_counts = []  # 记录每个epoch的spikes
    best_epoch = 0  # 记录达到最佳性能的epoch

    print(f"--- Starting Training (Max Epochs: {config.epochs}, Patience: {patience}) ---")

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        start_epoch_time = time.time()

        for i, (images, labels) in enumerate(train_loader): # Add index i for progress printing
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images, return_spikes=False) # No need for spikes during training loop
            loss = criterion(outputs, labels)

            if not torch.isfinite(loss):
                print(f"WARNING: Loss is {loss.item()} at epoch {epoch+1}, batch {i}. Skipping backward pass.")
                continue # Skip batch if loss is invalid

            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            total_loss += loss.item()

            # Optional: Print progress within epoch
            # if (i + 1) % 100 == 0:
            #     print(f'  Epoch [{epoch+1}/{config.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')


        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time

        # 修改: 在每个epoch结束后评估并记录spikes
        test_acc, epoch_spikes = evaluate(model, test_loader, config)
        history.append(test_acc)
        spike_counts.append(epoch_spikes)  # 记录总spike数
        
        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = 100. * train_correct / train_total if train_total > 0 else 0

        # 修改: 打印spike信息
        print(f"Epoch [{epoch + 1}/{config.epochs}] Loss: {avg_loss:.4f} "
              f"Train Acc: {train_acc:.2f}% Test Acc: {test_acc:.2f}% "
              f"Total Spikes: {epoch_spikes:.0f} "
              f"Epoch Time: {epoch_duration:.2f}s LR: {scheduler.get_last_lr()[0]:.1e}")

        # Early Stopping Logic
        if test_acc > best_test_acc_epoch:
            best_test_acc_epoch = test_acc
            best_epoch = epoch  # 记录最佳epoch索引
            epochs_no_improve = 0
            try:
                torch.save(model.state_dict(), best_model_path)
                print(f"  -> New best test accuracy: {best_test_acc_epoch:.2f}%. Model saved.")
            except Exception as e:
                print(f"  -> Error saving model: {e}")
        else:
            epochs_no_improve += 1
            print(f"  -> Test accuracy did not improve for {epochs_no_improve} epoch(s). Best: {best_test_acc_epoch:.2f}%")

        if epochs_no_improve >= patience:
            print(f"\n--- Early stopping triggered after {epoch + 1} epochs. ---")
            break

        scheduler.step()

    # Post-Training: Load Best Model and Final Evaluation
    print("\n--- Training finished. Loading best model for final evaluation. ---")
    if os.path.exists(best_model_path):
        try:
            model.load_state_dict(torch.load(best_model_path))
            print(f"Successfully loaded best model state (Accuracy: {best_test_acc_epoch:.2f}%)")
            os.remove(best_model_path)
            # print(f"Removed temporary model file: {best_model_path}") # Optional: uncomment to confirm
        except Exception as e:
            print(f"Error loading best model state from {best_model_path}: {e}. Using model from last epoch.")
            best_test_acc_epoch = history[-1] if history else 0.0
    else:
        print("No best model was saved (or file missing). Using model from last epoch.")
        best_test_acc_epoch = history[-1] if history else 0.0

    print("--- Running final evaluation on the best performing model state ---")
    final_test_acc, final_total_spikes = evaluate(model, test_loader, config)

    # 修改: 获取收敛时的spike数
    spikes_at_convergence = spike_counts[best_epoch] if spike_counts and best_epoch < len(spike_counts) else 0

    print(f"Final Evaluation - Test Accuracy: {final_test_acc:.2f}%")
    if isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN)):
        print(f"Final Evaluation - Total Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}")
        print(f"Final Evaluation - Total Spikes in Final Evaluation: {final_total_spikes:.0f}")

    # 修改: 返回收敛时的spikes (不是平均值)
    return best_test_acc_epoch, history, spikes_at_convergence, spike_counts

# --- Tiny ImageNet Loading Function (未更改) ---
def load_tiny_imagenet(config):
    """Loads the Tiny ImageNet dataset."""
    data_dir = config.data_root
    num_workers = min(4, os.cpu_count()) if os.cpu_count() else 0
    image_size = config.input_size # Should be 224 for pre-trained ResNet

    # Standard ImageNet normalization
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Data augmentation and normalization for training
    # Adjust augmentation based on standard practices for ImageNet fine-tuning
    train_transform = transforms.Compose([
        transforms.Resize(image_size), # Resize first
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)), # Standard crop
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    # Just normalization for validation/testing
    val_transform = transforms.Compose([
        transforms.Resize(image_size), # Ensure validation images are also resized
        transforms.CenterCrop(image_size), # Use CenterCrop for validation
        transforms.ToTensor(),
        normalize,
    ])

    # --- Dataset Loading ---
    # Tiny ImageNet structure: train/[wnid]/images/*.JPEG, val/images/*.JPEG, val/val_annotations.txt
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val', 'images') # Validation images are flat

    if not os.path.exists(train_dir) or not os.path.exists(val_dir):
         raise FileNotFoundError(f"Tiny ImageNet data not found at expected paths: {train_dir} and {val_dir}. "
                                f"Please ensure Tiny ImageNet is downloaded and extracted to '{data_dir}' "
                                "following the standard directory structure.")

    train_dataset = ImageFolder(train_dir, transform=train_transform)

    # Validation dataset requires special handling due to annotations file
    # Creating a custom Dataset class is cleaner
    class TinyImageNetVal(Dataset):
        def __init__(self, val_dir, annotations_file, class_to_idx, transform=None):
            self.val_dir = val_dir
            self.transform = transform
            self.class_to_idx = class_to_idx
            self.samples = []
            try:
                with open(annotations_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            img_name, wnid = parts[0], parts[1]
                            img_path = os.path.join(self.val_dir, img_name)
                            if os.path.exists(img_path) and wnid in self.class_to_idx:
                                self.samples.append((img_path, self.class_to_idx[wnid]))
                            # else:
                                # print(f"Warning: Skipping invalid validation entry: {line.strip()}")
            except FileNotFoundError:
                 raise FileNotFoundError(f"Validation annotations file not found: {annotations_file}")


        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            img_path, target = self.samples[idx]
            # Ensure images are loaded in RGB format
            try:
                with open(img_path, 'rb') as f:
                    img = Image.open(f).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Return a dummy image/target or handle appropriately
                return torch.zeros(3, image_size, image_size), -1 # Indicate error

            if self.transform:
                img = self.transform(img)
            return img, target

    # Need class_to_idx mapping from the training set folders
    class_to_idx = train_dataset.class_to_idx
    val_annotations_file = os.path.join(data_dir, 'val', 'val_annotations.txt')
    val_dataset = TinyImageNetVal(val_dir, val_annotations_file, class_to_idx, transform=val_transform)


    print(f"Tiny ImageNet - Found {len(train_dataset)} training images belonging to {len(train_dataset.classes)} classes.")
    print(f"Tiny ImageNet - Found {len(val_dataset)} validation images.")


    # Data Loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    # Use validation set as the test set for Tiny ImageNet evaluation
    test_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader


# --- 修改 Experiment Logger Class ---
class ExperimentLogger:
    def __init__(self, output_dir="experiment_results_tinyimagenet_resnet18"): # Changed dir name
        self.results = {}
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        # 修改: 更新CSV标题，将Average Spikes改为Spikes at Convergence
        self.csv_headers = [
            "Config Name", # Renamed from CNN Config
            "Model",
            "Delta/System",
            "Best Test Accuracy (%)",
            "Spikes at Convergence", # 修改: 更改列名
            "Training Time (s)",
            "Epochs Trained",
            "Convergence Epoch" # 添加收敛epoch的列
        ]

    def log_config(self, config_name, config):
        config_dict = {}
        config_dict['patience'] = config.patience
        for key, value in vars(config).items():
             if isinstance(value, (int, float, str, bool, list, tuple, dict, type(None))):
                 config_dict[key] = value
             # ... (rest of serialization logic from previous version)
        self.results[config_name] = {
            "config": config_dict,
            "models": {}
        }

    # 修改: 更新log_model_result参数和逻辑
    def log_model_result(self, config_name, model_name, accuracy, training_time, epochs_history, spikes_at_convergence, spike_counts, best_epoch):
        if config_name not in self.results:
            self.results[config_name] = {"config": {}, "models": {}}
            print(f"Warning: Config '{config_name}' not pre-logged. Creating entry.")
        epochs_trained = len(epochs_history)
        self.results[config_name]["models"][model_name] = {
            "accuracy": accuracy,
            "training_time": training_time,
            "epochs_history": epochs_history, # Storing history can make JSON large
            "spikes_at_convergence": spikes_at_convergence, # 修改: 改为收敛时的spikes
            "epochs_trained": epochs_trained,
            "spike_counts": spike_counts, # 添加: 存储每个epoch的spike数据
            "convergence_epoch": best_epoch + 1 # 添加: 转换为1-indexed的epoch号
        }
        print(f"Logged for {config_name} - {model_name}: Best Acc={accuracy:.2f}, Spikes at Convergence={spikes_at_convergence:.0f}, "
              f"Time={training_time:.2f}s, Epochs={epochs_trained}, Convergence Epoch={best_epoch+1}")

    def save_results(self):
        filepath = os.path.join(self.output_dir, "results.json")
        try:
            # Save only essential results to JSON to avoid large files due to history
            results_to_save = {}
            for cfg_name, cfg_data in self.results.items():
                results_to_save[cfg_name] = {"config": cfg_data["config"], "models": {}}
                for mdl_name, mdl_data in cfg_data["models"].items():
                    results_to_save[cfg_name]["models"][mdl_name] = {
                        k: v for k, v in mdl_data.items() if k not in ['epochs_history', 'spike_counts']
                    }
            with open(filepath, "w") as f:
                json.dump(results_to_save, f, indent=4, default=lambda o: '<not serializable>')
            print(f"Results saved to {filepath}")
            
            # 添加: 保存spike数据到CSV文件
            for cfg_name, cfg_data in self.results.items():
                for mdl_name, mdl_data in cfg_data["models"].items():
                    if 'spike_counts' in mdl_data and mdl_data['spike_counts']:
                        spike_df = pd.DataFrame({
                            'Epoch': range(1, len(mdl_data['spike_counts'])+1),
                            'Total Spikes': mdl_data['spike_counts']
                        })
                        spike_file = os.path.join(self.output_dir, f"{cfg_name}_{mdl_name}_spikes.csv")
                        spike_df.to_csv(spike_file, index=False)
                        print(f"Spike data saved to {spike_file}")
        except Exception as e:
            print(f"Error saving JSON results: {e}")


    def generate_summary_table(self):
        rows = []
        for config_name, data in self.results.items():
            for model_name, model_data in data["models"].items():
                delta_str = "N/A"
                # Simplified model type identification for summary
                if "Osc-SNN Delta=" in model_name:
                    delta_str = f"Osc(Δ={model_name.split('=')[-1].split(')')[0]})"
                elif "Lorenz-SNN" in model_name:
                    delta_str = "Lorenz"
                elif "CNN-SNN" in model_name and "Osc" not in model_name and "Lorenz" not in model_name:
                     delta_str = "Direct SNN"
                elif "CNN-ANN" in model_name:
                     delta_str = "ANN"

                spikes = model_data.get("spikes_at_convergence", 0.0)  # 修改: 使用收敛时的spikes
                epochs_trained = model_data.get("epochs_trained", "N/A")
                convergence_epoch = model_data.get("convergence_epoch", "N/A")  # 添加: 获取收敛epoch

                row = {
                    self.csv_headers[0]: config_name, # Use "Config Name"
                    self.csv_headers[1]: model_name,
                    self.csv_headers[2]: delta_str,
                    self.csv_headers[3]: model_data["accuracy"],
                    self.csv_headers[4]: spikes,  # 修改: 使用收敛时的spikes
                    self.csv_headers[5]: model_data["training_time"],
                    self.csv_headers[6]: epochs_trained,
                    self.csv_headers[7]: convergence_epoch  # 添加: 收敛epoch
                }
                rows.append(row)
        if not rows: return pd.DataFrame(columns=self.csv_headers)
        df = pd.DataFrame(rows)
        df = df[self.csv_headers] # Ensure column order
        try:
            df[self.csv_headers[3]] = pd.to_numeric(df[self.csv_headers[3]], errors='coerce').map('{:.2f}'.format)
            df[self.csv_headers[4]] = pd.to_numeric(df[self.csv_headers[4]], errors='coerce').map('{:.0f}'.format)  # 修改: 整数格式
            df[self.csv_headers[5]] = pd.to_numeric(df[self.csv_headers[5]], errors='coerce').map('{:.2f}'.format)
        except Exception as e: print(f"Error formatting summary table columns: {e}")
        filepath = os.path.join(self.output_dir, "summary.csv")
        try: df.to_csv(filepath, index=False); print(f"Summary table saved to {filepath}")
        except Exception as e: print(f"Error saving summary CSV: {e}")
        return df

    # --- Plotting (Requires Matplotlib) ---
    # Consider simplifying or removing plotting if matplotlib is not available/needed now
    # def plot_results(self):
    #     # ... (Plotting code from previous version - needs matplotlib)
    #     # If keeping plots, update logic to handle single config name and model types
    #     pass

# --- 修改 Main Experiment Function ---
def run_experiment():
    print(f"使用设备: {device}")
    config = Config() # Use the updated config
    config_name = f"{config.dataset_name}_{config.backbone_name}" # Single config name
    logger = ExperimentLogger(output_dir=f"experiment_results_{config_name}")

    # Define the specific oscillator deltas for the two modes
    # Use values inspired by Table D1, e.g., one expansive, one dissipative
    osc_delta_mode_b = -1.5 # High performance potential (Expansive)
    osc_delta_mode_a = 10.0 # High efficiency (Dissipative)

    print(f"\n{'=' * 60}")
    print(f"开始实验配置: {config_name}")
    print(f"Dataset: {config.dataset_name} (Root: {config.data_root}, Classes: {config.num_classes}, Input Size: {config.input_size})")
    print(f"Backbone: {config.backbone_name}")
    print(f"Max Epochs: {config.epochs}, Patience: {config.patience}, LR: {config.lr}")
    print(f"Osc Params: alpha={config.osc_alpha}, beta={config.osc_beta}, gamma={config.osc_gamma}, omega={config.osc_omega}, dt={config.osc_dt}")
    print(f"Lorenz Params: sigma={config.lorenz_sigma:.2f}, rho={config.lorenz_rho:.2f}, beta={config.lorenz_beta:.2f}, dt={config.lorenz_dt}")
    print(f"SNN Params: steps={config.num_steps}, decay_beta={config.beta:.2f}, chaos_dim={config.chaos_dim}")
    print(f"{'=' * 60}")

    logger.log_config(config_name, config)

    try:
        # Load Tiny ImageNet data
        train_loader, test_loader = load_tiny_imagenet(config)
    except Exception as e:
        print(f"无法加载数据集 {config_name}. 终止实验. 错误: {e}")
        import traceback
        traceback.print_exc()
        return None # Exit if data loading fails

    # --- Define Models to Run ---
    # Ensure pretrained_backbone=True is passed correctly
    models_to_run = {
        "Baseline (CNN-ANN)": BaseCNN(config, pretrained_backbone=True),
        "Baseline (CNN-SNN)": BasicCSNN(config, pretrained_backbone=True),
        "Proposed (CNN-Lorenz-SNN)": CNNLorenzSNN(config, pretrained_backbone=True)
    }
    # Add Oscillator SNNs with specific deltas
    osc_config_b = copy.deepcopy(config); osc_config_b.osc_delta = osc_delta_mode_b
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_b})"] = CNNOscSNN(osc_config_b, pretrained_backbone=True)

    osc_config_a = copy.deepcopy(config); osc_config_a.osc_delta = osc_delta_mode_a
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_a})"] = CNNOscSNN(osc_config_a, pretrained_backbone=True)


    # --- Train and Evaluate Models ---
    for model_name, model_instance in models_to_run.items():
        print(f"\n--- Training {model_name} for {config_name} ---")
        start_time = time.time()
        # Use a fresh copy for each training run (already done by creating new instances above)
        current_model = model_instance
        current_config = config # Default config for non-oscillator models
        if "Osc-SNN Delta=" in model_name:
             delta_val = float(model_name.split('=')[-1].split(')')[0])
             if delta_val == osc_delta_mode_a: current_config = osc_config_a
             elif delta_val == osc_delta_mode_b: current_config = osc_config_b
             else: print(f"Warning: Could not match delta for {model_name}")


        try:
            # 修改: 添加对spike_counts和best_epoch的接收
            best_acc, epochs_history, spikes_at_convergence, spike_counts = train_and_evaluate_with_history(
                current_model, train_loader, test_loader, current_config # Pass correct config
            )
            end_time = time.time()
            training_time = end_time - start_time
            
            # 计算best_epoch (0-indexed)
            best_epoch = epochs_history.index(best_acc) if best_acc in epochs_history else 0
            
            print(f"--- {model_name} finished. Best Acc: {best_acc:.2f}%, "
                 f"Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}, "
                 f"Time: {training_time:.2f}s, Epochs Trained: {len(epochs_history)} ---")
                 
            # 修改: 更新logger.log_model_result调用
            logger.log_model_result(
                config_name, model_name, best_acc, training_time, 
                epochs_history, spikes_at_convergence, spike_counts, best_epoch
            )
        except Exception as e:
            end_time = time.time()
            training_time = end_time - start_time
            print(f"!!! ERROR during training/evaluation for {model_name} on {config_name}: {e}")
            # 修改: 更新错误情况下的logger.log_model_result调用
            logger.log_model_result(config_name, model_name, 0.0, training_time, [], 0.0, [], 0)
            import traceback
            traceback.print_exc()
            # Decide whether to continue with other models upon error
            # continue


    # --- Finalize and Save Results ---
    print("\n--- All models processed for this configuration ---")
    logger.save_results()
    summary_df = logger.generate_summary_table()
    # try:
    #     logger.plot_results() # Optional plotting
    # except Exception as e:
    #     print(f"An error occurred during final plotting: {e}")
    #     import traceback
    #     traceback.print_exc()

    print("\n" + "="*60)
    print(f"实验完成! 结果保存在 '{logger.output_dir}' 目录中")
    print("="*60 + "\n")
    if not summary_df.empty:
        print("结果汇总:")
        # Configure pandas for wider output
        pd.set_option('display.max_rows', None)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 2000)
        pd.set_option('display.max_colwidth', None) # Show full column width
        print(summary_df.to_string(index=False)) # Use to_string for better control
        pd.reset_option('all') # Reset pandas display options
    else:
        print("结果汇总为空.")

    return logger

# --- Run Experiment ---
if __name__ == "__main__":
    # Optional: Set seeds for reproducibility
    # seed = 42
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)
    # # Note: Full determinism can impact performance and might not be guaranteed on GPU
    # # torch.backends.cudnn.deterministic = True
    # # torch.backends.cudnn.benchmark = False

    logger = run_experiment()

使用设备: cuda

开始实验配置: TinyImageNet_ResNet-18_pretrained
Dataset: TinyImageNet (Root: ./tiny-imagenet-200, Classes: 200, Input Size: 224)
Backbone: ResNet-18_pretrained
Max Epochs: 200, Patience: 15, LR: 0.0001
Osc Params: alpha=2.0, beta=0.1, gamma=0.1, omega=1.0, dt=0.05
Lorenz Params: sigma=10.00, rho=28.00, beta=2.67, dt=0.05
SNN Params: steps=5, decay_beta=0.95, chaos_dim=128
Tiny ImageNet - Found 100000 training images belonging to 200 classes.
Tiny ImageNet - Found 10000 validation images.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.

--- Training Baseline (CNN-ANN) for TinyImageNet_ResNet-18_pretrained ---
--- Starting Training (Max Epochs: 200, Patience: 15) ---
Epoch [1/200] Loss: 3.9273 Train Acc: 15.43% Test Acc: 32.57% Total Spikes: 0 Epoch Time: 51.82s LR: 1.0e-05
  -> New best test accuracy: 32.57%. Model saved.
Epoch [2/200] Loss: 2.48

In [6]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder # Useful for Tiny ImageNet structure
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt # Keep if needed later
import numpy as np
import pandas as pd
import time
import copy
# from itertools import product # No longer needed for CNN configs
import os
import json
import snntorch as snn
import snntorch.surrogate as surrogate
from snntorch import utils
from snntorch import functional as SF
from PIL import Image # Needed for TinyImageNet loading potentially

# --- Configuration Class (Updated for Tiny ImageNet & ResNet) ---
class Config:
    # 数据集
    dataset_name = "TinyImageNet"
    data_root = './tiny-imagenet-200' # <<<--- IMPORTANT: SET PATH TO YOUR TINY IMAGENET FOLDER
    batch_size = 64 # May need to reduce based on GPU memory with ResNet
    input_size = 224 # Standard input size for ImageNet pre-trained models
    num_classes = 200 # Tiny ImageNet has 200 classes

    # CNN Backbone (Fixed to ResNet-18)
    backbone_name = "ResNet-18_pretrained"

    # Common SNN/Encoding Parameters
    # Projecting 512 ResNet features to 32 might be aggressive.
    # Consider increasing chaos_dim, e.g., 128 or 256, for better results.
    chaos_dim = 256 # Dimension for projection before oscillator/lorenz/SNN layers
    num_steps = 5 # Reduced num_steps initially for faster testing with ResNet

    # --- Oscillator Parameters ---
    osc_alpha = 2.0
    osc_beta = 0.1
    osc_gamma = 0.1
    osc_omega = 1.0
    osc_drive = 0.0
    # osc_delta will be set specifically
    osc_dt = 0.05

    # --- Lorenz Parameters ---
    lorenz_sigma = 10.0
    lorenz_rho = 28.0
    lorenz_beta = 8.0/3.0
    lorenz_dt = 0.05

    # SNN Decay Rate
    beta = 0.95 # SNN Leaky Neuron Beta (Decay Rate)

    # 训练
    epochs = 200 # Adjust max epochs for Tiny ImageNet & fine-tuning
    # Learning rate might need adjustment for fine-tuning ResNet
    lr = 1e-4 # Lower initial LR often better for fine-tuning
    weight_decay = 5e-4
    # Early Stopping
    patience = 15 # Maybe increase patience slightly

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_grad = surrogate.fast_sigmoid()

# --- OscillatorTransformFast Class (Unchanged) ---
class OscillatorTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.alpha = config.osc_alpha
        self.beta_osc = config.osc_beta
        self.gamma = config.osc_gamma
        self.delta = getattr(config, 'osc_delta', 0.0)
        self.omega = config.osc_omega
        self.dt = config.osc_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        current_delta = self.delta
        state = torch.cat([x, x*0.2, -x], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = y_cur
            dy = -self.alpha * x_cur - self.beta_osc * (x_cur**3) - current_delta * y_cur + self.gamma * z_cur
            dz = -self.omega * x_cur - current_delta * z_cur + self.gamma * x_cur * y_cur
            derivatives = torch.cat([dx, dy, dz], dim=1)
            state = state + self.dt * derivatives
            trajectories[:, t, :] = state
        return trajectories

# --- LorenzTransformFast Class (Unchanged) ---
class LorenzTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.sigma = config.lorenz_sigma
        self.rho = config.lorenz_rho
        self.lorenz_beta_param = config.lorenz_beta
        self.dt = config.lorenz_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        state = torch.cat([
            x,
            0.2*x,
            -x
        ], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = self.sigma * (y_cur - x_cur)
            dy = x_cur * (self.rho - z_cur) - y_cur
            dz = x_cur * y_cur - self.lorenz_beta_param * z_cur
            state = state + self.dt * torch.cat([dx, dy, dz], dim=1)
            trajectories[:, t, :] = state
        return trajectories

# --- Helper function to get ResNet-18 backbone ---
def _get_resnet_backbone(pretrained=True):
    """Loads a pretrained ResNet-18 model and removes the final fc layer."""
    if pretrained:
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        model = torchvision.models.resnet18(weights=weights)
        print("Loaded PRETRAINED ResNet-18 weights.")
    else:
        model = torchvision.models.resnet18(weights=None)
        print("Initialized ResNet-18 weights FROM SCRATCH.")

    # Remove the final fully connected layer (classifier)
    model.fc = nn.Identity() # Replace fc layer with identity
    return model

# --- CNNOscSNN Class (Unchanged, relies on config.num_classes) ---
class CNNOscSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.oscillator = OscillatorTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.oscillator(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- CNNLorenzSNN Class (Unchanged, relies on config.num_classes) ---
class CNNLorenzSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lorenz = LorenzTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.lorenz(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BasicCSNN Class (Unchanged, relies on config.num_classes) ---
class BasicCSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        cnn_features = torch.tanh(self.proj(features))
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        total_output_mem = 0
        batch_total_spikes = 0.0
        for _ in range(self.num_steps):
            spk1, mem1 = self.lif1(cnn_features, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            total_output_mem += self.fc_out(spk2)
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = total_output_mem / self.num_steps
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BaseCNN Class (Unchanged, relies on config.num_classes) ---
class BaseCNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.fc1 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.fc2 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.classifier = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, **kwargs):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.classifier(x)
        return x

# --- 修改后的 Evaluate 函数 - 返回 total_spikes 而不是 average ---
def evaluate(model, loader, config):
    """Evaluates the model, returns accuracy and total spikes."""
    model.eval()
    correct = 0
    total = 0
    total_spikes_evaluated = 0.0
    is_snn = isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN))

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            if is_snn:
                outputs, batch_spikes = model(images, return_spikes=True)
                total_spikes_evaluated += batch_spikes.item()
            else:
                outputs = model(images) # **kwargs in BaseCNN handles extra args
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100. * correct / total
    # 修改: 返回 total_spikes 而不是平均值
    return accuracy, total_spikes_evaluated


# --- 修改后的 Train and Evaluate with History ---
def train_and_evaluate_with_history(model, train_loader, test_loader, config):
    """Trains model with early stopping, returns best test acc, epoch history, and spikes at convergence."""
    model = model.to(device)
    # --- OPTIONAL: Differential Learning Rate ---
    # You might want different LRs for backbone and head
    # Example:
    head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone.')]
    backbone_params = [p for n, p in model.named_parameters() if n.startswith('backbone.')]
    optimizer = torch.optim.Adam([
        {'params': backbone_params, 'lr': config.lr * 0.1}, # Lower LR for backbone
        {'params': head_params, 'lr': config.lr} # Normal LR for head
    ], weight_decay=config.weight_decay)
    # --- Using single LR for simplicity now ---
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    criterion = nn.CrossEntropyLoss()

    best_test_acc_epoch = 0.0
    epochs_no_improve = 0
    patience = config.patience
    # Ensure output directory exists before saving temp model
    output_dir = f"experiment_results_{config.dataset_name}_{config.backbone_name}"
    os.makedirs(output_dir, exist_ok=True) # Ensure directory exists
    best_model_path = os.path.join(output_dir, f"temp_best_model_{time.time()}_{id(model)}.pth") # Save inside output dir
    
    # 修改: 添加 spike 记录
    history = []
    spike_counts = []  # 记录每个epoch的spikes
    best_epoch = 0  # 记录达到最佳性能的epoch

    print(f"--- Starting Training (Max Epochs: {config.epochs}, Patience: {patience}) ---")

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        start_epoch_time = time.time()

        for i, (images, labels) in enumerate(train_loader): # Add index i for progress printing
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images, return_spikes=False) # No need for spikes during training loop
            loss = criterion(outputs, labels)

            if not torch.isfinite(loss):
                print(f"WARNING: Loss is {loss.item()} at epoch {epoch+1}, batch {i}. Skipping backward pass.")
                continue # Skip batch if loss is invalid

            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            total_loss += loss.item()

            # Optional: Print progress within epoch
            # if (i + 1) % 100 == 0:
            #     print(f'  Epoch [{epoch+1}/{config.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')


        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time

        # 修改: 在每个epoch结束后评估并记录spikes
        test_acc, epoch_spikes = evaluate(model, test_loader, config)
        history.append(test_acc)
        spike_counts.append(epoch_spikes)  # 记录总spike数
        
        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = 100. * train_correct / train_total if train_total > 0 else 0

        # 修改: 打印spike信息
        print(f"Epoch [{epoch + 1}/{config.epochs}] Loss: {avg_loss:.4f} "
              f"Train Acc: {train_acc:.2f}% Test Acc: {test_acc:.2f}% "
              f"Total Spikes: {epoch_spikes:.0f} "
              f"Epoch Time: {epoch_duration:.2f}s LR: {scheduler.get_last_lr()[0]:.1e}")

        # Early Stopping Logic
        if test_acc > best_test_acc_epoch:
            best_test_acc_epoch = test_acc
            best_epoch = epoch  # 记录最佳epoch索引
            epochs_no_improve = 0
            try:
                torch.save(model.state_dict(), best_model_path)
                print(f"  -> New best test accuracy: {best_test_acc_epoch:.2f}%. Model saved.")
            except Exception as e:
                print(f"  -> Error saving model: {e}")
        else:
            epochs_no_improve += 1
            print(f"  -> Test accuracy did not improve for {epochs_no_improve} epoch(s). Best: {best_test_acc_epoch:.2f}%")

        if epochs_no_improve >= patience:
            print(f"\n--- Early stopping triggered after {epoch + 1} epochs. ---")
            break

        scheduler.step()

    # Post-Training: Load Best Model and Final Evaluation
    print("\n--- Training finished. Loading best model for final evaluation. ---")
    if os.path.exists(best_model_path):
        try:
            model.load_state_dict(torch.load(best_model_path))
            print(f"Successfully loaded best model state (Accuracy: {best_test_acc_epoch:.2f}%)")
            os.remove(best_model_path)
            # print(f"Removed temporary model file: {best_model_path}") # Optional: uncomment to confirm
        except Exception as e:
            print(f"Error loading best model state from {best_model_path}: {e}. Using model from last epoch.")
            best_test_acc_epoch = history[-1] if history else 0.0
    else:
        print("No best model was saved (or file missing). Using model from last epoch.")
        best_test_acc_epoch = history[-1] if history else 0.0

    print("--- Running final evaluation on the best performing model state ---")
    final_test_acc, final_total_spikes = evaluate(model, test_loader, config)

    # 修改: 获取收敛时的spike数
    spikes_at_convergence = spike_counts[best_epoch] if spike_counts and best_epoch < len(spike_counts) else 0

    print(f"Final Evaluation - Test Accuracy: {final_test_acc:.2f}%")
    if isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN)):
        print(f"Final Evaluation - Total Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}")
        print(f"Final Evaluation - Total Spikes in Final Evaluation: {final_total_spikes:.0f}")

    # 修改: 返回收敛时的spikes (不是平均值)
    return best_test_acc_epoch, history, spikes_at_convergence, spike_counts

# --- Tiny ImageNet Loading Function (未更改) ---
def load_tiny_imagenet(config):
    """Loads the Tiny ImageNet dataset."""
    data_dir = config.data_root
    num_workers = min(4, os.cpu_count()) if os.cpu_count() else 0
    image_size = config.input_size # Should be 224 for pre-trained ResNet

    # Standard ImageNet normalization
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Data augmentation and normalization for training
    # Adjust augmentation based on standard practices for ImageNet fine-tuning
    train_transform = transforms.Compose([
        transforms.Resize(image_size), # Resize first
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)), # Standard crop
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    # Just normalization for validation/testing
    val_transform = transforms.Compose([
        transforms.Resize(image_size), # Ensure validation images are also resized
        transforms.CenterCrop(image_size), # Use CenterCrop for validation
        transforms.ToTensor(),
        normalize,
    ])

    # --- Dataset Loading ---
    # Tiny ImageNet structure: train/[wnid]/images/*.JPEG, val/images/*.JPEG, val/val_annotations.txt
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val', 'images') # Validation images are flat

    if not os.path.exists(train_dir) or not os.path.exists(val_dir):
         raise FileNotFoundError(f"Tiny ImageNet data not found at expected paths: {train_dir} and {val_dir}. "
                                f"Please ensure Tiny ImageNet is downloaded and extracted to '{data_dir}' "
                                "following the standard directory structure.")

    train_dataset = ImageFolder(train_dir, transform=train_transform)

    # Validation dataset requires special handling due to annotations file
    # Creating a custom Dataset class is cleaner
    class TinyImageNetVal(Dataset):
        def __init__(self, val_dir, annotations_file, class_to_idx, transform=None):
            self.val_dir = val_dir
            self.transform = transform
            self.class_to_idx = class_to_idx
            self.samples = []
            try:
                with open(annotations_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            img_name, wnid = parts[0], parts[1]
                            img_path = os.path.join(self.val_dir, img_name)
                            if os.path.exists(img_path) and wnid in self.class_to_idx:
                                self.samples.append((img_path, self.class_to_idx[wnid]))
                            # else:
                                # print(f"Warning: Skipping invalid validation entry: {line.strip()}")
            except FileNotFoundError:
                 raise FileNotFoundError(f"Validation annotations file not found: {annotations_file}")


        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            img_path, target = self.samples[idx]
            # Ensure images are loaded in RGB format
            try:
                with open(img_path, 'rb') as f:
                    img = Image.open(f).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Return a dummy image/target or handle appropriately
                return torch.zeros(3, image_size, image_size), -1 # Indicate error

            if self.transform:
                img = self.transform(img)
            return img, target

    # Need class_to_idx mapping from the training set folders
    class_to_idx = train_dataset.class_to_idx
    val_annotations_file = os.path.join(data_dir, 'val', 'val_annotations.txt')
    val_dataset = TinyImageNetVal(val_dir, val_annotations_file, class_to_idx, transform=val_transform)


    print(f"Tiny ImageNet - Found {len(train_dataset)} training images belonging to {len(train_dataset.classes)} classes.")
    print(f"Tiny ImageNet - Found {len(val_dataset)} validation images.")


    # Data Loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    # Use validation set as the test set for Tiny ImageNet evaluation
    test_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader


# --- 修改 Experiment Logger Class ---
class ExperimentLogger:
    def __init__(self, output_dir="experiment_results_tinyimagenet_resnet18"): # Changed dir name
        self.results = {}
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        # 修改: 更新CSV标题，将Average Spikes改为Spikes at Convergence
        self.csv_headers = [
            "Config Name", # Renamed from CNN Config
            "Model",
            "Delta/System",
            "Best Test Accuracy (%)",
            "Spikes at Convergence", # 修改: 更改列名
            "Training Time (s)",
            "Epochs Trained",
            "Convergence Epoch" # 添加收敛epoch的列
        ]

    def log_config(self, config_name, config):
        config_dict = {}
        config_dict['patience'] = config.patience
        for key, value in vars(config).items():
             if isinstance(value, (int, float, str, bool, list, tuple, dict, type(None))):
                 config_dict[key] = value
             # ... (rest of serialization logic from previous version)
        self.results[config_name] = {
            "config": config_dict,
            "models": {}
        }

    # 修改: 更新log_model_result参数和逻辑
    def log_model_result(self, config_name, model_name, accuracy, training_time, epochs_history, spikes_at_convergence, spike_counts, best_epoch):
        if config_name not in self.results:
            self.results[config_name] = {"config": {}, "models": {}}
            print(f"Warning: Config '{config_name}' not pre-logged. Creating entry.")
        epochs_trained = len(epochs_history)
        self.results[config_name]["models"][model_name] = {
            "accuracy": accuracy,
            "training_time": training_time,
            "epochs_history": epochs_history, # Storing history can make JSON large
            "spikes_at_convergence": spikes_at_convergence, # 修改: 改为收敛时的spikes
            "epochs_trained": epochs_trained,
            "spike_counts": spike_counts, # 添加: 存储每个epoch的spike数据
            "convergence_epoch": best_epoch + 1 # 添加: 转换为1-indexed的epoch号
        }
        print(f"Logged for {config_name} - {model_name}: Best Acc={accuracy:.2f}, Spikes at Convergence={spikes_at_convergence:.0f}, "
              f"Time={training_time:.2f}s, Epochs={epochs_trained}, Convergence Epoch={best_epoch+1}")

    def save_results(self):
        filepath = os.path.join(self.output_dir, "results.json")
        try:
            # Save only essential results to JSON to avoid large files due to history
            results_to_save = {}
            for cfg_name, cfg_data in self.results.items():
                results_to_save[cfg_name] = {"config": cfg_data["config"], "models": {}}
                for mdl_name, mdl_data in cfg_data["models"].items():
                    results_to_save[cfg_name]["models"][mdl_name] = {
                        k: v for k, v in mdl_data.items() if k not in ['epochs_history', 'spike_counts']
                    }
            with open(filepath, "w") as f:
                json.dump(results_to_save, f, indent=4, default=lambda o: '<not serializable>')
            print(f"Results saved to {filepath}")
            
            # 添加: 保存spike数据到CSV文件
            for cfg_name, cfg_data in self.results.items():
                for mdl_name, mdl_data in cfg_data["models"].items():
                    if 'spike_counts' in mdl_data and mdl_data['spike_counts']:
                        spike_df = pd.DataFrame({
                            'Epoch': range(1, len(mdl_data['spike_counts'])+1),
                            'Total Spikes': mdl_data['spike_counts']
                        })
                        spike_file = os.path.join(self.output_dir, f"{cfg_name}_{mdl_name}_spikes.csv")
                        spike_df.to_csv(spike_file, index=False)
                        print(f"Spike data saved to {spike_file}")
        except Exception as e:
            print(f"Error saving JSON results: {e}")


    def generate_summary_table(self):
        rows = []
        for config_name, data in self.results.items():
            for model_name, model_data in data["models"].items():
                delta_str = "N/A"
                # Simplified model type identification for summary
                if "Osc-SNN Delta=" in model_name:
                    delta_str = f"Osc(Δ={model_name.split('=')[-1].split(')')[0]})"
                elif "Lorenz-SNN" in model_name:
                    delta_str = "Lorenz"
                elif "CNN-SNN" in model_name and "Osc" not in model_name and "Lorenz" not in model_name:
                     delta_str = "Direct SNN"
                elif "CNN-ANN" in model_name:
                     delta_str = "ANN"

                spikes = model_data.get("spikes_at_convergence", 0.0)  # 修改: 使用收敛时的spikes
                epochs_trained = model_data.get("epochs_trained", "N/A")
                convergence_epoch = model_data.get("convergence_epoch", "N/A")  # 添加: 获取收敛epoch

                row = {
                    self.csv_headers[0]: config_name, # Use "Config Name"
                    self.csv_headers[1]: model_name,
                    self.csv_headers[2]: delta_str,
                    self.csv_headers[3]: model_data["accuracy"],
                    self.csv_headers[4]: spikes,  # 修改: 使用收敛时的spikes
                    self.csv_headers[5]: model_data["training_time"],
                    self.csv_headers[6]: epochs_trained,
                    self.csv_headers[7]: convergence_epoch  # 添加: 收敛epoch
                }
                rows.append(row)
        if not rows: return pd.DataFrame(columns=self.csv_headers)
        df = pd.DataFrame(rows)
        df = df[self.csv_headers] # Ensure column order
        try:
            df[self.csv_headers[3]] = pd.to_numeric(df[self.csv_headers[3]], errors='coerce').map('{:.2f}'.format)
            df[self.csv_headers[4]] = pd.to_numeric(df[self.csv_headers[4]], errors='coerce').map('{:.0f}'.format)  # 修改: 整数格式
            df[self.csv_headers[5]] = pd.to_numeric(df[self.csv_headers[5]], errors='coerce').map('{:.2f}'.format)
        except Exception as e: print(f"Error formatting summary table columns: {e}")
        filepath = os.path.join(self.output_dir, "summary.csv")
        try: df.to_csv(filepath, index=False); print(f"Summary table saved to {filepath}")
        except Exception as e: print(f"Error saving summary CSV: {e}")
        return df

    # --- Plotting (Requires Matplotlib) ---
    # Consider simplifying or removing plotting if matplotlib is not available/needed now
    # def plot_results(self):
    #     # ... (Plotting code from previous version - needs matplotlib)
    #     # If keeping plots, update logic to handle single config name and model types
    #     pass

# --- 修改 Main Experiment Function ---
def run_experiment():
    print(f"使用设备: {device}")
    config = Config() # Use the updated config
    config_name = f"{config.dataset_name}_{config.backbone_name}" # Single config name
    logger = ExperimentLogger(output_dir=f"experiment_results_{config_name}")

    # Define the specific oscillator deltas for the two modes
    # Use values inspired by Table D1, e.g., one expansive, one dissipative
    osc_delta_mode_b = -1.5 # High performance potential (Expansive)
    osc_delta_mode_a = 10.0 # High efficiency (Dissipative)

    print(f"\n{'=' * 60}")
    print(f"开始实验配置: {config_name}")
    print(f"Dataset: {config.dataset_name} (Root: {config.data_root}, Classes: {config.num_classes}, Input Size: {config.input_size})")
    print(f"Backbone: {config.backbone_name}")
    print(f"Max Epochs: {config.epochs}, Patience: {config.patience}, LR: {config.lr}")
    print(f"Osc Params: alpha={config.osc_alpha}, beta={config.osc_beta}, gamma={config.osc_gamma}, omega={config.osc_omega}, dt={config.osc_dt}")
    print(f"Lorenz Params: sigma={config.lorenz_sigma:.2f}, rho={config.lorenz_rho:.2f}, beta={config.lorenz_beta:.2f}, dt={config.lorenz_dt}")
    print(f"SNN Params: steps={config.num_steps}, decay_beta={config.beta:.2f}, chaos_dim={config.chaos_dim}")
    print(f"{'=' * 60}")

    logger.log_config(config_name, config)

    try:
        # Load Tiny ImageNet data
        train_loader, test_loader = load_tiny_imagenet(config)
    except Exception as e:
        print(f"无法加载数据集 {config_name}. 终止实验. 错误: {e}")
        import traceback
        traceback.print_exc()
        return None # Exit if data loading fails

    # --- Define Models to Run ---
    # Ensure pretrained_backbone=True is passed correctly
    models_to_run = {
        "Baseline (CNN-ANN)": BaseCNN(config, pretrained_backbone=True),
        "Baseline (CNN-SNN)": BasicCSNN(config, pretrained_backbone=True),
        "Proposed (CNN-Lorenz-SNN)": CNNLorenzSNN(config, pretrained_backbone=True)
    }
    # Add Oscillator SNNs with specific deltas
    osc_config_b = copy.deepcopy(config); osc_config_b.osc_delta = osc_delta_mode_b
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_b})"] = CNNOscSNN(osc_config_b, pretrained_backbone=True)

    osc_config_a = copy.deepcopy(config); osc_config_a.osc_delta = osc_delta_mode_a
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_a})"] = CNNOscSNN(osc_config_a, pretrained_backbone=True)


    # --- Train and Evaluate Models ---
    for model_name, model_instance in models_to_run.items():
        print(f"\n--- Training {model_name} for {config_name} ---")
        start_time = time.time()
        # Use a fresh copy for each training run (already done by creating new instances above)
        current_model = model_instance
        current_config = config # Default config for non-oscillator models
        if "Osc-SNN Delta=" in model_name:
             delta_val = float(model_name.split('=')[-1].split(')')[0])
             if delta_val == osc_delta_mode_a: current_config = osc_config_a
             elif delta_val == osc_delta_mode_b: current_config = osc_config_b
             else: print(f"Warning: Could not match delta for {model_name}")


        try:
            # 修改: 添加对spike_counts和best_epoch的接收
            best_acc, epochs_history, spikes_at_convergence, spike_counts = train_and_evaluate_with_history(
                current_model, train_loader, test_loader, current_config # Pass correct config
            )
            end_time = time.time()
            training_time = end_time - start_time
            
            # 计算best_epoch (0-indexed)
            best_epoch = epochs_history.index(best_acc) if best_acc in epochs_history else 0
            
            print(f"--- {model_name} finished. Best Acc: {best_acc:.2f}%, "
                 f"Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}, "
                 f"Time: {training_time:.2f}s, Epochs Trained: {len(epochs_history)} ---")
                 
            # 修改: 更新logger.log_model_result调用
            logger.log_model_result(
                config_name, model_name, best_acc, training_time, 
                epochs_history, spikes_at_convergence, spike_counts, best_epoch
            )
        except Exception as e:
            end_time = time.time()
            training_time = end_time - start_time
            print(f"!!! ERROR during training/evaluation for {model_name} on {config_name}: {e}")
            # 修改: 更新错误情况下的logger.log_model_result调用
            logger.log_model_result(config_name, model_name, 0.0, training_time, [], 0.0, [], 0)
            import traceback
            traceback.print_exc()
            # Decide whether to continue with other models upon error
            # continue


    # --- Finalize and Save Results ---
    print("\n--- All models processed for this configuration ---")
    logger.save_results()
    summary_df = logger.generate_summary_table()
    # try:
    #     logger.plot_results() # Optional plotting
    # except Exception as e:
    #     print(f"An error occurred during final plotting: {e}")
    #     import traceback
    #     traceback.print_exc()

    print("\n" + "="*60)
    print(f"实验完成! 结果保存在 '{logger.output_dir}' 目录中")
    print("="*60 + "\n")
    if not summary_df.empty:
        print("结果汇总:")
        # Configure pandas for wider output
        pd.set_option('display.max_rows', None)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 2000)
        pd.set_option('display.max_colwidth', None) # Show full column width
        print(summary_df.to_string(index=False)) # Use to_string for better control
        pd.reset_option('all') # Reset pandas display options
    else:
        print("结果汇总为空.")

    return logger

# --- Run Experiment ---
if __name__ == "__main__":
    # Optional: Set seeds for reproducibility
    # seed = 42
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)
    # # Note: Full determinism can impact performance and might not be guaranteed on GPU
    # # torch.backends.cudnn.deterministic = True
    # # torch.backends.cudnn.benchmark = False

    logger = run_experiment()

使用设备: cuda

开始实验配置: TinyImageNet_ResNet-18_pretrained
Dataset: TinyImageNet (Root: ./tiny-imagenet-200, Classes: 200, Input Size: 224)
Backbone: ResNet-18_pretrained
Max Epochs: 200, Patience: 15, LR: 0.0001
Osc Params: alpha=2.0, beta=0.1, gamma=0.1, omega=1.0, dt=0.05
Lorenz Params: sigma=10.00, rho=28.00, beta=2.67, dt=0.05
SNN Params: steps=5, decay_beta=0.95, chaos_dim=256
Tiny ImageNet - Found 100000 training images belonging to 200 classes.
Tiny ImageNet - Found 10000 validation images.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.

--- Training Baseline (CNN-ANN) for TinyImageNet_ResNet-18_pretrained ---
--- Starting Training (Max Epochs: 200, Patience: 15) ---
Epoch [1/200] Loss: 3.4558 Train Acc: 22.07% Test Acc: 41.06% Total Spikes: 0 Epoch Time: 52.26s LR: 1.0e-05
  -> New best test accuracy: 41.06%. Model saved.
Epoch [2/200] Loss: 2.11

In [7]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder # Useful for Tiny ImageNet structure
import torchvision.transforms as transforms
# import matplotlib.pyplot as plt # Keep if needed later
import numpy as np
import pandas as pd
import time
import copy
# from itertools import product # No longer needed for CNN configs
import os
import json
import snntorch as snn
import snntorch.surrogate as surrogate
from snntorch import utils
from snntorch import functional as SF
from PIL import Image # Needed for TinyImageNet loading potentially

# --- Configuration Class (Updated for Tiny ImageNet & ResNet) ---
class Config:
    # 数据集
    dataset_name = "TinyImageNet"
    data_root = './tiny-imagenet-200' # <<<--- IMPORTANT: SET PATH TO YOUR TINY IMAGENET FOLDER
    batch_size = 64 # May need to reduce based on GPU memory with ResNet
    input_size = 224 # Standard input size for ImageNet pre-trained models
    num_classes = 200 # Tiny ImageNet has 200 classes

    # CNN Backbone (Fixed to ResNet-18)
    backbone_name = "ResNet-18_pretrained"

    # Common SNN/Encoding Parameters
    # Projecting 512 ResNet features to 32 might be aggressive.
    # Consider increasing chaos_dim, e.g., 128 or 256, for better results.
    chaos_dim = 512 # Dimension for projection before oscillator/lorenz/SNN layers
    num_steps = 5 # Reduced num_steps initially for faster testing with ResNet

    # --- Oscillator Parameters ---
    osc_alpha = 2.0
    osc_beta = 0.1
    osc_gamma = 0.1
    osc_omega = 1.0
    osc_drive = 0.0
    # osc_delta will be set specifically
    osc_dt = 0.05

    # --- Lorenz Parameters ---
    lorenz_sigma = 10.0
    lorenz_rho = 28.0
    lorenz_beta = 8.0/3.0
    lorenz_dt = 0.05

    # SNN Decay Rate
    beta = 0.95 # SNN Leaky Neuron Beta (Decay Rate)

    # 训练
    epochs = 200 # Adjust max epochs for Tiny ImageNet & fine-tuning
    # Learning rate might need adjustment for fine-tuning ResNet
    lr = 1e-4 # Lower initial LR often better for fine-tuning
    weight_decay = 5e-4
    # Early Stopping
    patience = 15 # Maybe increase patience slightly

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spike_grad = surrogate.fast_sigmoid()

# --- OscillatorTransformFast Class (Unchanged) ---
class OscillatorTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.alpha = config.osc_alpha
        self.beta_osc = config.osc_beta
        self.gamma = config.osc_gamma
        self.delta = getattr(config, 'osc_delta', 0.0)
        self.omega = config.osc_omega
        self.dt = config.osc_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        current_delta = self.delta
        state = torch.cat([x, x*0.2, -x], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = y_cur
            dy = -self.alpha * x_cur - self.beta_osc * (x_cur**3) - current_delta * y_cur + self.gamma * z_cur
            dz = -self.omega * x_cur - current_delta * z_cur + self.gamma * x_cur * y_cur
            derivatives = torch.cat([dx, dy, dz], dim=1)
            state = state + self.dt * derivatives
            trajectories[:, t, :] = state
        return trajectories

# --- LorenzTransformFast Class (Unchanged) ---
class LorenzTransformFast(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_steps = config.num_steps
        self.sigma = config.lorenz_sigma
        self.rho = config.lorenz_rho
        self.lorenz_beta_param = config.lorenz_beta
        self.dt = config.lorenz_dt

    def forward(self, x):
        batch_size, dim = x.shape
        device = x.device
        trajectories = torch.zeros(batch_size, self.num_steps, dim * 3, device=device)
        state = torch.cat([
            x,
            0.2*x,
            -x
        ], dim=1)
        trajectories[:, 0, :] = state
        for t in range(1, self.num_steps):
            x_cur = state[:, :dim]
            y_cur = state[:, dim:2 * dim]
            z_cur = state[:, 2 * dim:]
            dx = self.sigma * (y_cur - x_cur)
            dy = x_cur * (self.rho - z_cur) - y_cur
            dz = x_cur * y_cur - self.lorenz_beta_param * z_cur
            state = state + self.dt * torch.cat([dx, dy, dz], dim=1)
            trajectories[:, t, :] = state
        return trajectories

# --- Helper function to get ResNet-18 backbone ---
def _get_resnet_backbone(pretrained=True):
    """Loads a pretrained ResNet-18 model and removes the final fc layer."""
    if pretrained:
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        model = torchvision.models.resnet18(weights=weights)
        print("Loaded PRETRAINED ResNet-18 weights.")
    else:
        model = torchvision.models.resnet18(weights=None)
        print("Initialized ResNet-18 weights FROM SCRATCH.")

    # Remove the final fully connected layer (classifier)
    model.fc = nn.Identity() # Replace fc layer with identity
    return model

# --- CNNOscSNN Class (Unchanged, relies on config.num_classes) ---
class CNNOscSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.oscillator = OscillatorTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.oscillator(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- CNNLorenzSNN Class (Unchanged, relies on config.num_classes) ---
class CNNLorenzSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lorenz = LorenzTransformFast(config)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim * 3, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        encoded = self.lorenz(x)
        outputs = []
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        batch_total_spikes = 0.0
        for step in range(self.num_steps):
            cur = encoded[:, step]
            spk1, mem1 = self.lif1(cur, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            outputs.append(self.fc_out(spk2))
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = torch.stack(outputs).sum(0)
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BasicCSNN Class (Unchanged, relies on config.num_classes) ---
class BasicCSNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.num_steps = config.num_steps
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.lif1 = snn.Leaky(beta=config.beta)
        self.lif2 = snn.Leaky(beta=config.beta)
        self.fc_out = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, return_spikes=False):
        features = self.backbone(x)
        cnn_features = torch.tanh(self.proj(features))
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        total_output_mem = 0
        batch_total_spikes = 0.0
        for _ in range(self.num_steps):
            spk1, mem1 = self.lif1(cnn_features, mem1)
            spk2, mem2 = self.lif2(spk1, mem2)
            total_output_mem += self.fc_out(spk2)
            if return_spikes:
                batch_total_spikes += spk1.sum().item() + spk2.sum().item()
        final_output = total_output_mem / self.num_steps
        if return_spikes:
            return final_output, torch.tensor(batch_total_spikes, device=final_output.device)
        else:
            return final_output

# --- BaseCNN Class (Unchanged, relies on config.num_classes) ---
class BaseCNN(nn.Module):
    def __init__(self, config, pretrained_backbone=True):
        super().__init__()
        self.backbone = _get_resnet_backbone(pretrained=pretrained_backbone)
        self.proj = nn.Linear(512, config.chaos_dim)
        self.fc1 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.fc2 = nn.Linear(config.chaos_dim, config.chaos_dim)
        self.classifier = nn.Linear(config.chaos_dim, config.num_classes) # Automatically uses updated num_classes

    def forward(self, x, **kwargs):
        features = self.backbone(x)
        x = torch.tanh(self.proj(features))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.classifier(x)
        return x

# --- 修改后的 Evaluate 函数 - 返回 total_spikes 而不是 average ---
def evaluate(model, loader, config):
    """Evaluates the model, returns accuracy and total spikes."""
    model.eval()
    correct = 0
    total = 0
    total_spikes_evaluated = 0.0
    is_snn = isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN))

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            if is_snn:
                outputs, batch_spikes = model(images, return_spikes=True)
                total_spikes_evaluated += batch_spikes.item()
            else:
                outputs = model(images) # **kwargs in BaseCNN handles extra args
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100. * correct / total
    # 修改: 返回 total_spikes 而不是平均值
    return accuracy, total_spikes_evaluated


# --- 修改后的 Train and Evaluate with History ---
def train_and_evaluate_with_history(model, train_loader, test_loader, config):
    """Trains model with early stopping, returns best test acc, epoch history, and spikes at convergence."""
    model = model.to(device)
    # --- OPTIONAL: Differential Learning Rate ---
    # You might want different LRs for backbone and head
    # Example:
    head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone.')]
    backbone_params = [p for n, p in model.named_parameters() if n.startswith('backbone.')]
    optimizer = torch.optim.Adam([
        {'params': backbone_params, 'lr': config.lr * 0.1}, # Lower LR for backbone
        {'params': head_params, 'lr': config.lr} # Normal LR for head
    ], weight_decay=config.weight_decay)
    # --- Using single LR for simplicity now ---
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    criterion = nn.CrossEntropyLoss()

    best_test_acc_epoch = 0.0
    epochs_no_improve = 0
    patience = config.patience
    # Ensure output directory exists before saving temp model
    output_dir = f"experiment_results_{config.dataset_name}_{config.backbone_name}"
    os.makedirs(output_dir, exist_ok=True) # Ensure directory exists
    best_model_path = os.path.join(output_dir, f"temp_best_model_{time.time()}_{id(model)}.pth") # Save inside output dir
    
    # 修改: 添加 spike 记录
    history = []
    spike_counts = []  # 记录每个epoch的spikes
    best_epoch = 0  # 记录达到最佳性能的epoch

    print(f"--- Starting Training (Max Epochs: {config.epochs}, Patience: {patience}) ---")

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        start_epoch_time = time.time()

        for i, (images, labels) in enumerate(train_loader): # Add index i for progress printing
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images, return_spikes=False) # No need for spikes during training loop
            loss = criterion(outputs, labels)

            if not torch.isfinite(loss):
                print(f"WARNING: Loss is {loss.item()} at epoch {epoch+1}, batch {i}. Skipping backward pass.")
                continue # Skip batch if loss is invalid

            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            total_loss += loss.item()

            # Optional: Print progress within epoch
            # if (i + 1) % 100 == 0:
            #     print(f'  Epoch [{epoch+1}/{config.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')


        end_epoch_time = time.time()
        epoch_duration = end_epoch_time - start_epoch_time

        # 修改: 在每个epoch结束后评估并记录spikes
        test_acc, epoch_spikes = evaluate(model, test_loader, config)
        history.append(test_acc)
        spike_counts.append(epoch_spikes)  # 记录总spike数
        
        avg_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = 100. * train_correct / train_total if train_total > 0 else 0

        # 修改: 打印spike信息
        print(f"Epoch [{epoch + 1}/{config.epochs}] Loss: {avg_loss:.4f} "
              f"Train Acc: {train_acc:.2f}% Test Acc: {test_acc:.2f}% "
              f"Total Spikes: {epoch_spikes:.0f} "
              f"Epoch Time: {epoch_duration:.2f}s LR: {scheduler.get_last_lr()[0]:.1e}")

        # Early Stopping Logic
        if test_acc > best_test_acc_epoch:
            best_test_acc_epoch = test_acc
            best_epoch = epoch  # 记录最佳epoch索引
            epochs_no_improve = 0
            try:
                torch.save(model.state_dict(), best_model_path)
                print(f"  -> New best test accuracy: {best_test_acc_epoch:.2f}%. Model saved.")
            except Exception as e:
                print(f"  -> Error saving model: {e}")
        else:
            epochs_no_improve += 1
            print(f"  -> Test accuracy did not improve for {epochs_no_improve} epoch(s). Best: {best_test_acc_epoch:.2f}%")

        if epochs_no_improve >= patience:
            print(f"\n--- Early stopping triggered after {epoch + 1} epochs. ---")
            break

        scheduler.step()

    # Post-Training: Load Best Model and Final Evaluation
    print("\n--- Training finished. Loading best model for final evaluation. ---")
    if os.path.exists(best_model_path):
        try:
            model.load_state_dict(torch.load(best_model_path))
            print(f"Successfully loaded best model state (Accuracy: {best_test_acc_epoch:.2f}%)")
            os.remove(best_model_path)
            # print(f"Removed temporary model file: {best_model_path}") # Optional: uncomment to confirm
        except Exception as e:
            print(f"Error loading best model state from {best_model_path}: {e}. Using model from last epoch.")
            best_test_acc_epoch = history[-1] if history else 0.0
    else:
        print("No best model was saved (or file missing). Using model from last epoch.")
        best_test_acc_epoch = history[-1] if history else 0.0

    print("--- Running final evaluation on the best performing model state ---")
    final_test_acc, final_total_spikes = evaluate(model, test_loader, config)

    # 修改: 获取收敛时的spike数
    spikes_at_convergence = spike_counts[best_epoch] if spike_counts and best_epoch < len(spike_counts) else 0

    print(f"Final Evaluation - Test Accuracy: {final_test_acc:.2f}%")
    if isinstance(model, (CNNOscSNN, BasicCSNN, CNNLorenzSNN)):
        print(f"Final Evaluation - Total Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}")
        print(f"Final Evaluation - Total Spikes in Final Evaluation: {final_total_spikes:.0f}")

    # 修改: 返回收敛时的spikes (不是平均值)
    return best_test_acc_epoch, history, spikes_at_convergence, spike_counts

# --- Tiny ImageNet Loading Function (未更改) ---
def load_tiny_imagenet(config):
    """Loads the Tiny ImageNet dataset."""
    data_dir = config.data_root
    num_workers = min(4, os.cpu_count()) if os.cpu_count() else 0
    image_size = config.input_size # Should be 224 for pre-trained ResNet

    # Standard ImageNet normalization
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Data augmentation and normalization for training
    # Adjust augmentation based on standard practices for ImageNet fine-tuning
    train_transform = transforms.Compose([
        transforms.Resize(image_size), # Resize first
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)), # Standard crop
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    # Just normalization for validation/testing
    val_transform = transforms.Compose([
        transforms.Resize(image_size), # Ensure validation images are also resized
        transforms.CenterCrop(image_size), # Use CenterCrop for validation
        transforms.ToTensor(),
        normalize,
    ])

    # --- Dataset Loading ---
    # Tiny ImageNet structure: train/[wnid]/images/*.JPEG, val/images/*.JPEG, val/val_annotations.txt
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val', 'images') # Validation images are flat

    if not os.path.exists(train_dir) or not os.path.exists(val_dir):
         raise FileNotFoundError(f"Tiny ImageNet data not found at expected paths: {train_dir} and {val_dir}. "
                                f"Please ensure Tiny ImageNet is downloaded and extracted to '{data_dir}' "
                                "following the standard directory structure.")

    train_dataset = ImageFolder(train_dir, transform=train_transform)

    # Validation dataset requires special handling due to annotations file
    # Creating a custom Dataset class is cleaner
    class TinyImageNetVal(Dataset):
        def __init__(self, val_dir, annotations_file, class_to_idx, transform=None):
            self.val_dir = val_dir
            self.transform = transform
            self.class_to_idx = class_to_idx
            self.samples = []
            try:
                with open(annotations_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            img_name, wnid = parts[0], parts[1]
                            img_path = os.path.join(self.val_dir, img_name)
                            if os.path.exists(img_path) and wnid in self.class_to_idx:
                                self.samples.append((img_path, self.class_to_idx[wnid]))
                            # else:
                                # print(f"Warning: Skipping invalid validation entry: {line.strip()}")
            except FileNotFoundError:
                 raise FileNotFoundError(f"Validation annotations file not found: {annotations_file}")


        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            img_path, target = self.samples[idx]
            # Ensure images are loaded in RGB format
            try:
                with open(img_path, 'rb') as f:
                    img = Image.open(f).convert('RGB')
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Return a dummy image/target or handle appropriately
                return torch.zeros(3, image_size, image_size), -1 # Indicate error

            if self.transform:
                img = self.transform(img)
            return img, target

    # Need class_to_idx mapping from the training set folders
    class_to_idx = train_dataset.class_to_idx
    val_annotations_file = os.path.join(data_dir, 'val', 'val_annotations.txt')
    val_dataset = TinyImageNetVal(val_dir, val_annotations_file, class_to_idx, transform=val_transform)


    print(f"Tiny ImageNet - Found {len(train_dataset)} training images belonging to {len(train_dataset.classes)} classes.")
    print(f"Tiny ImageNet - Found {len(val_dataset)} validation images.")


    # Data Loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    # Use validation set as the test set for Tiny ImageNet evaluation
    test_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader


# --- 修改 Experiment Logger Class ---
class ExperimentLogger:
    def __init__(self, output_dir="experiment_results_tinyimagenet_resnet18"): # Changed dir name
        self.results = {}
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        # 修改: 更新CSV标题，将Average Spikes改为Spikes at Convergence
        self.csv_headers = [
            "Config Name", # Renamed from CNN Config
            "Model",
            "Delta/System",
            "Best Test Accuracy (%)",
            "Spikes at Convergence", # 修改: 更改列名
            "Training Time (s)",
            "Epochs Trained",
            "Convergence Epoch" # 添加收敛epoch的列
        ]

    def log_config(self, config_name, config):
        config_dict = {}
        config_dict['patience'] = config.patience
        for key, value in vars(config).items():
             if isinstance(value, (int, float, str, bool, list, tuple, dict, type(None))):
                 config_dict[key] = value
             # ... (rest of serialization logic from previous version)
        self.results[config_name] = {
            "config": config_dict,
            "models": {}
        }

    # 修改: 更新log_model_result参数和逻辑
    def log_model_result(self, config_name, model_name, accuracy, training_time, epochs_history, spikes_at_convergence, spike_counts, best_epoch):
        if config_name not in self.results:
            self.results[config_name] = {"config": {}, "models": {}}
            print(f"Warning: Config '{config_name}' not pre-logged. Creating entry.")
        epochs_trained = len(epochs_history)
        self.results[config_name]["models"][model_name] = {
            "accuracy": accuracy,
            "training_time": training_time,
            "epochs_history": epochs_history, # Storing history can make JSON large
            "spikes_at_convergence": spikes_at_convergence, # 修改: 改为收敛时的spikes
            "epochs_trained": epochs_trained,
            "spike_counts": spike_counts, # 添加: 存储每个epoch的spike数据
            "convergence_epoch": best_epoch + 1 # 添加: 转换为1-indexed的epoch号
        }
        print(f"Logged for {config_name} - {model_name}: Best Acc={accuracy:.2f}, Spikes at Convergence={spikes_at_convergence:.0f}, "
              f"Time={training_time:.2f}s, Epochs={epochs_trained}, Convergence Epoch={best_epoch+1}")

    def save_results(self):
        filepath = os.path.join(self.output_dir, "results.json")
        try:
            # Save only essential results to JSON to avoid large files due to history
            results_to_save = {}
            for cfg_name, cfg_data in self.results.items():
                results_to_save[cfg_name] = {"config": cfg_data["config"], "models": {}}
                for mdl_name, mdl_data in cfg_data["models"].items():
                    results_to_save[cfg_name]["models"][mdl_name] = {
                        k: v for k, v in mdl_data.items() if k not in ['epochs_history', 'spike_counts']
                    }
            with open(filepath, "w") as f:
                json.dump(results_to_save, f, indent=4, default=lambda o: '<not serializable>')
            print(f"Results saved to {filepath}")
            
            # 添加: 保存spike数据到CSV文件
            for cfg_name, cfg_data in self.results.items():
                for mdl_name, mdl_data in cfg_data["models"].items():
                    if 'spike_counts' in mdl_data and mdl_data['spike_counts']:
                        spike_df = pd.DataFrame({
                            'Epoch': range(1, len(mdl_data['spike_counts'])+1),
                            'Total Spikes': mdl_data['spike_counts']
                        })
                        spike_file = os.path.join(self.output_dir, f"{cfg_name}_{mdl_name}_spikes.csv")
                        spike_df.to_csv(spike_file, index=False)
                        print(f"Spike data saved to {spike_file}")
        except Exception as e:
            print(f"Error saving JSON results: {e}")


    def generate_summary_table(self):
        rows = []
        for config_name, data in self.results.items():
            for model_name, model_data in data["models"].items():
                delta_str = "N/A"
                # Simplified model type identification for summary
                if "Osc-SNN Delta=" in model_name:
                    delta_str = f"Osc(Δ={model_name.split('=')[-1].split(')')[0]})"
                elif "Lorenz-SNN" in model_name:
                    delta_str = "Lorenz"
                elif "CNN-SNN" in model_name and "Osc" not in model_name and "Lorenz" not in model_name:
                     delta_str = "Direct SNN"
                elif "CNN-ANN" in model_name:
                     delta_str = "ANN"

                spikes = model_data.get("spikes_at_convergence", 0.0)  # 修改: 使用收敛时的spikes
                epochs_trained = model_data.get("epochs_trained", "N/A")
                convergence_epoch = model_data.get("convergence_epoch", "N/A")  # 添加: 获取收敛epoch

                row = {
                    self.csv_headers[0]: config_name, # Use "Config Name"
                    self.csv_headers[1]: model_name,
                    self.csv_headers[2]: delta_str,
                    self.csv_headers[3]: model_data["accuracy"],
                    self.csv_headers[4]: spikes,  # 修改: 使用收敛时的spikes
                    self.csv_headers[5]: model_data["training_time"],
                    self.csv_headers[6]: epochs_trained,
                    self.csv_headers[7]: convergence_epoch  # 添加: 收敛epoch
                }
                rows.append(row)
        if not rows: return pd.DataFrame(columns=self.csv_headers)
        df = pd.DataFrame(rows)
        df = df[self.csv_headers] # Ensure column order
        try:
            df[self.csv_headers[3]] = pd.to_numeric(df[self.csv_headers[3]], errors='coerce').map('{:.2f}'.format)
            df[self.csv_headers[4]] = pd.to_numeric(df[self.csv_headers[4]], errors='coerce').map('{:.0f}'.format)  # 修改: 整数格式
            df[self.csv_headers[5]] = pd.to_numeric(df[self.csv_headers[5]], errors='coerce').map('{:.2f}'.format)
        except Exception as e: print(f"Error formatting summary table columns: {e}")
        filepath = os.path.join(self.output_dir, "summary.csv")
        try: df.to_csv(filepath, index=False); print(f"Summary table saved to {filepath}")
        except Exception as e: print(f"Error saving summary CSV: {e}")
        return df

    # --- Plotting (Requires Matplotlib) ---
    # Consider simplifying or removing plotting if matplotlib is not available/needed now
    # def plot_results(self):
    #     # ... (Plotting code from previous version - needs matplotlib)
    #     # If keeping plots, update logic to handle single config name and model types
    #     pass

# --- 修改 Main Experiment Function ---
def run_experiment():
    print(f"使用设备: {device}")
    config = Config() # Use the updated config
    config_name = f"{config.dataset_name}_{config.backbone_name}" # Single config name
    logger = ExperimentLogger(output_dir=f"experiment_results_{config_name}")

    # Define the specific oscillator deltas for the two modes
    # Use values inspired by Table D1, e.g., one expansive, one dissipative
    osc_delta_mode_b = -1.5 # High performance potential (Expansive)
    osc_delta_mode_a = 10.0 # High efficiency (Dissipative)

    print(f"\n{'=' * 60}")
    print(f"开始实验配置: {config_name}")
    print(f"Dataset: {config.dataset_name} (Root: {config.data_root}, Classes: {config.num_classes}, Input Size: {config.input_size})")
    print(f"Backbone: {config.backbone_name}")
    print(f"Max Epochs: {config.epochs}, Patience: {config.patience}, LR: {config.lr}")
    print(f"Osc Params: alpha={config.osc_alpha}, beta={config.osc_beta}, gamma={config.osc_gamma}, omega={config.osc_omega}, dt={config.osc_dt}")
    print(f"Lorenz Params: sigma={config.lorenz_sigma:.2f}, rho={config.lorenz_rho:.2f}, beta={config.lorenz_beta:.2f}, dt={config.lorenz_dt}")
    print(f"SNN Params: steps={config.num_steps}, decay_beta={config.beta:.2f}, chaos_dim={config.chaos_dim}")
    print(f"{'=' * 60}")

    logger.log_config(config_name, config)

    try:
        # Load Tiny ImageNet data
        train_loader, test_loader = load_tiny_imagenet(config)
    except Exception as e:
        print(f"无法加载数据集 {config_name}. 终止实验. 错误: {e}")
        import traceback
        traceback.print_exc()
        return None # Exit if data loading fails

    # --- Define Models to Run ---
    # Ensure pretrained_backbone=True is passed correctly
    models_to_run = {
        "Baseline (CNN-ANN)": BaseCNN(config, pretrained_backbone=True),
        "Baseline (CNN-SNN)": BasicCSNN(config, pretrained_backbone=True),
        "Proposed (CNN-Lorenz-SNN)": CNNLorenzSNN(config, pretrained_backbone=True)
    }
    # Add Oscillator SNNs with specific deltas
    osc_config_b = copy.deepcopy(config); osc_config_b.osc_delta = osc_delta_mode_b
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_b})"] = CNNOscSNN(osc_config_b, pretrained_backbone=True)

    osc_config_a = copy.deepcopy(config); osc_config_a.osc_delta = osc_delta_mode_a
    models_to_run[f"Proposed (CNN-Osc-SNN Delta={osc_delta_mode_a})"] = CNNOscSNN(osc_config_a, pretrained_backbone=True)


    # --- Train and Evaluate Models ---
    for model_name, model_instance in models_to_run.items():
        print(f"\n--- Training {model_name} for {config_name} ---")
        start_time = time.time()
        # Use a fresh copy for each training run (already done by creating new instances above)
        current_model = model_instance
        current_config = config # Default config for non-oscillator models
        if "Osc-SNN Delta=" in model_name:
             delta_val = float(model_name.split('=')[-1].split(')')[0])
             if delta_val == osc_delta_mode_a: current_config = osc_config_a
             elif delta_val == osc_delta_mode_b: current_config = osc_config_b
             else: print(f"Warning: Could not match delta for {model_name}")


        try:
            # 修改: 添加对spike_counts和best_epoch的接收
            best_acc, epochs_history, spikes_at_convergence, spike_counts = train_and_evaluate_with_history(
                current_model, train_loader, test_loader, current_config # Pass correct config
            )
            end_time = time.time()
            training_time = end_time - start_time
            
            # 计算best_epoch (0-indexed)
            best_epoch = epochs_history.index(best_acc) if best_acc in epochs_history else 0
            
            print(f"--- {model_name} finished. Best Acc: {best_acc:.2f}%, "
                 f"Spikes at Convergence (Epoch {best_epoch+1}): {spikes_at_convergence:.0f}, "
                 f"Time: {training_time:.2f}s, Epochs Trained: {len(epochs_history)} ---")
                 
            # 修改: 更新logger.log_model_result调用
            logger.log_model_result(
                config_name, model_name, best_acc, training_time, 
                epochs_history, spikes_at_convergence, spike_counts, best_epoch
            )
        except Exception as e:
            end_time = time.time()
            training_time = end_time - start_time
            print(f"!!! ERROR during training/evaluation for {model_name} on {config_name}: {e}")
            # 修改: 更新错误情况下的logger.log_model_result调用
            logger.log_model_result(config_name, model_name, 0.0, training_time, [], 0.0, [], 0)
            import traceback
            traceback.print_exc()
            # Decide whether to continue with other models upon error
            # continue


    # --- Finalize and Save Results ---
    print("\n--- All models processed for this configuration ---")
    logger.save_results()
    summary_df = logger.generate_summary_table()
    # try:
    #     logger.plot_results() # Optional plotting
    # except Exception as e:
    #     print(f"An error occurred during final plotting: {e}")
    #     import traceback
    #     traceback.print_exc()

    print("\n" + "="*60)
    print(f"实验完成! 结果保存在 '{logger.output_dir}' 目录中")
    print("="*60 + "\n")
    if not summary_df.empty:
        print("结果汇总:")
        # Configure pandas for wider output
        pd.set_option('display.max_rows', None)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 2000)
        pd.set_option('display.max_colwidth', None) # Show full column width
        print(summary_df.to_string(index=False)) # Use to_string for better control
        pd.reset_option('all') # Reset pandas display options
    else:
        print("结果汇总为空.")

    return logger

# --- Run Experiment ---
if __name__ == "__main__":
    # Optional: Set seeds for reproducibility
    # seed = 42
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(seed)
    # # Note: Full determinism can impact performance and might not be guaranteed on GPU
    # # torch.backends.cudnn.deterministic = True
    # # torch.backends.cudnn.benchmark = False

    logger = run_experiment()

使用设备: cuda

开始实验配置: TinyImageNet_ResNet-18_pretrained
Dataset: TinyImageNet (Root: ./tiny-imagenet-200, Classes: 200, Input Size: 224)
Backbone: ResNet-18_pretrained
Max Epochs: 200, Patience: 15, LR: 0.0001
Osc Params: alpha=2.0, beta=0.1, gamma=0.1, omega=1.0, dt=0.05
Lorenz Params: sigma=10.00, rho=28.00, beta=2.67, dt=0.05
SNN Params: steps=5, decay_beta=0.95, chaos_dim=512
Tiny ImageNet - Found 100000 training images belonging to 200 classes.
Tiny ImageNet - Found 10000 validation images.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.
Loaded PRETRAINED ResNet-18 weights.

--- Training Baseline (CNN-ANN) for TinyImageNet_ResNet-18_pretrained ---
--- Starting Training (Max Epochs: 200, Patience: 15) ---
Epoch [1/200] Loss: 2.9028 Train Acc: 32.86% Test Acc: 52.59% Total Spikes: 0 Epoch Time: 51.71s LR: 1.0e-05
  -> New best test accuracy: 52.59%. Model saved.
Epoch [2/200] Loss: 1.65