<a href="https://colab.research.google.com/github/chengpeip/colab_design_for_graduate/blob/main/levit_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mne

Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m41.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


In [None]:
import os
import re
import mne
import numpy as np
import torch
from scipy.signal import stft

class EDFPretransform:
    def __init__(self, edf_folder, pte_folder, output_img_size=224, nperseg=256, target_sfreq=250):
        self.edf_folder = edf_folder
        self.pte_folder = pte_folder
        self.output_img_size = output_img_size
        self.nperseg = nperseg
        self.target_sfreq = target_sfreq
        self.channel_names = None
        self.window_duration = 2  # 时间窗口长度（秒）

        # 预编译正则表达式
        self.edf_pattern = re.compile(r'^([hs])(\d+)\.edf$', re.IGNORECASE)
        self.pte_pattern = re.compile(r'^([hs])(\d+)_pte\.npy$', re.IGNORECASE)

    def _match_files(self):
        """匹配EDF与对应的PTE窗口文件"""
        edf_files = sorted(
            [f for f in os.listdir(self.edf_folder) if self.edf_pattern.match(f)],
            key=lambda x: int(self.edf_pattern.match(x).group(2)))

        valid_pairs = []
        for edf_file in edf_files:
            match = self.edf_pattern.match(edf_file)
            prefix = match.group(1).lower()
            number = match.group(2)

            # 构建PTE文件名（约定为相同编号）
            pte_file = f"{prefix}{number}_pte.npy"
            pte_path = os.path.join(self.pte_folder, pte_file)

            if os.path.exists(pte_path):
                valid_pairs.append((
                    os.path.join(self.edf_folder, edf_file),
                    pte_path,
                    prefix
                ))
            else:
                print(f"警告: 未找到 {edf_file} 对应的PTE窗口文件 {pte_file}")

        return valid_pairs

    def _process_edf(self, edf_path):
        """处理EDF文件并统一采样率"""
        try:
            raw = mne.io.read_raw_edf(edf_path, preload=True)
            if raw.info['sfreq'] != self.target_sfreq:
                raw = raw.resample(self.target_sfreq)
            return raw
        except Exception as e:
            print(f"EDF文件处理失败 {edf_path}: {str(e)}")
            return None

    def _segment_data(self, data):
        """分割数据为时间窗口"""
        window_samples = int(self.window_duration * self.target_sfreq)
        n_samples = data.shape[1]
        n_windows = n_samples // window_samples

        # 截断到整数个窗口
        truncated = data[:, :n_windows*window_samples]
        return truncated.reshape(data.shape[0], n_windows, window_samples).transpose(1, 0, 2)

    def _generate_spectrogram(self, data, sfreq):
        """生成标准化时频图"""
        f, t, Zxx = stft(data, fs=sfreq, nperseg=self.nperseg)
        Zxx = 10 * np.log10(np.abs(Zxx) + 1e-6)
        Zxx = (Zxx - Zxx.min()) / (Zxx.max() - Zxx.min() + 1e-6)

        from skimage.transform import resize
        return resize(Zxx, (self.output_img_size, self.output_img_size),
                     mode='reflect', anti_aliasing=True).astype(np.float32)

    def process_all_files(self):
        """主处理流程"""
        file_pairs = self._match_files()

        all_images = []
        all_pte = []
        labels = []

        for edf_path, pte_path, prefix in file_pairs:
            # 处理EDF文件
            raw = self._process_edf(edf_path)
            if raw is None:
                continue

            # 初始化通道信息
            if self.channel_names is None:
                self.channel_names = raw.ch_names
                print(f"通道配置: {len(self.channel_names)}个通道 ({', '.join(self.channel_names)})")

            # 分割时间窗口
            data = raw.get_data()  # [n_channels, n_samples]
            data_windows = self._segment_data(data)  # [n_windows, n_channels, window_samples]

            # 加载对应的PTE窗口数据
            try:
                pte_windows = np.load(pte_path)  # 预期形状 [n_windows, 19, 19]
                assert pte_windows.ndim == 3, "PTE数据应为三维数组"
                assert pte_windows.shape[1:] == (len(self.channel_names), len(self.channel_names)), "PTE维度与通道数不符"
                assert pte_windows.shape[0] == data_windows.shape[0], "PTE窗口数与数据窗口数不匹配"
            except Exception as e:
                print(f"PTE加载失败 {pte_path}: {str(e)}")
                continue

            # 生成标签（h=0, s=1）
            label = 1 if prefix == 's' else 0

            # 遍历每个时间窗口
            for win_idx in range(data_windows.shape[0]):
                # 生成时频图
                window_images = []
                for ch_idx in range(data_windows.shape[1]):
                    ch_data = data_windows[win_idx, ch_idx, :]
                    window_images.append(self._generate_spectrogram(ch_data, self.target_sfreq))

                # 收集数据
                all_images.append(np.stack(window_images))  # [19, H, W]
                all_pte.append(pte_windows[win_idx])       # [19, 19]
                labels.append(label)

        # 转换为张量
        images_tensor = torch.FloatTensor(np.array(all_images))  # [N, 19, 224, 224]
        pte_tensor = torch.FloatTensor(np.array(all_pte))        # [N, 19, 19]
        labels_tensor = torch.LongTensor(labels)                # [N]

        # 数据验证
        print(f"\n处理完成: 总样本数 {len(labels)}")
        print(f"时频图张量: {images_tensor.shape}")
        print(f"PTE张量: {pte_tensor.shape}")
        print(f"标签分布: 健康={labels.count(0)}, 患者={labels.count(1)}")

        return images_tensor, pte_tensor, labels_tensor

In [None]:
edf_processor = EDFPretransform(edf_folder="/content",pte_folder="/content")

In [None]:
images, pte, labels = edf_processor.process_all_files()

Extracting EDF parameters from /content/h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
通道配置: 19个通道 (Fp2, F8, T4, T6, O2, Fp1, F7, T3, T5, O1, F4, C4, P4, F3, C3, P3, Fz, Cz, Pz)
Extracting EDF parameters from /content/s01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 211249  =      0.000 ...   844.996 secs...
Extracting EDF parameters from /content/s02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 286249  =      0.000 ...  1144.996 secs...
Extracting EDF parameters from /content/h02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 227499  =      0.000 ...   909.996 secs...
Extracting EDF parameters from /content/h03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 .

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import timm
from timm.data import resolve_model_data_config, create_transform
import torchvision.transforms as T
from tqdm import tqdm

class ChannelProjector(nn.Module):
    def __init__(self, in_channels=19, out_channels=3):
        super(ChannelProjector, self).__init__()
        # 1x1 卷积，将 in_channels 投影到 out_channels
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # x: (batch, in_channels, H, W)
        return self.conv(x)  # 输出: (batch, out_channels, H, W)

# ---------------------------
# 2. 定义 EEG Spectrogram 数据集
# ---------------------------
class EEGSpectrogramDataset(data.Dataset):
    def __init__(self, images, pte, labels, model_name="levit_128.fb_dist_in1k", projector=None):
        """
        Args:
            images: EEG 时频图像数据，形状为 [N, 19, H, W]（19个通道的原始图像）
            pte: 对应的 PTE 矩阵数据，形状为 [N, 19, 19]
            labels: 样本标签，形状为 [N]
            model_name: 用于解析预处理配置的模型名称（例如 LeViT 模型）
            projector: 用于将 19 通道转换为 3 通道的模块（如 ChannelProjector 实例）
        """
        # 将输入数据转换为 torch.Tensor（若还不是）
        self.images = torch.tensor(images, dtype=torch.float32) if not torch.is_tensor(images) else images
        self.pte = torch.tensor(pte, dtype=torch.float32) if not torch.is_tensor(pte) else pte
        self.labels = torch.tensor(labels, dtype=torch.long) if not torch.is_tensor(labels) else labels
        self.projector = projector

        # 获取模型特定的数据预处理配置（包含 input_size, mean, std 等）
        self.data_config = resolve_model_data_config(model_name)
        # 创建预处理变换
        self.transforms = create_transform(**self.data_config, is_training=False)

        # 检查数据维度是否匹配预期
        assert self.images.ndim == 4, f"images 期望形状 [N, C, H, W], got {self.images.shape}"
        assert self.pte.ndim == 3, f"pte 期望形状 [N, C, C], got {self.pte.shape}"
        assert self.labels.ndim == 1, f"labels 期望形状 [N], got {self.labels.shape}"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # 获取 EEG 图像数据，形状为 (19, H, W)
        img = self.images[idx]  # Tensor, (19, H, W)
        # 如果提供了 projector，将 19 通道转换为 3 通道
        if self.projector is not None:
            # 添加 batch 维度：变为 (1, 19, H, W)
            img = img.unsqueeze(0)
            img = self.projector(img)   # 经过 1x1 卷积后变为 (1, 3, H, W)
            img = img.squeeze(0)        # 移除 batch 维度，得到 (3, H, W)
        # 应用 LeViT 所需的预处理：resize、归一化等
        img = self.transforms(img)

        # 读取 PTE 矩阵和标签
        pte_matrix = self.pte[idx]   # (19, 19)
        label = self.labels[idx]
        return {'eeg_img': img, 'pte_matrix': pte_matrix, 'label': label}
# ---------------------------
# 1. 定义 EEG 图像编码器（基于 LeViT）
# ---------------------------


class EEGImageEncoder(nn.Module):
    def __init__(self, embed_dim=128):
        super(EEGImageEncoder, self).__init__()

        self.vit = timm.create_model('levit_128.fb_dist_in1k', pretrained=True, num_classes=0)

        # 将 ViT 的输出特征投影到目标维度 embed_dim
        self.fc = nn.Linear(self.vit.num_features, embed_dim)

    def forward(self, x):
        features = self.vit(x)
        # 投影到指定维度
        out = self.fc(features)
        return out


# ---------------------------
# 2. 定义 PTE 矩阵编码器（基于全连接网络）
# ---------------------------
class PTEEncoder(nn.Module):
    def __init__(self, num_channels=19, embed_dim=128):
        super(PTEEncoder, self).__init__()
        self.fc1 = nn.Linear(num_channels * num_channels, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, embed_dim)

    def forward(self, x):
        # x: (batch, num_channels, num_channels)
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # 扁平化
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# ---------------------------
# 3. 定义对比损失（InfoNCE / NT-Xent 损失的简单实现）
# ---------------------------
def contrastive_loss(z1, z2, temperature=0.07):
    # z1, z2: (batch, embed_dim)
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    batch_size = z1.size(0)
    # 计算相似度矩阵（余弦相似度）
    representations = torch.cat([z1, z2], dim=0)  # (2*batch, embed_dim)
    similarity_matrix = torch.matmul(representations, representations.T)  # (2*batch, 2*batch)
    # 去除自身相似度
    mask = torch.eye(2*batch_size, device=z1.device).bool()
    similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
    similarity_matrix /= temperature

    # 构造标签：同一原始样本的两个增强视角构成正对，标签为 i 与 i+batch_size 对应
    labels = torch.arange(batch_size, device=z1.device)
    loss_v2p = F.cross_entropy(similarity_matrix[:batch_size, batch_size:], labels)
    loss_p2v = F.cross_entropy(similarity_matrix[batch_size:, :batch_size], labels)
    return (loss_v2p + loss_p2v) / 2

# ---------------------------
# 4. 定义分类器
# ---------------------------
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, num_classes=2):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
    def forward(self, x):
        return self.net(x)

# ---------------------------
# 5. 定义跨模态对比学习模型
# ---------------------------
class CrossModalModel(nn.Module):
    def __init__(self, embed_dim=128, num_classes=2):
        super(CrossModalModel, self).__init__()
        self.eeg_encoder = EEGImageEncoder(embed_dim=embed_dim)
        self.pte_encoder = PTEEncoder(num_channels=19, embed_dim=embed_dim)
        # 分类器输入为拼接的两个模态特征
        self.classifier = Classifier(in_dim=2*embed_dim, num_classes=num_classes)

    def forward(self, eeg_img, pte_matrix):
        # eeg_img: (batch, 1, H, W)
        # pte_matrix: (batch, 19, 19)
        f_eeg = self.eeg_encoder(eeg_img)    # (batch, embed_dim)
        f_pte = self.pte_encoder(pte_matrix)   # (batch, embed_dim)
        fused = torch.cat([f_eeg, f_pte], dim=1)  # (batch, 2*embed_dim)
        logits = self.classifier(fused)        # (batch, num_classes)
        return f_eeg, f_pte, logits





In [None]:
# ---------------------------
# 7. 构造数据加载器
# ---------------------------
batch_size = 4
dataset = EEGSpectrogramDataset(images, pte, labels,projector=ChannelProjector())
from torch.utils.data import random_split

# 划分训练集和验证集（80%训练，20%验证）
train_ratio = 0.8
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


# 创建数据加载器
train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True  # 加速数据转移到GPU
)

val_loader = data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,    # 验证集不需要shuffle
    pin_memory=True
)

test_loader = data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True
)

In [None]:
# ---------------------------
# 修改2：添加验证函数
# ---------------------------
def validate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in loader:
            eeg_img = batch['eeg_img'].to(device)
            pte_matrix = batch['pte_matrix'].to(device)
            labels = batch['label'].to(device)

            # 前向传播
            f_eeg, f_pte, logits = model(eeg_img, pte_matrix)

            # 计算损失
            loss_cls = criterion_cls(logits, labels)
            loss_contrast = contrastive_loss(f_eeg, f_pte)
            loss = loss_cls + alpha * loss_contrast

            # 统计指标
            total_loss += loss.item() * labels.size(0)
            total_correct += (logits.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    return avg_loss, accuracy

In [None]:

# ---------------------------
# 8. 训练流程（含准确率计算）
# ---------------------------
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# 初始化设备、模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrossModalModel(embed_dim=128, num_classes=2).to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
criterion_cls = nn.CrossEntropyLoss()



In [None]:

# ---------------------------
# 修改3：优化训练循环（含早停和学习率调度）
# ---------------------------
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 添加学习率调度器
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',       # 监控验证准确率
    factor=0.5,       # 学习率衰减因子
    patience=2,       # 等待2个epoch无改进
    verbose=True
)

# 早停参数
best_val_acc = 0.0
patience_counter = 0
early_stop_patience = 4  # 连续4次无改进则停止

def contrastive_loss(z1, z2, temperature=0.07):
    """改进的对比损失函数"""
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)

    # 计算相似度矩阵
    sim_matrix = torch.mm(z1, z2.T) / temperature

    # 创建标签
    labels = torch.arange(z1.size(0), device=z1.device)

    # 计算交叉熵损失
    loss = F.cross_entropy(sim_matrix, labels)
    return loss

def calculate_accuracy(logits, labels):
    """计算分类准确率"""
    _, predicted = torch.max(logits, 1)
    correct = (predicted == labels).sum().item()
    return correct / labels.size(0)



# 训练参数
num_epochs = 10
alpha = 1.0  # 对比损失权重

# 训练循环（含准确率计算）
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    progress_bar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}")

    for batch in progress_bar:
        # 数据加载（确保与数据集返回格式一致）
        eeg_img = batch['eeg_img'].to(device)
        pte_matrix = batch['pte_matrix'].to(device)
        labels = batch['label'].to(device)

        # 前向传播
        optimizer.zero_grad()
        f_eeg, f_pte, logits = model(eeg_img, pte_matrix)

        # 损失计算
        loss_cls = criterion_cls(logits, labels)
        loss_contrast = contrastive_loss(f_eeg, f_pte)
        loss = loss_cls + alpha * loss_contrast

        # 反向传播
        loss.backward()
        optimizer.step()

        # 统计指标
        batch_size = labels.size(0)
        total_loss += loss.item() * batch_size
        total_correct += (logits.argmax(1) == labels).sum().item()
        total_samples += batch_size

        # 实时更新进度条显示
        progress_bar.set_postfix({
            'lr': optimizer.param_groups[0]['lr'],
            'loss': f"{loss.item():.4f}",
            'acc': f"{100 * total_correct / total_samples:.1f}%"
        })
    # 验证阶段
    val_loss, val_acc = validate(model, val_loader, device)

    # 学习率调度
    scheduler.step(val_acc)

    # 早停机制
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        # 保存最佳模型
        torch.save(model.state_dict(), "best_model.pth")
        print(f"保存最佳模型，验证准确率: {val_acc:.2f}%")
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\n早停触发！连续{early_stop_patience}个epoch验证准确率未提升")
            break

    train_loss = total_loss / total_samples
    train_acc = 100 * total_correct / total_samples
    # 打印报告
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"训练损失: {train_loss:.4f} | 训练准确率: {train_acc:.2f}%")
    print(f"验证损失: {val_loss:.4f} | 验证准确率: {val_acc:.2f}%")
    print("─" * 50)



Train Epoch 1: 100%|██████████| 6/6 [00:04<00:00,  1.41it/s, lr=0.0005, loss=0.7394, acc=45.5%]


保存最佳模型，验证准确率: 66.67%
Epoch 1/10
训练损失: 0.7064 | 训练准确率: 45.45%
验证损失: 0.5616 | 验证准确率: 66.67%
──────────────────────────────────────────────────


Train Epoch 2: 100%|██████████| 6/6 [00:02<00:00,  2.55it/s, lr=0.0005, loss=0.5561, acc=81.8%]


Epoch 2/10
训练损失: 0.4807 | 训练准确率: 81.82%
验证损失: 0.5659 | 验证准确率: 66.67%
──────────────────────────────────────────────────


Train Epoch 3: 100%|██████████| 6/6 [00:02<00:00,  2.60it/s, lr=0.0005, loss=0.3264, acc=86.4%]


Epoch 3/10
训练损失: 0.4570 | 训练准确率: 86.36%
验证损失: 0.7676 | 验证准确率: 50.00%
──────────────────────────────────────────────────


Train Epoch 4: 100%|██████████| 6/6 [00:02<00:00,  2.58it/s, lr=0.0005, loss=0.4046, acc=81.8%]


Epoch 4/10
训练损失: 0.5074 | 训练准确率: 81.82%
验证损失: 0.8072 | 验证准确率: 33.33%
──────────────────────────────────────────────────


Train Epoch 5: 100%|██████████| 6/6 [00:03<00:00,  1.74it/s, lr=0.00025, loss=0.0601, acc=95.5%]



早停触发！连续4个epoch验证准确率未提升


In [None]:
# ---------------------------
# 修改4：最终测试
# ---------------------------
# 加载最佳模型
model.load_state_dict(torch.load("best_model.pth"))

# 在完整测试集上评估
test_loss, test_acc = validate(model, test_loader, device)
print(f"\n最终测试结果：损失 {test_loss:.4f} | 准确率 {test_acc:.2f}%")