In [None]:
from scipy.io import loadmat
import h5py
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import os
from sklearn.metrics.pairwise import cosine_similarity

from matplotlib.backends.backend_pdf import PdfPages


In [None]:
def filter_test_MUA_by_trails(test_MUA_oracle, image_trail_to_include):
    filtered_data = []
    labels = []
    
    for image_idx, trail_indices in image_trail_to_include.items():
        image_data = test_MUA_oracle[:, image_idx, :]  
        
        for trail_idx in trail_indices:
            if trail_idx < image_data.shape[0]:
                filtered_data.append(image_data[trail_idx, :])  
                labels.append(image_idx)  
                
    filtered_test_MUA = np.array(filtered_data)
    labels = np.array(labels)
    
    print(f"筛选后的数据形状: {filtered_test_MUA.shape}")
    print(f"标签数量: {len(labels)}")
    print(f"包含的图像数量: {len(image_trail_to_include)}")
    
    return filtered_test_MUA, labels

In [38]:
file_path = "/media/ubuntu/sda/Monkey/data/THINGS_normMUA_MonkeyF.mat"
with h5py.File(file_path, 'r') as file:
    print(list(file.keys()))
    oracle = file['oracle'][:]
    test_MUA = file['test_MUA_reps'][:]
    train_MUA = file['train_MUA'][:]
    reliab = file['reliab'][:]
    SNR = file['SNR_max'][:]
    lats = file['lats'][:]

lats = lats.mean(axis=0)
channel_inf = pd.DataFrame(oracle, columns=['oracle'])
channel_inf['region'] = 'V1'
channel_inf.loc[512:832, 'region'] = 'V4'
channel_inf.loc[832:, 'region'] = 'IT'

channel_inf['reliab'] = True
channel_inf.loc[channel_inf['oracle'] < 0.6, 'reliab'] = False
channel_inf['SNR'] = SNR
channel_inf['lats'] = lats

test_MUA_oracle = test_MUA[:, :, channel_inf[channel_inf['reliab'] == True].index]
train_MUA_oracle = train_MUA[:, channel_inf[channel_inf['reliab'] == True].index]

similarity_matrices = np.zeros((100, 30, 30))
similarity_list = []
for image_idx in range(100):
    image_data = test_MUA_oracle[:, image_idx, :]
    
    similarity_matrix = cosine_similarity(image_data)
    similarity_matrices[image_idx] = similarity_matrix
    similarity_list.append((np.sum(similarity_matrix) - 30) / (30 * 29))

similarity_list = np.array(similarity_list)

image_to_include = np.where(similarity_list >= 0.5)[0]
image_trail_to_include = {}
for image in image_to_include:
    temp = similarity_matrices[image].mean(axis=0)
    image_trail_to_include[image] = np.where(temp > 0.5)[0]

filtered_test_MUA, filtered_labels = filter_test_MUA_by_trails(test_MUA_oracle, image_trail_to_include)

['SNR', 'SNR_max', 'lats', 'oracle', 'reliab', 'tb', 'test_MUA', 'test_MUA_reps', 'train_MUA']
筛选后的数据形状: (2549, 503)
标签数量: 2549
包含的图像数量: 91


In [39]:
vmin = channel_inf['oracle'].min()
vmax = channel_inf['oracle'].max()

groups = channel_inf.groupby('region')

with PdfPages('/media/ubuntu/sda/Monkey/figure/region_heatmaps_oracle_monkeyF.pdf') as pdf:
    for region_name, group_data in groups:
        n_blocks = len(group_data) // 64
        
        n_cols = min(4, n_blocks) 
        n_rows = (n_blocks + n_cols - 1) // n_cols 
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 2.5*n_rows + 0.5))
        fig.suptitle(f'Heatmaps for {region_name}', fontsize=16)
        
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        axes_flat = axes.flatten()
        
        for i in range(n_blocks):
            block_data = group_data.iloc[i*64:(i+1)*64]['oracle'].values
            heatmap_data = block_data.reshape(8, 8)
            
            im = axes_flat[i].imshow(heatmap_data, cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes_flat[i].set_xticks([])
            axes_flat[i].set_yticks([])
        
        for i in range(n_blocks, len(axes_flat)):
            axes_flat[i].set_visible(False)
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('Oracle Value', fontsize=10)
        
        plt.subplots_adjust(right=0.9, top=0.9 if n_rows > 1 else 0.85)
        
        pdf.savefig(fig, bbox_inches='tight')
        plt.close()

vmin = channel_inf['SNR'].min()
vmax = channel_inf['SNR'].max()
with PdfPages('/media/ubuntu/sda/Monkey/figure/region_heatmaps_SNR_monkeyF.pdf') as pdf:
    for region_name, group_data in groups:
        n_blocks = len(group_data) // 64
        
        n_cols = min(4, n_blocks) 
        n_rows = (n_blocks + n_cols - 1) // n_cols 
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 2.5*n_rows + 0.5))
        fig.suptitle(f'Heatmaps for {region_name}', fontsize=16)
        
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        axes_flat = axes.flatten()
        
        for i in range(n_blocks):
            block_data = group_data.iloc[i*64:(i+1)*64]['SNR'].values
            heatmap_data = block_data.reshape(8, 8)
            
            im = axes_flat[i].imshow(heatmap_data, cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes_flat[i].set_xticks([])
            axes_flat[i].set_yticks([])
        
        for i in range(n_blocks, len(axes_flat)):
            axes_flat[i].set_visible(False)
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('Oracle Value', fontsize=10)
        
        plt.subplots_adjust(right=0.9, top=0.9 if n_rows > 1 else 0.85)
        
        pdf.savefig(fig, bbox_inches='tight')
        plt.close()

vmin = channel_inf['lats'].min()
vmax = channel_inf['lats'].max()
with PdfPages('/media/ubuntu/sda/Monkey/figure/region_heatmaps_lats_monkeyF.pdf') as pdf:
    for region_name, group_data in groups:
        n_blocks = len(group_data) // 64
        
        n_cols = min(4, n_blocks) 
        n_rows = (n_blocks + n_cols - 1) // n_cols 
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 2.5*n_rows + 0.5))
        fig.suptitle(f'Heatmaps for {region_name}', fontsize=16)
        
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        axes_flat = axes.flatten()
        
        for i in range(n_blocks):
            block_data = group_data.iloc[i*64:(i+1)*64]['lats'].values
            heatmap_data = block_data.reshape(8, 8)
            
            im = axes_flat[i].imshow(heatmap_data, cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes_flat[i].set_xticks([])
            axes_flat[i].set_yticks([])
        
        for i in range(n_blocks, len(axes_flat)):
            axes_flat[i].set_visible(False)
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('Oracle Value', fontsize=10)
        
        plt.subplots_adjust(right=0.9, top=0.9 if n_rows > 1 else 0.85)
        
        pdf.savefig(fig, bbox_inches='tight')
        plt.close()

In [19]:
np.save('/media/ubuntu/sda/Monkey/data/filtered_test_MUA_MonkeyF.npy', filtered_test_MUA)
np.save('/media/ubuntu/sda/Monkey/data/filtered_labels_MonkeyF.npy', filtered_labels)
np.save('/media/ubuntu/sda/Monkey/data/train_MUA_MonkeyF.npy', train_MUA_oracle)


In [33]:
file_path = "/media/ubuntu/sda/Monkey/data/THINGS_normMUA_MonkeyN.mat"
with h5py.File(file_path, 'r') as file:
    print(list(file.keys()))
    oracle = file['oracle'][:]
    test_MUA = file['test_MUA_reps'][:]
    train_MUA = file['train_MUA'][:]
    reliab = file['reliab'][:]
    SNR = file['SNR_max'][:]
    lats = file['lats'][:]

lats = lats.mean(axis=0)
channel_inf = pd.DataFrame(oracle, columns=['oracle'])
channel_inf['region'] = 'V1'
channel_inf.loc[512:768, 'region'] = 'V4'
channel_inf.loc[768:, 'region'] = 'IT'

channel_inf['reliab'] = True
channel_inf.loc[channel_inf['oracle'] < 0.6, 'reliab'] = False
channel_inf['SNR'] = SNR
channel_inf['lats'] = lats

test_MUA_oracle = test_MUA[:, :, channel_inf[channel_inf['reliab'] == True].index]
train_MUA_oracle = train_MUA[:, channel_inf[channel_inf['reliab'] == True].index]


similarity_matrices = np.zeros((100, 30, 30))
similarity_list = []
for image_idx in range(100):
    image_data = test_MUA_oracle[:, image_idx, :]
    
    similarity_matrix = cosine_similarity(image_data)
    similarity_matrices[image_idx] = similarity_matrix
    similarity_list.append((np.sum(similarity_matrix) - 30) / (30 * 29))

similarity_list = np.array(similarity_list)

image_to_include = np.where(similarity_list >= 0.5)[0]
image_trail_to_include = {}
for image in image_to_include:
    temp = similarity_matrices[image].mean(axis=0)
    image_trail_to_include[image] = np.where(temp > 0.5)[0]

filtered_test_MUA, filtered_labels = filter_test_MUA_by_trails(test_MUA_oracle, image_trail_to_include)


['SNR', 'SNR_max', 'lats', 'oracle', 'reliab', 'tb', 'test_MUA', 'test_MUA_reps', 'train_MUA']
筛选后的数据形状: (2834, 669)
标签数量: 2834
包含的图像数量: 98


In [49]:
n_trials, n_images, n_neurons = test_MUA_oracle.shape

max_pref_images = [] 
min_pref_images = []  

with PdfPages('/media/ubuntu/sda/Monkey/figure/neuron_scatter_plots_monkeyN.pdf') as pdf:
    for neuron_idx in range(n_neurons):
        neuron_data = test_MUA_oracle[:, :, neuron_idx]  
        image_means = np.mean(neuron_data, axis=0)  
        
        sorted_indices = np.argsort(image_means)
        
        min_pref_image = sorted_indices[0]  
        max_pref_image = sorted_indices[-1] 
        
        max_pref_images.append(max_pref_image)
        min_pref_images.append(min_pref_image)
        
        sorted_means = image_means[sorted_indices]
        
        all_values = []
        all_ranks = []
        
        for img_idx in range(n_images):
            img_values = neuron_data[:, img_idx]
            img_rank = np.where(sorted_indices == img_idx)[0][0]
            all_values.extend(img_values)
            all_ranks.extend([img_rank] * n_trials)
        
        fig, ax = plt.subplots(figsize=(6, 4))
        
        scatter = ax.scatter(all_ranks, all_values, alpha=0.6, s=25, color='gray')
        ax.set_xticks([])
        ax.set_xticklabels([])
        ax.legend('')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        
        plt.tight_layout()
        
        pdf.savefig(fig)
        plt.close(fig)  
        
        if (neuron_idx + 1) % 50 == 0:
            print(f"Processed {neuron_idx + 1}/{n_neurons} neurons...")

print(f"PDF created with {n_neurons} scatter plots")

Processed 50/503 neurons...
Processed 100/503 neurons...
Processed 150/503 neurons...
Processed 200/503 neurons...
Processed 250/503 neurons...
Processed 300/503 neurons...
Processed 350/503 neurons...
Processed 400/503 neurons...
Processed 450/503 neurons...
Processed 500/503 neurons...
PDF created with 503 scatter plots


In [52]:
min_pref_images[13]

np.int64(2)

In [53]:
max_pref_images[13]

np.int64(20)

In [37]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns

vmin = channel_inf['oracle'].min()
vmax = channel_inf['oracle'].max()

groups = channel_inf.groupby('region')

with PdfPages('/media/ubuntu/sda/Monkey/figure/region_heatmaps_oracle_monkeyN.pdf') as pdf:
    for region_name, group_data in groups:
        n_blocks = len(group_data) // 64
        
        n_cols = min(4, n_blocks) 
        n_rows = (n_blocks + n_cols - 1) // n_cols 
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 2.5*n_rows + 0.5))
        fig.suptitle(f'Heatmaps for {region_name}', fontsize=16)
        
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        axes_flat = axes.flatten()
        
        for i in range(n_blocks):
            block_data = group_data.iloc[i*64:(i+1)*64]['oracle'].values
            heatmap_data = block_data.reshape(8, 8)
            
            im = axes_flat[i].imshow(heatmap_data, cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes_flat[i].set_xticks([])
            axes_flat[i].set_yticks([])
        
        for i in range(n_blocks, len(axes_flat)):
            axes_flat[i].set_visible(False)
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('Oracle Value', fontsize=10)
        
        plt.subplots_adjust(right=0.9, top=0.9 if n_rows > 1 else 0.85)
        
        pdf.savefig(fig, bbox_inches='tight')
        plt.close()

vmin = channel_inf['SNR'].min()
vmax = channel_inf['SNR'].max()
with PdfPages('/media/ubuntu/sda/Monkey/figure/region_heatmaps_SNR_monkeyN.pdf') as pdf:
    for region_name, group_data in groups:
        n_blocks = len(group_data) // 64
        
        n_cols = min(4, n_blocks) 
        n_rows = (n_blocks + n_cols - 1) // n_cols 
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 2.5*n_rows + 0.5))
        fig.suptitle(f'Heatmaps for {region_name}', fontsize=16)
        
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        axes_flat = axes.flatten()
        
        for i in range(n_blocks):
            block_data = group_data.iloc[i*64:(i+1)*64]['SNR'].values
            heatmap_data = block_data.reshape(8, 8)
            
            im = axes_flat[i].imshow(heatmap_data, cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes_flat[i].set_xticks([])
            axes_flat[i].set_yticks([])
        
        for i in range(n_blocks, len(axes_flat)):
            axes_flat[i].set_visible(False)
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('Oracle Value', fontsize=10)
        
        plt.subplots_adjust(right=0.9, top=0.9 if n_rows > 1 else 0.85)
        
        pdf.savefig(fig, bbox_inches='tight')
        plt.close()

vmin = channel_inf['lats'].min()
vmax = channel_inf['lats'].max()
with PdfPages('/media/ubuntu/sda/Monkey/figure/region_heatmaps_lats_monkeyN.pdf') as pdf:
    for region_name, group_data in groups:
        n_blocks = len(group_data) // 64
        
        n_cols = min(4, n_blocks) 
        n_rows = (n_blocks + n_cols - 1) // n_cols 
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 2.5*n_rows + 0.5))
        fig.suptitle(f'Heatmaps for {region_name}', fontsize=16)
        
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        axes_flat = axes.flatten()
        
        for i in range(n_blocks):
            block_data = group_data.iloc[i*64:(i+1)*64]['lats'].values
            heatmap_data = block_data.reshape(8, 8)
            
            im = axes_flat[i].imshow(heatmap_data, cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes_flat[i].set_xticks([])
            axes_flat[i].set_yticks([])
        
        for i in range(n_blocks, len(axes_flat)):
            axes_flat[i].set_visible(False)
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  
        cbar = fig.colorbar(im, cax=cbar_ax)
        cbar.set_label('Oracle Value', fontsize=10)
        
        plt.subplots_adjust(right=0.9, top=0.9 if n_rows > 1 else 0.85)
        
        pdf.savefig(fig, bbox_inches='tight')
        plt.close()

In [23]:
np.save('/media/ubuntu/sda/Monkey/data/filtered_test_MUA_MonkeyN.npy', filtered_test_MUA)
np.save('/media/ubuntu/sda/Monkey/data/filtered_labels_MonkeyN.npy', filtered_labels)
np.save('/media/ubuntu/sda/Monkey/data/train_MUA_MonkeyN.npy', train_MUA_oracle)


In [24]:
# 构建分类网络的Dataset类
from random import shuffle
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np

class MUAClassificationDataset(Dataset):
    def __init__(self, mua_data, labels, transform=None):
        """
        MUA分类数据集
        
        Args:
            mua_data: MUA数据，形状为 (n_samples, n_channels)
            labels: 标签列表
            transform: 数据变换（可选）
        """
        self.mua_data = torch.tensor(mua_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform
        
    def __len__(self):
        return len(self.mua_data)
    
    def __getitem__(self, idx):
        sample = self.mua_data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample, label

def create_label_mapping(original_labels, num_classes=91):
    """
    创建标签重新映射
    
    Args:
        original_labels: 原始标签数组
        num_classes: 目标类别数量
    
    Returns:
        mapped_labels: 重新映射后的标签
        label_mapping: 标签映射字典
    """
    unique_labels = np.unique(original_labels)
    print(f"原始标签数量: {len(unique_labels)}")
    
    # 创建映射字典
    label_mapping = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}
    
    # 重新映射标签
    mapped_labels = np.array([label_mapping[label] for label in original_labels])
    
    print(f"映射后标签数量: {len(np.unique(mapped_labels))}")
    print(f"标签映射范围: {mapped_labels.min()} - {mapped_labels.max()}")
    
    return mapped_labels, label_mapping

# 创建数据集
print("=== 创建分类数据集 ===")
mapped_labels, label_mapping = create_label_mapping(filtered_labels, num_classes=91)

# 分割训练和测试集
train_indices, test_indices = train_test_split(
    range(len(filtered_test_MUA)), 
    test_size=0.2, 
    random_state=42, 
    stratify=mapped_labels
)

train_dataset = MUAClassificationDataset(
    filtered_test_MUA[train_indices], 
    mapped_labels[train_indices]
)

test_dataset = MUAClassificationDataset(
    filtered_test_MUA[test_indices], 
    mapped_labels[test_indices]
)



print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"特征维度: {filtered_test_MUA.shape[1]}")
print(f"类别数量: {len(np.unique(mapped_labels))}")


=== 创建分类数据集 ===
原始标签数量: 98
映射后标签数量: 98
标签映射范围: 0 - 97
训练集大小: 2267
测试集大小: 567
特征维度: 669
类别数量: 98


In [25]:
# 构建4层ResNet MLP模型
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """ResNet风格的残差块"""
    def __init__(self, input_dim, output_dim, dropout_rate=0.1):
        super(ResidualBlock, self).__init__()
        
        self.linear1 = nn.Linear(input_dim, output_dim)
        self.bn1 = nn.BatchNorm1d(output_dim)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.linear2 = nn.Linear(output_dim, output_dim)
        self.bn2 = nn.BatchNorm1d(output_dim)
        self.dropout2 = nn.Dropout(dropout_rate)
        
        # 残差连接
        self.shortcut = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()
        
    def forward(self, x):
        residual = self.shortcut(x)
        
        out = F.relu(self.bn1(self.linear1(x)))
        out = self.dropout1(out)
        out = self.bn2(self.linear2(out))
        out = self.dropout2(out)
        
        out += residual
        out = F.relu(out)
        
        return out

class ResNetMLPClassifier(nn.Module):
    """4层ResNet MLP分类器"""
    def __init__(self, input_dim, hidden_dims=[512, 1024, 1024, 1024], num_classes=91, dropout_rate=0.1):
        super(ResNetMLPClassifier, self).__init__()
        
        self.input_dim = input_dim
        self.num_classes = num_classes
        
        # 输入投影层
        self.input_projection = nn.Linear(input_dim, hidden_dims[0])
        self.input_bn = nn.BatchNorm1d(hidden_dims[0])
        
        # 4个ResNet块
        self.resnet_blocks = nn.ModuleList()
        for i in range(len(hidden_dims) - 1):
            self.resnet_blocks.append(
                ResidualBlock(hidden_dims[i], hidden_dims[i+1], dropout_rate)
            )
        
        # 最终特征层（1024维）
        self.feature_layer = nn.Linear(hidden_dims[-1], 1024)
        self.feature_bn = nn.BatchNorm1d(1024)
        self.feature_dropout = nn.Dropout(dropout_rate)
        
        # 分类器层
        self.classifier = nn.Linear(1024, num_classes)
        
        # 初始化权重
        self._initialize_weights()
    
    def _initialize_weights(self):
        """初始化网络权重"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # 输入投影
        x = F.relu(self.input_bn(self.input_projection(x)))
        
        # 通过ResNet块
        for block in self.resnet_blocks:
            x = block(x)
        
        # 特征提取（1024维）
        features = F.relu(self.feature_bn(self.feature_layer(x)))
        features = self.feature_dropout(features)
        
        # 分类
        logits = self.classifier(features)
        
        return logits, features
    
    def get_features(self, x):
        """只返回特征，不进行分类"""
        with torch.no_grad():
            x = F.relu(self.input_bn(self.input_projection(x)))
            
            for block in self.resnet_blocks:
                x = block(x)
            
            features = F.relu(self.feature_bn(self.feature_layer(x)))
            return features

# 创建模型
input_dim = filtered_test_MUA.shape[1]  # 特征维度
num_classes = len(np.unique(mapped_labels))  # 实际类别数量




In [26]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # 前向传播
        logits, features = model(data)
        loss = criterion(logits, target)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    """评估模型"""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    all_features = []
    
    
    for idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        logits, features = model(data)
        loss = criterion(logits, target)
        
        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # 收集预测结果和特征
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())
        all_features.extend(features.cpu().numpy())
        
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy, all_predictions, all_targets, all_features

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
model = ResNetMLPClassifier(
    input_dim=input_dim,
    hidden_dims=[512, 1024, 1024, 1024],
    num_classes=num_classes,
    dropout_rate=0.1
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)


train_loader = DataLoader(train_dataset, shuffle = True, batch_size= 128)
val_loader = DataLoader(test_dataset, shuffle = False, batch_size= 128)

num_epochs = 20
best_val_acc = 0.0
train_losses, train_accs = [], []
val_losses, val_accs = [], []


for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 50)
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    
    val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'mua_test_classifier.pth')
        print(f"save best model acc: {val_acc:.2f}%")
    
    print(f"Train loss: {train_loss:.4f}, Train acc: {train_acc:.2f}%")
    print(f"Test loss: {val_loss:.4f}, Test acc: {val_acc:.2f}%")



Epoch 1/20
--------------------------------------------------
save best model acc: 96.12%
Train loss: 1.8695, Train acc: 69.87%
Test loss: 0.1454, Test acc: 96.12%

Epoch 2/20
--------------------------------------------------
Train loss: 0.0846, Train acc: 97.75%
Test loss: 0.2042, Test acc: 96.12%

Epoch 3/20
--------------------------------------------------
save best model acc: 98.06%
Train loss: 0.0956, Train acc: 97.88%
Test loss: 0.0904, Test acc: 98.06%

Epoch 4/20
--------------------------------------------------
Train loss: 0.0502, Train acc: 99.03%
Test loss: 0.2623, Test acc: 96.47%

Epoch 5/20
--------------------------------------------------
Train loss: 0.0287, Train acc: 99.21%
Test loss: 0.1457, Test acc: 96.65%

Epoch 6/20
--------------------------------------------------
save best model acc: 98.59%
Train loss: 0.0557, Train acc: 98.99%
Test loss: 0.0863, Test acc: 98.59%

Epoch 7/20
--------------------------------------------------
Train loss: 0.0511, Train acc: 