In [1]:
import os
cwd = os.getcwd()
import sys
path = os.path.join(cwd, "..\\..\\")
sys.path.append(path)

In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

import logging
logging.getLogger('lightning').setLevel(0)

import warnings
warnings.filterwarnings('ignore')

import pytorch_lightning
pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR)

from splearn.data import MultipleSubjects, PyTorchDataset, PyTorchDataset2Views, HSSSVEP
from splearn.filter.butterworth import butter_bandpass_filter
from splearn.filter.notch import notch_filter
from splearn.filter.channels import pick_channels
from splearn.nn.models import CompactEEGNet
from splearn.utils import Logger, Config
from splearn.nn.base import LightningModelClassifier

In [3]:
config = {
    "run_name": "eeg_hsssvep_run2",
    "data": {
        "load_subject_ids": np.arange(1,36),
        # "selected_channels": ["PO8", "PZ", "PO7", "PO4", "POz", "PO3", "O2", "Oz", "O1"], # AA paper
        "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], # hsssvep paper
    },
    "training": {
        "num_epochs": 500,
        "num_warmup_epochs": 50,
        "learning_rate": 0.03,
        "gpus": [0],
        "batchsize": 256,
    },
    "model": {
        "optimizer": "adamw",
        "scheduler": "cosine_with_warmup",
    },
    "testing": {
        "test_subject_ids": np.arange(33,34),
        "kfolds": np.arange(0,3),
    },
    "seed": 1234
}

main_logger = Logger(filename_postfix=config["run_name"])
main_logger.write_to_log("Config")
main_logger.write_to_log(config)

config = Config(config)

seed_everything(config.seed)

Global seed set to 1234


1234

In [4]:
def func_preprocessing(data):
    data_x = data.data
    data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels)
    # data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0)
    data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6)
    start_t = 160
    end_t = start_t + 250
    data_x = data_x[:,:,:,start_t:end_t]
    data.set_data(data_x)

data = MultipleSubjects(
    dataset=HSSSVEP, 
    root=os.path.join(path, "../data/hsssvep"), 
    subject_ids=config.data.load_subject_ids, 
    func_preprocessing=func_preprocessing,
    verbose=True, 
)

Load subject: 1
Load subject: 2
Load subject: 3
Load subject: 4
Load subject: 5
Load subject: 6
Load subject: 7
Load subject: 8
Load subject: 9
Load subject: 10
Load subject: 11
Load subject: 12
Load subject: 13
Load subject: 14
Load subject: 15
Load subject: 16
Load subject: 17
Load subject: 18
Load subject: 19
Load subject: 20
Load subject: 21
Load subject: 22
Load subject: 23
Load subject: 24
Load subject: 25
Load subject: 26
Load subject: 27
Load subject: 28
Load subject: 29
Load subject: 30
Load subject: 31
Load subject: 32
Load subject: 33
Load subject: 34
Load subject: 35


In [5]:
print("Final data shape:", data.data.shape)

num_channel = data.data.shape[2]
num_classes = 40
signal_length = data.data.shape[3]

Final data shape: (35, 240, 9, 250)


In [6]:
test_subject_id=1
kfold_k=1

train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k)

In [10]:
train_dataset.data.shape

(320, 9, 250)

In [79]:
def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0):
    
    ## init data
    
    # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k)
    train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k)
    train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False)

    ## init model

    eegnet = CompactEEGNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length)

    model = LightningModelClassifier(
        optimizer=config.model.optimizer,
        scheduler=config.model.scheduler,
        optimizer_learning_rate=config.training.learning_rate,
        scheduler_warmup_epochs=config.training.num_warmup_epochs,
    )
    
    model.build_model(model=eegnet)

    ## train

    sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k)
    logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.run_name, sub_dir=sub_dir)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor])
    trainer.fit(model, train_loader, val_loader)
    
    ## test
    
    result = trainer.test(dataloaders=test_loader, verbose=False)
    test_acc = result[0]['test_acc_epoch']
    
    return test_acc

In [80]:
main_logger.write_to_log("Begin", break_line=True)

test_results_acc = {}
means = []

def k_fold_train_test_all_subjects():
    
    for test_subject_id in config.testing.test_subject_ids:
        print()
        print("running test_subject_id:", test_subject_id)
        
        if test_subject_id not in test_results_acc:
            test_results_acc[test_subject_id] = []
        
        test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=0)
            
        test_results_acc[test_subject_id].append(test_acc)
        means.append(test_acc)
        
        this_result = {
            "test_subject_id": test_subject_id,
            "acc": test_results_acc[test_subject_id],
        }        
        print(this_result)
        main_logger.write_to_log(this_result)

        
k_fold_train_test_all_subjects()

mean_acc = np.mean(means)
print()
print("mean all", mean_acc)
main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 1234



running test_subject_id: 1


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


{'test_subject_id': 1, 'acc': [0.5916666388511658]}

mean all 0.5916666388511658


In [6]:
from splearn.nn.modules.conv2d import Conv2d


In [7]:
# x = torch.randn(320, 9, 250)

# print(x.shape)


# model = CompactEEGNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length)
# y = model(x)
# print(y.shape)

# model = Model()
# y = model(x)
# print(y.shape)

In [8]:
# class SlowFast(nn.Module):
#     def __init__(self, block=None, layers=[3, 4, 6, 3], class_num=10, dropout=0.5):
#         super(SlowFast, self).__init__()

#         in_channels = 9
#         filters = [32, 64, 128]
#         kernel_size = (1, 5)

#         self.fast_conv1 = Conv2d(
#             in_channels, filters[0], kernel_size=kernel_size, bias=False)
#         self.fast_bn1 = nn.BatchNorm2d(filters[0])
#         self.fast_conv2 = Conv2d(
#             filters[0], filters[1], kernel_size=kernel_size, bias=False)
#         self.fast_bn2 = nn.BatchNorm2d(filters[1])
#         self.fast_conv3 = Conv2d(
#             filters[1], filters[2], kernel_size=kernel_size, bias=False)
#         self.fast_bn3 = nn.BatchNorm2d(filters[2])

#         self.fast_relu = nn.ReLU(inplace=True)
#         self.fast_maxpool = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 1))

#         self.lateral_p1 = Conv2d(
#             filters[0], filters[0], kernel_size=(1, 1), stride=(1, 2), bias=False)
#         self.lateral_p2 = Conv2d(
#             filters[1], filters[1], kernel_size=(1, 1), stride=(1, 2), bias=False)
#         self.lateral_p3 = Conv2d(
#             filters[2], filters[2], kernel_size=(1, 1), stride=(1, 2), bias=False)
        
#         self.identity1 = Conv2d(
#             in_channels, filters[0], kernel_size=(1, 1), stride=(1, 1), bias=False)
#         self.identity2 = Conv2d(
#             filters[0], filters[1], kernel_size=(1, 1), stride=(1, 1), bias=False)
#         self.identity3 = Conv2d(
#             filters[1], filters[2], kernel_size=(1, 1), stride=(1, 1), bias=False)

#         self.slow_conv1 = Conv2d(
#             in_channels, filters[0], kernel_size=kernel_size, stride=(1, 2), padding=(0, 3), bias=False)
#         self.slow_bn1 = nn.BatchNorm2d(filters[0])
#         self.slow_conv2 = Conv2d(
#             filters[1], filters[1], kernel_size=kernel_size, stride=(1, 2), padding=(0, 3), bias=False)
#         self.slow_bn2 = nn.BatchNorm2d(filters[1])
#         self.slow_conv3 = Conv2d(
#             filters[2], filters[2], kernel_size=kernel_size, stride=(1, 2), padding=(0, 3), bias=False)
#         self.slow_bn3 = nn.BatchNorm2d(filters[2])

#         self.slow_relu = nn.ReLU(inplace=True)
#         self.slow_maxpool = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 1))

#     def forward(self, input):
#         input = torch.unsqueeze(input, 2)
#         fast, lateral = self.FastPath(input)
#         slow = self.SlowPath(input, lateral)
#         return fast, slow

#     def SlowPath(self, input, lateral):
#         x = self.slow_conv1(input)
#         x = self.slow_bn1(x)
#         x = self.slow_relu(x)
#         x = self.slow_maxpool(x)
#         # print("slow x", x.shape, lateral[0].shape)
#         x = torch.cat([x, lateral[0]], dim=1)
        
#         # print("slow x", x.shape)
#         x = self.slow_conv2(x)
#         x = self.slow_bn2(x)
#         x = self.slow_relu(x)
#         x = self.slow_maxpool(x)
#         # print("slow x", x.shape, lateral[1].shape)
#         x = torch.cat([x, lateral[1]], dim=1)

#         # print("slow x", x.shape)
#         x = self.slow_conv3(x)
#         x = self.slow_bn3(x)
#         x = self.slow_relu(x)
#         x = self.slow_maxpool(x)
#         # print("slow x", x.shape, lateral[2].shape)
#         x = torch.cat([x, lateral[2]], dim=1)

#         return x

#     def FastPath(self, input):
#         lateral = []
#         x1 = self.fast_conv1(input)
#         x1 = self.fast_bn1(x1)
#         x1 = self.fast_relu(x1)
#         x1 = self.identity1(input) + x1
#         # pool1 = self.fast_maxpool(x1)
#         # print("pool1", pool1.shape)
#         # print("x1", x1.shape)
#         lateral_p1 = self.lateral_p1(x1)
#         lateral.append(lateral_p1)
#         # print("lateral_p1", lateral_p1.shape)

#         x2 = self.fast_conv2(lateral_p1)
#         x2 = self.fast_bn2(x2)
#         x2 = self.fast_relu(x2)
#         x2 = self.identity2(lateral_p1) + x2
#         # print(lateral_p1.shape, x2.shape)
#         # x2 = lateral_p1 + x2
#         # pool2 = self.fast_maxpool(x2)
#         # print("pool2", pool2.shape)
#         # print("x2", x2.shape)
#         lateral_p2 = self.lateral_p2(x2)
#         # print("lateral_p2", lateral_p2.shape)
#         lateral.append(lateral_p2)

#         x3 = self.fast_conv3(lateral_p2)
#         x3 = self.fast_bn3(x3)
#         x3 = self.fast_relu(x3)
#         x3 = self.identity3(lateral_p2) + x3
#         # x3 = lateral_p2 + x3
#         # pool3 = self.fast_maxpool(x3)
#         # print("pool3", pool3.shape)
#         # print("x3", x3.shape)
#         lateral_p3 = self.lateral_p3(x3)
#         # print("lateral_p3", lateral_p3.shape)
#         lateral.append(lateral_p3)

#         return lateral_p3, lateral


# model = SlowFast()
# fast, slow = model(x)
# print("fast", fast.shape)
# print("slow", slow.shape)


In [23]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super(Block, self).__init__()
        
        self.conv = Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, input):
        x = self.conv(input)
        x = self.bn(x)
        x = self.relu(x)
        
        return x

class ResBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, kernel_sizes):
        super(ResBlock, self).__init__()
        
        self.conv1 = Block(in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_sizes[0])
        self.conv2 = Block(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=kernel_sizes[1])
        self.conv3 = Block(in_channels=hidden_channels, out_channels=out_channels, kernel_size=kernel_sizes[2])
        self.conv_fusion = Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1), bias=False)
        
    def forward(self, input):
        x = self.conv1(input)
        # print("ResBlock 1", x.shape)
        x = self.conv2(x)
        # print("ResBlock 2", x.shape)
        x = self.conv3(x)
        # print("ResBlock 3", x.shape)
        
        shortcut = self.conv_fusion(input)
        # print("shortcut", shortcut.shape)
        
        x = x + shortcut
        return x

class Fusion(nn.Module):
    def __init__(self, fusion_dim_in, conv_kernel_size, conv_stride, slowfast_channel_reduction_ratio=8, conv_fusion_channel_ratio=8):
        super(Fusion, self).__init__()
        
        conv_dim_in = fusion_dim_in // slowfast_channel_reduction_ratio
        norm_eps = 1e-5
        norm_momentum = 0.1
        
        self.conv_fast_to_slow = nn.Conv2d(
            conv_dim_in,
            int(conv_dim_in * conv_fusion_channel_ratio),
            kernel_size=conv_kernel_size,
            stride=conv_stride,
            padding=[k_size // 2 for k_size in conv_kernel_size],
            bias=False,
        )
        
        self.bn = nn.BatchNorm2d(
            num_features=conv_dim_in * conv_fusion_channel_ratio,
            eps=norm_eps,
            momentum=norm_momentum,
        )
        self.activation = nn.ReLU()
        
    def forward(self, x):
        x_s = x[0]
        x_f = x[1]
        # print(888, x_f.shape) # 888 torch.Size([320, 64, 32, 56])
        fuse = self.conv_fast_to_slow(x_f)
        fuse = self.bn(fuse)
        fuse = self.activation(fuse)
        x_s_fuse = torch.cat([x_s, fuse], 1)
        return x_s_fuse


class SlowFast(nn.Module):
    def __init__(self):
        super(SlowFast, self).__init__()
        in_channels = 1
        
        self.fast_conv1 = Block(in_channels=in_channels, out_channels=8, kernel_size=(5,7), stride=(2,2))
        self.fast_maxpool = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2), padding=(0,1))
        
        self.fast_conv2 = ResBlock(in_channels=8, hidden_channels=8, out_channels=32, kernel_sizes=[(3,1),(1,3),(1,1)])
        self.fast_conv3 = ResBlock(in_channels=32, hidden_channels=16, out_channels=64, kernel_sizes=[(3,1),(1,3),(1,1)])
        
        self.slow_conv1 = Block(in_channels=in_channels, out_channels=32, kernel_size=(1,7), stride=(16,2))
        self.slow_maxpool = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2), padding=(0,1))
        
        self.slow_conv2 = ResBlock(in_channels=64, hidden_channels=64, out_channels=256, kernel_sizes=[(1,1),(1,3),(1,1)])
        self.slow_conv3 = ResBlock(in_channels=256, hidden_channels=128, out_channels=512, kernel_sizes=[(1,1),(1,3),(1,1)])
        
        self.fusion1 = Fusion(fusion_dim_in=32, conv_kernel_size=(1,3), conv_stride=(8,1))
        self.fusion2 = Fusion(fusion_dim_in=128, conv_kernel_size=(1,3), conv_stride=(8,1))
        self.fusion3 = Fusion(fusion_dim_in=256, conv_kernel_size=(1,3), conv_stride=(8,1))
        
        
    def forward(self, input):
        # input = torch.unsqueeze(input, 3)
        # print("input.shape", input.shape)
        fast, lateral = self.FastPath(input)        
        slow = self.SlowPath(input, lateral)
        return fast, slow

    
    def FastPath(self, input):
        lateral = []
        x1 = self.fast_conv1(input)
        # x1 = self.fast_maxpool(x1)
        # print("fast x1", x1.shape)
        lateral.append(x1)
        
        x2 = self.fast_conv2(x1)
        # print("fast x2", x2.shape)
        lateral.append(x2)
        
        x3 = self.fast_conv3(x2)
        # print("fast x3", x3.shape)
        lateral.append(x3)
        
        return x3, lateral
    
    def SlowPath(self, input, lateral):
        x1 = self.slow_conv1(input)
        # x1 = self.slow_maxpool(x1)
        print("slow 1", x1.shape)
        
        print("slow fusion1",x1.shape, lateral[0].shape)
        x1 = self.fusion1([x1, lateral[0]])
        
        
        x2 = self.slow_conv2(x1)
        print("slow fusion2", x2.shape, lateral[1].shape)
        x2 = self.fusion2([x2,lateral[1]])
        
        
        x3 = self.slow_conv3(x2)
        print("slow fusion3", x3.shape, lateral[2].shape)
        x3 = self.fusion3([x3,lateral[2]])
        

        return x3



x = torch.randn(320, 1, 64, 224)

model = SlowFast()
fast, slow = model(x)
print()
print("fast", fast.shape)
print("slow", slow.shape)


slow 1 torch.Size([320, 32, 4, 112])
slow fusion1 torch.Size([320, 32, 4, 112]) torch.Size([320, 8, 32, 112])


RuntimeError: Given groups=1, weight of size [32, 4, 1, 3], expected input[320, 8, 32, 112] to have 4 channels, but got 8 channels instead

In [10]:
# class Detection(nn.Module):

#     # def __init__(self, pooler_mode: Pooler.Mode, hidden: nn.Module, num_hidden_out: int, num_classes: int, proposal_smooth_l1_loss_beta: float):
#     def __init__(self):
#         super().__init__()
#         num_hidden_out = 12288
#         num_classes = 40
#         self._proposal_class = nn.Linear(num_hidden_out, num_classes)

#     def forward(self, fast_feature, slow_feature):
#         batch_size = fast_feature.shape[0]
        
#         fast_feature = nn.AvgPool2d(kernel_size=(
#             fast_feature.shape[2], 1))(fast_feature).squeeze(2)
#         # print(fast_feature.shape)
#         slow_feature = nn.AvgPool2d(kernel_size=(
#             slow_feature.shape[2], 1))(slow_feature).squeeze(2)
#         # print(slow_feature.shape)
#         feature = torch.cat([fast_feature, slow_feature], dim=1)
#         # print(feature.shape)

#         out = feature.view(feature.shape[0],-1)#.cuda()
#         # out = torch.flatten(feature, start_dim=1)
#         # out = torch.reshape(feature,(feature.shape[0],-1))
#         # print(out.shape)
#         proposal_classes = self._proposal_class(out)
#         # print(proposal_classes.shape)
        
#         return proposal_classes


# detection = Detection()#.cuda()
# fast_feature = torch.randn(320, 128, 1, 32)
# slow_feature = torch.randn(320, 256, 1, 32)
# y = detection(fast_feature, slow_feature)
# y.shape

class PoolConcatPathway(nn.Module):
    def __init__(
        self,
        pool,
        dim: int = 1,
    ) -> None:
        super().__init__()
        self.pool = pool
        
    def forward(self, x) -> torch.Tensor:
        output = []
        for ind in range(len(x)):
            if x[ind] is not None:
                if self.pool is not None and self.pool[ind] is not None:
                    x[ind] = self.pool[ind](x[ind])
                # print(99, x[ind].shape)
                output.append(x[ind])
        return torch.cat(output, 1)


_num_pathway=2
head_pool_kernel_sizes = ((111, 1), (2, 1))
pool_model = [
    nn.AvgPool2d(
        kernel_size=head_pool_kernel_sizes[idx],
        stride=(1, 1),
        padding=(0, 0),
    )
    for idx in range(_num_pathway)
]
poolconcat = PoolConcatPathway(pool_model)
fast_feature = torch.randn(320, 64, 125, 1)
slow_feature = torch.randn(320, 1024, 16, 1)
# fast_feature = torch.randn(320, 64, 32, 56)
# slow_feature = torch.randn(320, 1024, 4, 56)

y = poolconcat([fast_feature, slow_feature])
print(y.shape)

# torch.Size([320, 256, 32, 7])

    
class ResNetBasicHead(nn.Module):
    def __init__(self):
        super().__init__()
        dropout_rate=0.5
        in_features=1088
        out_features=40
        
        self.dropout = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(in_features, out_features)
        self.outputpool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.dropout(x)
    
        x = x.permute((0, 2, 3, 1))
        x = self.proj(x)
        x = x.permute((0, 3, 1, 2))
        
        x = self.outputpool(x)
        x = x.squeeze()
        return x
    
pooled = torch.randn(320, 1088, 15, 1)
head = ResNetBasicHead()
out = head(pooled)
print(out.shape)

# detection = Detection()
# fast_feature = torch.randn(320, 64, 125, 1)
# slow_feature = torch.randn(320, 1024, 16, 1)
# y = detection(fast_feature, slow_feature)
# y.shape

torch.Size([320, 1088, 15, 1])
torch.Size([320, 40])


In [11]:
# class Model(nn.Module):

#     def __init__(self):
#         super().__init__()
#         self.backbone = SlowFast()
#         self.detection = Detection()

#     def forward(self, input):
#         fast_feature, slow_feature = self.backbone(input)
#         # print(99, fast_feature.shape, slow_feature.shape)
#         y = self.detection(fast_feature, slow_feature)
#         return y

class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.backbone = SlowFast()
        # self.detection = Detection()
        
        _num_pathway=2
        head_pool_kernel_sizes = ((111, 1), (2, 1))
        pool_model = [
            nn.AvgPool2d(
                kernel_size=head_pool_kernel_sizes[idx],
                stride=(1, 1),
                padding=(0, 0),
            )
            for idx in range(_num_pathway)
        ]
        self.poolconcat = PoolConcatPathway(pool_model)
        self.head = ResNetBasicHead()


    def forward(self, input):
        fast_feature, slow_feature = self.backbone(input)
        # print(99, fast_feature.shape, slow_feature.shape)
        # y = self.detection(fast_feature, slow_feature)
        
        y = self.poolconcat([fast_feature, slow_feature])
        out = self.head(y)
        return out
    
model = Model()
y = model(x)
y.shape

torch.Size([320, 40])

In [12]:
def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0):
    
    ## init data
    
    # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k)
    train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k)
    train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False)

    ## init model

    eegnet = Model()

    model = LightningModelClassifier(
        optimizer=config.model.optimizer,
        scheduler=config.model.scheduler,
        optimizer_learning_rate=config.training.learning_rate,
        scheduler_warmup_epochs=config.training.num_warmup_epochs,
    )
    
    model.build_model(model=eegnet)

    ## train

    sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k)
    logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.run_name, sub_dir=sub_dir)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor])
    trainer.fit(model, train_loader, val_loader)
    
    ## test
    
    result = trainer.test(dataloaders=test_loader, verbose=False)
    test_acc = result[0]['test_acc_epoch']
    
    return test_acc

In [13]:
main_logger.write_to_log("Begin", break_line=True)

test_results_acc = {}
means = []

def k_fold_train_test_all_subjects():
    
    for test_subject_id in config.testing.test_subject_ids:
        print()
        print("running test_subject_id:", test_subject_id)
        
        if test_subject_id not in test_results_acc:
            test_results_acc[test_subject_id] = []
        
        test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=0)
            
        test_results_acc[test_subject_id].append(test_acc)
        means.append(test_acc)
        
        this_result = {
            "test_subject_id": test_subject_id,
            "acc": test_results_acc[test_subject_id],
        }        
        print(this_result)
        main_logger.write_to_log(this_result)

        
k_fold_train_test_all_subjects()

mean_acc = np.mean(means)
print()
print("mean all", mean_acc)
main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True)


running test_subject_id: 33


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 1234
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


{'test_subject_id': 33, 'acc': [0.02916666679084301]}

mean all 0.02916666679084301
