<a href="https://colab.research.google.com/github/chengpeip/colab_design_for_graduate/blob/main/orginal_levit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mne

Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
import mne
import timm


##############################
# Step 1: 数据加载与预处理 (使用MNE)
##############################

# 电极位置映射 (与EDF文件中的通道名称对应)
electrode_positions = {
    'Fp1': (0, 3), 'Fp2': (0, 5),
    'F7': (2, 0), 'F3': (2, 2), 'Fz': (2, 4), 'F4': (2, 6), 'F8': (2, 8),
    'T3': (4, 0), 'C3': (4, 2), 'Cz': (4, 4), 'C4': (4, 6), 'T4': (4, 8),
    'T5': (6, 0), 'P3': (6, 2), 'Pz': (6, 4), 'P4': (6, 6), 'T6': (6, 8),
    'O1': (8, 3), 'O2': (8, 5)
}

class EEGDataset(Dataset):
    def __init__(self, data_folder, window_size=196, fixed_length=225000):
        self.files = [f for f in os.listdir(data_folder) if f.endswith('.edf')]
        self.labels = [0 if 'h' in f else 1 for f in self.files]
        self.channel_names = list(electrode_positions.keys())
        self.fixed_length = fixed_length
        self.window_size = window_size
        self.data_folder = data_folder

        # 每个文件可生成多少段图像
        self.segments_per_file = fixed_length // window_size
        self.total_samples = len(self.files) * self.segments_per_file

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        file_idx = idx // self.segments_per_file
        seg_idx = idx % self.segments_per_file

        file_path = os.path.join(self.data_folder, self.files[file_idx])
        raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)

        if raw.info['sfreq'] != 250:
            raw.resample(250, npad='auto')

        raw.pick(self.channel_names)

        raw.reorder_channels(self.channel_names)

        eeg_data, _ = raw[:, :]
        eeg_data = eeg_data.T

        if eeg_data.shape[0] > self.fixed_length:
            eeg_data = eeg_data[:self.fixed_length, :]
        else:
            pad_size = self.fixed_length - eeg_data.shape[0]
            eeg_data = np.pad(eeg_data, ((0, pad_size), (0, 0)), 'constant')

        # 归一化
        eeg_data = (eeg_data - np.min(eeg_data)) / (np.max(eeg_data) - np.min(eeg_data)) * 255

        # 取对应 segment 并转换为图像
        start = seg_idx * self.window_size
        patch = eeg_data[start:start + self.window_size, :]
        frames = [split_brain_lobes(map_to_2d(patch, i)) for i in range(self.window_size)]
        image = rearrange(np.stack(frames), '(h w) p1 p2 c -> (h p1) (w p2) c', h=14, w=14)
        image = torch.FloatTensor(image).permute(2, 0, 1)  # 转为 [C, H, W] 格式

        label = self.labels[file_idx]
        return image, label


##############################
# 后续步骤保持不变（图像转换、模型定义、训练流程）
##############################

# 其他代码与原始版本相同...

##############################
# Step 2: EEG信号到图像的转换
##############################

def map_to_2d(eeg_1d, time_point):
    """将单个时间点的1D EEG数据映射到9x9矩阵"""
    matrix = np.zeros((9, 9))
    for i, channel in enumerate(electrode_positions.keys()):
        row, col = electrode_positions[channel]
        matrix[row, col] = eeg_1d[time_point, i]  # 假设数据按通道顺序排列
    return matrix

import numpy as np

def split_brain_lobes(matrix):
    """
    将输入的 9x9 矩阵分为3组，每组三行，放入3个通道；
    然后将这3个 9x9 矩阵组成一个 9x9x3 的张量；
    最后嵌入到 16x16x3 的大矩阵中心。
    """
    # 检查输入
    assert matrix.shape == (9, 9), "输入必须是 9x9 的矩阵"

    # 初始化三个通道为全零矩阵
    R = np.zeros((9, 9))
    G = np.zeros((9, 9))
    B = np.zeros((9, 9))

    # 分组填充
    R[0:3, :] = matrix[0:3, :]
    G[3:6, :] = matrix[3:6, :]
    B[6:9, :] = matrix[6:9, :]

    # 组合成 9x9x3 的多通道图像
    combined = np.stack([R, G, B], axis=-1)  # shape: (9, 9, 3)

    # 计算需要 pad 的边界：上下左右各补 (16-9)//2 = 3.5 -> (3,4)
    pad_top, pad_bottom = 3, 4
    pad_left, pad_right = 3, 4

    # 对三维矩阵进行 padding
    padded = np.pad(combined, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant')

    return padded





In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import timm
import math

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 超参数
BATCH_SIZE = 32
EPOCHS = 20
WARMUP_EPOCHS = 5
BASE_LR = 0.0005
LR = BASE_LR * BATCH_SIZE / 10
WEIGHT_DECAY = 0.025
VAL_RATIO = 0.2

# 数据准备
dataset = EEGDataset('/content')

# 划分训练集和验证集
val_size = int(len(dataset) * VAL_RATIO)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 模型初始化并移动到 GPU
model = timm.create_model('levit_128.fb_dist_in1k', pretrained=True, num_classes=2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# 训练循环
for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0.0
    correct_train = 0
    total_train = 0

    # 使用 tqdm 显示训练进度
    train_bar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{EPOCHS}")
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

        train_acc = 100 * correct_train / total_train
        train_bar.set_postfix(loss=loss.item(), acc=f"{train_acc:.2f}%", lr=optimizer.param_groups[0]['lr'])

    # 学习率调整（warmup + cosine）
    if epoch < WARMUP_EPOCHS:
        lr_scale = (epoch + 1) / WARMUP_EPOCHS
    else:
        progress = (epoch - WARMUP_EPOCHS) / (EPOCHS - WARMUP_EPOCHS)
        lr_scale = 0.5 * (1 + math.cos(math.pi * progress))
    current_lr = LR * lr_scale
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    # 验证阶段
    model.eval()
    correct_val = 0
    total_val = 0
    val_bar = tqdm(val_loader, desc="Validating", leave=False)
    with torch.no_grad():
        for images, labels in val_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
            val_acc = 100 * correct_val / total_val
            val_bar.set_postfix(acc=f"{val_acc:.2f}%")

    # 训练 & 验证结果总结
    epoch_train_loss = total_train_loss / len(train_loader)
    epoch_train_acc = 100 * correct_train / total_train
    epoch_val_acc = 100 * correct_val / total_val

    print(f"[Epoch {epoch+1}/{EPOCHS}] "
          f"Train Loss: {epoch_train_loss:.4f} | "
          f"Train Acc: {epoch_train_acc:.2f}% | "
          f"Val Acc: {epoch_val_acc:.2f}% | "
          f"LR: {current_lr:.6f}")


Using device: cpu


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/37.1M [00:00<?, ?B/s]

Train Epoch 1/20: 100%|██████████| 803/803 [1:22:04<00:00,  6.13s/it, acc=80.04%, loss=0.345, lr=0.0016]


[Epoch 1/20] Train Loss: 0.4150 | Train Acc: 80.04% | Val Acc: 73.45% | LR: 0.000320


Train Epoch 2/20: 100%|██████████| 803/803 [1:23:06<00:00,  6.21s/it, acc=90.14%, loss=0.282, lr=0.00032]


[Epoch 2/20] Train Loss: 0.2249 | Train Acc: 90.14% | Val Acc: 86.94% | LR: 0.000640


Train Epoch 3/20: 100%|██████████| 803/803 [1:26:51<00:00,  6.49s/it, acc=91.19%, loss=0.0728, lr=0.00064]


[Epoch 3/20] Train Loss: 0.2029 | Train Acc: 91.19% | Val Acc: 85.22% | LR: 0.000960


Train Epoch 4/20: 100%|██████████| 803/803 [1:21:01<00:00,  6.05s/it, acc=92.41%, loss=0.158, lr=0.00096]


[Epoch 4/20] Train Loss: 0.1802 | Train Acc: 92.41% | Val Acc: 94.83% | LR: 0.001280


Train Epoch 5/20: 100%|██████████| 803/803 [1:18:15<00:00,  5.85s/it, acc=93.70%, loss=0.172, lr=0.00128]


[Epoch 5/20] Train Loss: 0.1566 | Train Acc: 93.70% | Val Acc: 87.61% | LR: 0.001600


Train Epoch 6/20: 100%|██████████| 803/803 [1:19:44<00:00,  5.96s/it, acc=93.10%, loss=0.122, lr=0.0016]


[Epoch 6/20] Train Loss: 0.1592 | Train Acc: 93.10% | Val Acc: 88.73% | LR: 0.001600


Train Epoch 7/20:  51%|█████     | 410/803 [44:03<38:49,  5.93s/it, acc=93.94%, loss=0.353, lr=0.0016]