<a href="https://colab.research.google.com/github/chengpeip/colab_design_for_graduate/blob/main/levit_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install mne

Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


In [4]:
import gc
import os
import re
import mne
import numpy as np
import torch
from scipy.signal import stft
from skimage.transform import resize
from tqdm import tqdm

In [5]:
class EDFPretransformToDisk:
    def __init__(self, edf_folder, pte_folder, output_img_size=224, nperseg=256, target_sfreq=250, out_dir="/content/processed"):
        self.edf_folder = edf_folder
        self.pte_folder = pte_folder
        self.output_img_size = output_img_size
        self.nperseg = nperseg
        self.target_sfreq = target_sfreq
        self.out_dir = out_dir
        self.channel_names = None
        self.window_duration = 2  # 时间窗口长度（秒）
        self.dtype = np.float32
        self.edf_pattern = re.compile(r'^([hs])(\d+)\.edf$', re.IGNORECASE)
        self.pte_pattern = re.compile(r'^([hs])(\d+)_pte_window_2s\.npy$', re.IGNORECASE)
        if not os.path.exists(self.out_dir):
            os.makedirs(self.out_dir)

    def _match_files(self):
        edf_files = sorted(
            [f for f in os.listdir(self.edf_folder) if self.edf_pattern.match(f)],
            key=lambda x: int(self.edf_pattern.match(x).group(2))
        )
        valid_pairs = []
        for edf_file in edf_files:
            match = self.edf_pattern.match(edf_file)
            if not match:
                continue
            prefix = match.group(1).lower()
            number = match.group(2)
            pte_file = f"{prefix}{number}_pte_window_2s.npy"
            pte_path = os.path.join(self.pte_folder, pte_file)
            if os.path.exists(pte_path):
                valid_pairs.append((
                    os.path.join(self.edf_folder, edf_file),
                    pte_path,
                    prefix
                ))
            else:
                print(f"警告: 未找到 {edf_file} 对应的PTE窗口文件 {pte_file}")
        return valid_pairs

    def _process_edf(self, edf_path):
        try:
            raw = mne.io.read_raw_edf(edf_path, preload=True)
            if raw.info['sfreq'] != self.target_sfreq:
                raw = raw.resample(self.target_sfreq)
            return raw
        except Exception as e:
            print(f"EDF文件处理失败 {edf_path}: {str(e)}")
            return None

    def _segment_data(self, data):
        window_samples = int(self.window_duration * self.target_sfreq)
        n_samples = data.shape[1]
        n_windows = n_samples // window_samples
        truncated = data[:, :n_windows*window_samples]
        return truncated.reshape(data.shape[0], n_windows, window_samples).transpose(1, 0, 2)

    def _generate_spectrogram(self, data, sfreq):
        f, t, Zxx = stft(data.astype(self.dtype), fs=sfreq, nperseg=self.nperseg)
        Zxx = np.abs(Zxx, dtype=self.dtype)
        np.log10(Zxx + 1e-6, out=Zxx)
        np.multiply(Zxx, 10, out=Zxx)
        z_min = Zxx.min()
        z_max = Zxx.max()
        del f, t
        Zxx = (Zxx - z_min) / (z_max - z_min + 1e-6)
        img = resize(Zxx, (self.output_img_size, self.output_img_size),
                     mode='reflect', anti_aliasing=True)
        return img.astype(self.dtype)

    def process_all_files(self):
        file_pairs = self._match_files()
        sample_idx = 0
        labels = []
        for edf_path, pte_path, prefix in file_pairs:
            raw = self._process_edf(edf_path)
            if raw is None:
                continue
            if self.channel_names is None:
                self.channel_names = raw.ch_names
                print(f"检测到 {len(self.channel_names)} 个通道: {self.channel_names}")
            data = raw.get_data()  # [n_channels, n_samples]
            data_windows = self._segment_data(data)  # [n_windows, n_channels, window_samples]
            try:
                pte_windows = np.load(pte_path, mmap_mode='r')
                assert pte_windows.ndim == 3
                assert pte_windows.shape[1:] == (len(self.channel_names), len(self.channel_names))
                assert pte_windows.shape[0] == data_windows.shape[0]
            except Exception as e:
                print(f"PTE加载失败 {pte_path}: {str(e)}")
                continue
            label = 1 if prefix == 's' else 0
            for win_idx in range(data_windows.shape[0]):
                window_images = []
                for ch_idx in range(data_windows.shape[1]):
                    ch_data = data_windows[win_idx, ch_idx, :]
                    img = self._generate_spectrogram(ch_data, self.target_sfreq)
                    window_images.append(img)
                sample_img = np.stack(window_images)  # [19, output_img_size, output_img_size]
                sample_pte = pte_windows[win_idx]       # [19, 19]
                # 保存到磁盘
                np.save(os.path.join(self.out_dir, f"sample_{sample_idx}_img.npy"), sample_img)
                np.save(os.path.join(self.out_dir, f"sample_{sample_idx}_pte.npy"), sample_pte)
                labels.append(label)
                sample_idx += 1

            del raw, data, data_windows
            gc.collect()
        # 保存标签数组
        np.save(os.path.join(self.out_dir, "labels.npy"), np.array(labels))
        print(f"处理完成: 总样本数 {sample_idx}")





In [11]:
# 使用示例
edf_preprocessor = EDFPretransformToDisk(
    edf_folder="/content",
    pte_folder="/content",
    output_img_size=224,
    nperseg=256,
    target_sfreq=250,
    out_dir="/content/processed"
)

In [12]:
edf_preprocessor.process_all_files()

Extracting EDF parameters from /content/h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
检测到 19 个通道: ['Fp2', 'F8', 'T4', 'T6', 'O2', 'Fp1', 'F7', 'T3', 'T5', 'O1', 'F4', 'C4', 'P4', 'F3', 'C3', 'P3', 'Fz', 'Cz', 'Pz']
Extracting EDF parameters from /content/s01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 211249  =      0.000 ...   844.996 secs...
Extracting EDF parameters from /content/s02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 286249  =      0.000 ...  1144.996 secs...
Extracting EDF parameters from /content/h02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 227499  =      0.000 ...   909.996 secs...
Extracting EDF parameters from /content/h03.edf...
EDF file detected
Setting channel info structure...
Crea

ValueError: too many values to unpack (expected 3)

In [15]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import timm
from timm.data import resolve_model_data_config, create_transform
import torch.nn as nn
import torchvision.transforms as T
from tqdm import tqdm

class ProcessedEEGDataset(Dataset):
    def __init__(self, processed_dir, model_name="levit_128.fb_dist_in1k", projector=None):
        """
        Args:
            processed_dir (str): 存放预处理好样本的文件夹路径。该文件夹下应有:
                - 若干个 EEG 时频图像文件，命名格式：sample_{idx}_img.npy，
                  形状为 [19, H, W]（原始19通道时频图像）。
                - 若干个 PTE 矩阵文件，命名格式：sample_{idx}_pte.npy，
                  形状为 [19, 19].
                - labels.npy: 一个包含所有样本标签的数组，形状为 [N]（N = 受试者数×窗口数）。
            model_name (str): 用于解析预处理配置的模型名称。
            projector (nn.Module, optional): 将 19 通道转换为 3 通道的模块（如 ChannelProjector 实例）。
        """
        self.processed_dir = processed_dir
        # 获取所有 EEG 图像文件和 PTE 矩阵文件，按文件名排序（确保 sample_{idx} 顺序一致）
        self.img_files = sorted([f for f in os.listdir(processed_dir) if f.endswith('_img.npy')])
        self.pte_files = sorted([f for f in os.listdir(processed_dir) if f.endswith('_pte.npy')])
        self.labels = np.load(os.path.join(processed_dir, "labels.npy"))
        self.projector = projector

        # 获取 LeViT 预处理配置：例如 input_size, mean, std 等
        self.data_config = resolve_model_data_config(model_name)
        # 创建预处理流水线（is_training=False 表示不进行随机数据增强）
        self.transforms = create_transform(**self.data_config, is_training=False)

        # 检查数据长度是否一致
        assert len(self.img_files) == len(self.pte_files) == len(self.labels), \
            "EEG图像文件、PTE文件和标签数量不一致！"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # 读取 EEG 时频图像文件，形状原为 [19, H, W]
        img_path = os.path.join(self.processed_dir, self.img_files[idx])
        img = np.load(img_path)  # NumPy 数组，形状 [19, H, W]
        # 读取 PTE 矩阵文件，形状为 [19, 19]
        pte_path = os.path.join(self.processed_dir, self.pte_files[idx])
        pte_matrix = np.load(pte_path)
        # 读取标签
        label = int(self.labels[idx])

        # 转换为 PyTorch Tensor
        img = torch.FloatTensor(img)
        pte_matrix = torch.FloatTensor(pte_matrix)
        label = torch.tensor(label, dtype=torch.long)

        # 如果提供了 projector，将 19 通道转换为 3 通道
        if self.projector is not None:
            # 添加 batch 维度：从 (19, H, W) -> (1, 19, H, W)
            img = img.unsqueeze(0)
            # 应用 1x1 卷积模块
            img = self.projector(img)   # 输出 (1, 3, H, W)
            # 移除 batch 维度 -> (3, H, W)
            img = img.squeeze(0)

        # 应用 LeViT 所需的预处理：resize、归一化等
        img = self.transforms(img)

        return {'eeg_img': img, 'pte_matrix': pte_matrix, 'label': label}

# ---------------------------
# 示例：定义 1x1 卷积模块（ChannelProjector）
# ---------------------------
class ChannelProjector(nn.Module):
    def __init__(self, in_channels=19, out_channels=3):
        super(ChannelProjector, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

# ---------------------------
# 示例使用
# ---------------------------
# if __name__ == "__main__":
#     # 假设你的处理后数据存放在 "/content/processed" 目录下
#     processed_dir = "/content/processed"
#     # 创建 ChannelProjector，将 19 通道转换为 3 通道
#     projector = ChannelProjector(in_channels=19, out_channels=3)
#     # 创建数据集
#     dataset = ProcessedEEGDataset(processed_dir, model_name="levit_128.fb_dist_in1k", projector=projector)
#     print("Dataset length:", len(dataset))

#     # 测试 DataLoader 读取
#     from torch.utils.data import DataLoader
#     loader = DataLoader(dataset, batch_size=8, shuffle=True, pin_memory=True)
#     for batch in loader:
#         eeg_img = batch['eeg_img']         # 期望形状: (batch, 3, H, W)
#         pte_matrix = batch['pte_matrix']     # 期望形状: (batch, 19, 19)
#         label = batch['label']               # 形状: (batch,)
#         print("EEG Image shape:", eeg_img.shape)
#         print("PTE Matrix shape:", pte_matrix.shape)
#         print("Label shape:", label.shape)
#         break


In [20]:
# ---------------------------
# 1. 定义 EEG 图像编码器（基于 LeViT）
# ---------------------------


class EEGImageEncoder(nn.Module):
    def __init__(self, embed_dim=128):
        super(EEGImageEncoder, self).__init__()

        self.vit = timm.create_model('levit_128.fb_dist_in1k', pretrained=True, num_classes=0)

        # 将 ViT 的输出特征投影到目标维度 embed_dim
        self.fc = nn.Linear(self.vit.num_features, embed_dim)

    def forward(self, x):
        features = self.vit(x)
        # 投影到指定维度
        out = self.fc(features)
        return out


# ---------------------------
# 2. 定义 PTE 矩阵编码器（基于全连接网络）
# ---------------------------
class PTEEncoder(nn.Module):
    def __init__(self, num_channels=19, embed_dim=128):
        super(PTEEncoder, self).__init__()
        self.fc1 = nn.Linear(num_channels * num_channels, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, embed_dim)

    def forward(self, x):
        # x: (batch, num_channels, num_channels)
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # 扁平化
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# ---------------------------
# 3. 定义对比损失（InfoNCE / NT-Xent 损失的简单实现）
# ---------------------------
def contrastive_loss(z1, z2, temperature=0.07):
    # z1, z2: (batch, embed_dim)
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    batch_size = z1.size(0)
    # 计算相似度矩阵（余弦相似度）
    representations = torch.cat([z1, z2], dim=0)  # (2*batch, embed_dim)
    similarity_matrix = torch.matmul(representations, representations.T)  # (2*batch, 2*batch)
    # 去除自身相似度
    mask = torch.eye(2*batch_size, device=z1.device).bool()
    similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
    similarity_matrix /= temperature

    # 构造标签：同一原始样本的两个增强视角构成正对，标签为 i 与 i+batch_size 对应
    labels = torch.arange(batch_size, device=z1.device)
    loss_v2p = F.cross_entropy(similarity_matrix[:batch_size, batch_size:], labels)
    loss_p2v = F.cross_entropy(similarity_matrix[batch_size:, :batch_size], labels)
    return (loss_v2p + loss_p2v) / 2

# ---------------------------
# 4. 定义分类器
# ---------------------------
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, num_classes=2):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
    def forward(self, x):
        return self.net(x)

# ---------------------------
# 5. 定义跨模态对比学习模型
# ---------------------------
class CrossModalModel(nn.Module):
    def __init__(self, embed_dim=128, num_classes=2):
        super(CrossModalModel, self).__init__()
        self.eeg_encoder = EEGImageEncoder(embed_dim=embed_dim)
        self.pte_encoder = PTEEncoder(num_channels=19, embed_dim=embed_dim)
        # 分类器输入为拼接的两个模态特征
        self.classifier = Classifier(in_dim=2*embed_dim, num_classes=num_classes)

    def forward(self, eeg_img, pte_matrix):
        # eeg_img: (batch, 1, H, W)
        # pte_matrix: (batch, 19, 19)
        f_eeg = self.eeg_encoder(eeg_img)    # (batch, embed_dim)
        f_pte = self.pte_encoder(pte_matrix)   # (batch, embed_dim)
        fused = torch.cat([f_eeg, f_pte], dim=1)  # (batch, 2*embed_dim)
        logits = self.classifier(fused)        # (batch, num_classes)
        return f_eeg, f_pte, logits

In [17]:
# ---------------------------
# 7. 构造数据加载器
# ---------------------------
from torch.utils.data import random_split
import torch.utils.data as data
processed_dir = "/content/processed"
projector = ChannelProjector(in_channels=19, out_channels=3)
batch_size = 4
dataset = ProcessedEEGDataset(processed_dir, model_name="levit_128.fb_dist_in1k", projector=projector)

print("Dataset length:", len(dataset))
# 划分训练集和验证集（80%训练，20%验证）
train_ratio = 0.8
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


# 创建数据加载器
train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True  # 加速数据转移到GPU
)

val_loader = data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,    # 验证集不需要shuffle
    pin_memory=True
)



Dataset length: 14424


In [21]:
# ---------------------------
# 修改2：添加验证函数
# ---------------------------
def validate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in loader:
            eeg_img = batch['eeg_img'].to(device)
            pte_matrix = batch['pte_matrix'].to(device)
            labels = batch['label'].to(device)

            # 前向传播
            f_eeg, f_pte, logits = model(eeg_img, pte_matrix)

            # 计算损失
            loss_cls = criterion_cls(logits, labels)
            loss_contrast = contrastive_loss(f_eeg, f_pte)
            loss = loss_cls + alpha * loss_contrast

            # 统计指标
            total_loss += loss.item() * labels.size(0)
            total_correct += (logits.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    return avg_loss, accuracy

In [22]:

# ---------------------------
# 8. 训练流程（含准确率计算）
# ---------------------------
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# 初始化设备、模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrossModalModel(embed_dim=128, num_classes=2).to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
criterion_cls = nn.CrossEntropyLoss()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/37.1M [00:00<?, ?B/s]



In [23]:

# ---------------------------
# 修改3：优化训练循环（含早停和学习率调度）
# ---------------------------
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 添加学习率调度器
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',       # 监控验证准确率
    factor=0.5,       # 学习率衰减因子
    patience=2,       # 等待2个epoch无改进
    verbose=True
)

# 早停参数
best_val_acc = 0.0
patience_counter = 0
early_stop_patience = 4  # 连续4次无改进则停止

def contrastive_loss(z1, z2, temperature=0.07):
    """改进的对比损失函数"""
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)

    # 计算相似度矩阵
    sim_matrix = torch.mm(z1, z2.T) / temperature

    # 创建标签
    labels = torch.arange(z1.size(0), device=z1.device)

    # 计算交叉熵损失
    loss = F.cross_entropy(sim_matrix, labels)
    return loss

def calculate_accuracy(logits, labels):
    """计算分类准确率"""
    _, predicted = torch.max(logits, 1)
    correct = (predicted == labels).sum().item()
    return correct / labels.size(0)



# 训练参数
num_epochs = 10
alpha = 1.0  # 对比损失权重

# 训练循环（含准确率计算）
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    progress_bar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}")

    for batch in progress_bar:
        # 数据加载（确保与数据集返回格式一致）
        eeg_img = batch['eeg_img'].to(device)
        pte_matrix = batch['pte_matrix'].to(device)
        labels = batch['label'].to(device)

        # 前向传播
        optimizer.zero_grad()
        f_eeg, f_pte, logits = model(eeg_img, pte_matrix)

        # 损失计算
        loss_cls = criterion_cls(logits, labels)
        loss_contrast = contrastive_loss(f_eeg, f_pte)
        loss = loss_cls + alpha * loss_contrast

        # 反向传播
        loss.backward()
        optimizer.step()

        # 统计指标
        batch_size = labels.size(0)
        total_loss += loss.item() * batch_size
        total_correct += (logits.argmax(1) == labels).sum().item()
        total_samples += batch_size

        # 实时更新进度条显示
        progress_bar.set_postfix({
            'lr': optimizer.param_groups[0]['lr'],
            'loss': f"{loss.item():.4f}",
            'acc': f"{100 * total_correct / total_samples:.1f}%"
        })
    # 验证阶段
    val_loss, val_acc = validate(model, val_loader, device)

    # 学习率调度
    scheduler.step(val_acc)

    # 早停机制
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        # 保存最佳模型
        torch.save(model.state_dict(), "best_model.pth")
        print(f"保存最佳模型，验证准确率: {val_acc:.2f}%")
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\n早停触发！连续{early_stop_patience}个epoch验证准确率未提升")
            break

    train_loss = total_loss / total_samples
    train_acc = 100 * total_correct / total_samples
    # 打印报告
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"训练损失: {train_loss:.4f} | 训练准确率: {train_acc:.2f}%")
    print(f"验证损失: {val_loss:.4f} | 验证准确率: {val_acc:.2f}%")
    print("─" * 50)



Train Epoch 1: 100%|██████████| 2885/2885 [30:33<00:00,  1.57it/s, lr=0.0005, loss=0.7239, acc=59.8%]


保存最佳模型，验证准确率: 65.44%
Epoch 1/10
训练损失: 1.4001 | 训练准确率: 59.84%
验证损失: 1.5037 | 验证准确率: 65.44%
──────────────────────────────────────────────────


Train Epoch 2: 100%|██████████| 2885/2885 [30:34<00:00,  1.57it/s, lr=0.0005, loss=0.8287, acc=64.0%]


Epoch 2/10
训练损失: 1.1235 | 训练准确率: 64.01%
验证损失: 2.4641 | 验证准确率: 60.66%
──────────────────────────────────────────────────


Train Epoch 3: 100%|██████████| 2885/2885 [30:46<00:00,  1.56it/s, lr=0.0005, loss=1.0044, acc=65.2%]


Epoch 3/10
训练损失: 1.1849 | 训练准确率: 65.17%
验证损失: 1.4532 | 验证准确率: 63.57%
──────────────────────────────────────────────────


Train Epoch 4: 100%|██████████| 2885/2885 [30:55<00:00,  1.55it/s, lr=0.0005, loss=0.8748, acc=67.0%]


Epoch 4/10
训练损失: 1.1260 | 训练准确率: 66.99%
验证损失: 1.5745 | 验证准确率: 63.71%
──────────────────────────────────────────────────


Train Epoch 5: 100%|██████████| 2885/2885 [31:17<00:00,  1.54it/s, lr=0.00025, loss=0.6230, acc=71.3%]


保存最佳模型，验证准确率: 76.01%
Epoch 5/10
训练损失: 0.9632 | 训练准确率: 71.28%
验证损失: 1.1740 | 验证准确率: 76.01%
──────────────────────────────────────────────────


Train Epoch 6: 100%|██████████| 2885/2885 [31:28<00:00,  1.53it/s, lr=0.00025, loss=0.8177, acc=73.4%]


保存最佳模型，验证准确率: 78.37%
Epoch 6/10
训练损失: 0.9296 | 训练准确率: 73.35%
验证损失: 1.1406 | 验证准确率: 78.37%
──────────────────────────────────────────────────


Train Epoch 7: 100%|██████████| 2885/2885 [31:53<00:00,  1.51it/s, lr=0.00025, loss=0.4127, acc=74.0%]


Epoch 7/10
训练损失: 0.8893 | 训练准确率: 73.98%
验证损失: 1.0829 | 验证准确率: 74.21%
──────────────────────────────────────────────────


Train Epoch 8: 100%|██████████| 2885/2885 [33:48<00:00,  1.42it/s, lr=0.00025, loss=1.1807, acc=75.2%]


Epoch 8/10
训练损失: 0.8575 | 训练准确率: 75.16%
验证损失: 1.1477 | 验证准确率: 78.06%
──────────────────────────────────────────────────


Train Epoch 9: 100%|██████████| 2885/2885 [32:56<00:00,  1.46it/s, lr=0.00025, loss=0.4031, acc=76.4%]


Epoch 9/10
训练损失: 0.8450 | 训练准确率: 76.39%
验证损失: 2.1583 | 验证准确率: 71.44%
──────────────────────────────────────────────────


Train Epoch 10: 100%|██████████| 2885/2885 [32:37<00:00,  1.47it/s, lr=0.000125, loss=1.1294, acc=78.1%]


保存最佳模型，验证准确率: 80.24%
Epoch 10/10
训练损失: 0.7488 | 训练准确率: 78.14%
验证损失: 0.9830 | 验证准确率: 80.24%
──────────────────────────────────────────────────


In [24]:
!zip -r folder.zip /content/processed
from google.colab import files
files.download('folder.zip')


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
  adding: content/processed/sample_13852_img.npy (deflated 12%)
  adding: content/processed/sample_1291_pte.npy (deflated 12%)
  adding: content/processed/sample_4145_img.npy (deflated 11%)
  adding: content/processed/sample_9669_img.npy (deflated 11%)
  adding: content/processed/sample_1974_pte.npy (deflated 15%)
  adding: content/processed/sample_3428_img.npy (deflated 11%)
  adding: content/processed/sample_12133_img.npy (deflated 11%)
  adding: content/processed/sample_7265_pte.npy (deflated 14%)
  adding: content/processed/sample_8569_pte.npy (deflated 14%)
  adding: content/processed/sample_5685_pte.npy (deflated 13%)
  adding: content/processed/sample_982_img.npy (deflated 11%)
  adding: content/processed/sample_4397_img.npy (deflated 11%)
  adding: content/processed/sample_7586_pte.npy (deflated 16%)
  adding: content/processed/sample_10790_pte.npy (deflated 15%)
  adding: content/processed/sample_6764_img.npy (deflated 11%)
  adding: co

FileNotFoundError: Cannot find file: folder.zip

In [None]:
# ---------------------------
# 修改4：最终测试
# ---------------------------
# 加载最佳模型
model.load_state_dict(torch.load("best_model.pth"))

