In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import time
import joblib
import model as NN

In [2]:
seed = 18
total_fold = 10  # 10折
'''深度学习超参数'''
input_size = 16
hidden_size = 128
num_layers_lstm = 1
num_layers_bilstm = 2
num_classes = 2
batch_size = 64
num_epochs = 30
# learning_rate = 0.0003
learning_rate = 0.001

# model_name = 'dl.STCGRU()'
start = time.perf_counter()
name = locals()
NN.seed_everything(seed)
srate = 250
writer = SummaryWriter('./runs/' +'250hz_'+ str(seed))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

for i in range(total_fold):
    train_data_combine = torch.load("EEGData/250hz/TrainData/train_data_"
                                    + str(i + 1) + "_fold_with_seed_" + str(seed) + ".pth",weights_only=False)
    valid_data_combine = torch.load("EEGData/250hz/ValidData/valid_data_"
                                    + str(i + 1) + "_fold_with_seed_" + str(seed) + ".pth",weights_only=False)
    '''定义深度学习模型'''
    model = NN.STCGRU().to(device)
    '''定义损失函数Loss 和 优化算法optimizer'''
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.05)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.000001)  # 余弦退火
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.8)
    print('开始第%d次训练，共%d次' % (i + 1, total_fold))

    # 生成迭代器，根据小批量数据大小划分每批送入模型的数据集
    train_loader = DataLoader(dataset=train_data_combine,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True,
                            num_workers=8)
    valid_loader = DataLoader(dataset=valid_data_combine,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True,
                            num_workers=8)
    total_step = len(train_loader)
    '''模型训练'''
    for epoch in range(num_epochs):
        '''训练'''
        model, optimizer = NN.model_training(writer, i, type='train', num_epochs=num_epochs,
                                            epoch=epoch, loader=train_loader, neural_network=model,
                                            criterion=criterion, optimizer=optimizer)
        '''验证'''
        optimizer, lr_list = NN.model_training(writer, i, type='validation', epoch=epoch,
                                            loader=valid_loader, neural_network=model, criterion=criterion,
                                            optimizer=optimizer, scheduler=scheduler)
    ensure_dir("stcgru/250hz")
    torch.save(model.state_dict(),
            "stcgru/250hz/" +  str(i + 1) + "_fold_model_parameter_with_seed_" + str(seed) + ".pth")
    print("stcgru" + "模型第" + str(i + 1) + "次训练结果保存成功")
end = time.perf_counter()
print("训练及验证运行时间为", round(end - start), 'seconds')

开始第1次训练，共10次
Epoch: [  1/30] Train loss: 0.6054      Train accuracy: 0.6849
                 Validation loss: 0.4407 Validation accuracy: 0.9021
Epoch: [  2/30] Train loss: 0.3885      Train accuracy: 0.9465
                 Validation loss: 0.3705 Validation accuracy: 0.9781
Epoch: [  3/30] Train loss: 0.3523      Train accuracy: 0.9851
                 Validation loss: 0.3405 Validation accuracy: 0.9969
Epoch: [  4/30] Train loss: 0.3413      Train accuracy: 0.9969
                 Validation loss: 0.3411 Validation accuracy: 0.9990
Epoch: [  5/30] Train loss: 0.3375      Train accuracy: 0.9985
                 Validation loss: 0.3438 Validation accuracy: 0.9979
Epoch: [  6/30] Train loss: 0.3387      Train accuracy: 0.9980
                 Validation loss: 0.3356 Validation accuracy: 1.0000
Epoch: [  7/30] Train loss: 0.3383      Train accuracy: 0.9993
                 Validation loss: 0.3383 Validation accuracy: 1.0000
Epoch: [  8/30] Train loss: 0.3395      Train accuracy: 0.9995
