In [ ]:
import pandas as pd
from torch.utils.data import DataLoader

from dataset import CFDataset
from model import lstm, resnet, transformer
import torch
import deep_learning as dl

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [ ]:
p300_df = pd.read_csv('data/P300_1200g_norm.csv', index_col=0)
r300_df = pd.read_csv('data/R300_1200g_norm.csv', index_col=0)

p300_file_list = p300_df['data_file'].to_numpy()
p300_label_list = p300_df['label'].to_numpy()

r300_file_list = r300_df['data_file'].to_numpy()
r300_label_list = r300_df['label'].to_numpy()

In [ ]:
p300_test_ds = CFDataset(p300_file_list, p300_label_list)
p300_test_dataloader = DataLoader(p300_test_ds, batch_size=256)
r300_test_ds = CFDataset(r300_file_list, r300_label_list)
r300_test_dataloader = DataLoader(r300_test_ds, batch_size=256)

In [None]:
train_names = ['transformer_r300', 'transformer_p300', 'resnet_p300', 'resnet_r300', 'lstm_r300', 'lstm_p300']

for name in train_names:
    if 'r300' in name:
        test_dataloader = p300_test_dataloader
        input_size = p300_test_ds.input_size()
    else:
        test_dataloader = r300_test_dataloader
        input_size = r300_test_ds.input_size()

    if name.startswith('lstm'):
        print('lstm model')
        model = lstm.LSTM(input_size=input_size, hidden_size=100, num_layers=2, out_size=10)
    elif name.startswith('resnet'):
        print('resnet model')
        model = resnet.ResNet(hidden_sizes=[100] * 6, num_blocks=[2] * 6, input_dim=input_size, in_channels=64, n_classes=10)
    else:
        print('transformer model')
        model = transformer.TransformerModel(input_size=input_size, hidden_dim=128, num_classes=10, num_layers=3)

    model.to(device)

    for i in range(5):
        true_list, predicted_list = dl.test(model, model_path='checkpoints/{}_cv{}.pth'.format(name, i + 1),
                                            test_dataloader=test_dataloader)
        res = [item1 == item2 for item1, item2 in zip(true_list, predicted_list)]
        acc = res.count(True) / len(true_list)
        print('model: {}, cv: {}, acc:{}'.format(name, i + 1, acc))