In [None]:
import torch
import torch.nn as nn
from torcheeg import transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torcheeg.datasets import SEEDDataset
from torcheeg.models import DGCNN
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Subset

In [None]:
#TODO: dataset接受收集的数据
dataset = SEEDDataset(io_path='/Users/hanlin/Desktop/vr_locomotion/seed',  # 设置为之前保存数据的路径
                      offline_transform=transforms.BandDifferentialEntropy(band_dict={
                          "delta": [1, 4],
                          "theta": [4, 8],
                          "alpha": [8, 14],
                          "beta": [14, 31],
                          "gamma": [31, 49]
                      }),
                      online_transform=transforms.Compose([
                          transforms.ToTensor()
                      ]),
                      label_transform=transforms.Compose([
                          transforms.Select('emotion'),
                          transforms.Lambda(lambda x: x + 1)
                      ]))

In [None]:
# 划分数据集为训练集、验证集和测试集
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

seed = 42
torch.manual_seed(seed)
train_size = int(train_ratio * len(dataset))
val_size = int(val_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# TODO: 设置超参数
batch_size = 32
num_epochs = 100
# l1_reg = 0
# l2_reg = 0
learning_rate = 0.0001

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


train_losses = []
val_losses = []

# TODO: 修改模型参数
model = DGCNN(in_channels=5, num_electrodes=62, hid_channels=32, num_layers=2, num_classes=3)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_data, batch_labels in train_loader:
        # 训练代码
        # 将数据传递给模型
        outputs = model(batch_data)
        
        # 计算交叉熵损失
        ce_loss = criterion(outputs, batch_labels)
        
        # # 计算L1正则项
        # l1_reg_loss = 0
        # for param in model.parameters():
        #     l1_reg_loss += torch.sum(torch.abs(param))
        # # 计算总损失
        # loss = ce_loss + l1_reg * l1_reg_loss

        # 计算L2正则项
        l2_reg_loss = 0
        for param in model.parameters():
            l2_reg_loss += torch.sum(torch.pow(param, 2))
        
        # 计算总损失
        loss = ce_loss + l2_reg * l2_reg_loss
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # 在验证集上评估模型
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch_data, batch_labels in val_loader:
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            val_loss += loss.item()
    
    # 计算平均训练损失和验证损失
    train_loss = running_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# 绘制训练损失和验证损失的图表
import matplotlib.pyplot as plt
epochs = range(1, num_epochs + 1)
plt.plot(epochs, train_losses, 'b', label='Training Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')
plt.title('Training and Validation Loss, l2_reg: ' + str(l2_reg))
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 在测试集上评估模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch_data, batch_labels in test_loader:
        outputs = model(batch_data)
        _, predicted = torch.max(outputs.data, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")