In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import random_split
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary

In [12]:
class EEGNet(nn.Module):
    def __init__(self, num_classes=6):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 4), stride=(1, 2))
        self.bn1 = nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.dropout1 = nn.Dropout(p=0.25)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
        self.bn2 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.dropout2 = nn.Dropout(p=0.25)
        self.fc1 = nn.Linear(1024, 128)
        self.dropout3 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, 6)


    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        # print('x:', x.shape)
        x = self.conv1(x)
        # print('conv1:', x.shape)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        # print('pool1:', x.shape)
        x = self.dropout1(x)
        x = self.conv2(x)
        # print('conv2:', x.shape)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.pool2(x)
        # print('pool2:', x.shape)
        x = self.dropout2(x)
        x = x.view(x.size(0), -1)
        # print('flatten:', x.shape)
        x = self.fc1(x)
        # print('fc1:', x.shape)
        x = torch.relu(x)
        x = self.dropout3(x)
        x = self.fc2(x)
        # print('fc2:', x.shape)
        return x

In [13]:
model = EEGNet()
model.eval()

save_path = 'C:\\Users\\a1882\Desktop\EEG\\normal\model\\cnn_128_100e_198.pt'
# torch.save(model, save_path)
# loaded_model = torch.load(save_path)
# torch.save(model.state_dict(), save_path)
loaded_model = EEGNet()
loaded_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [14]:
# import onnx
# import onnxruntime
# from torch import nn
# onnx_file_name = "C:\\Users\\a1882\Desktop\EEG\\normal\model\\cnn_conv3_128_100e_2_97.onnx"
# torch.load(save_path, map_location=torch.device('cpu'))

In [15]:
torch.save(model, "C:\\Users\\a1882\Desktop\EEG\\normal\model\\cnn.pt")

In [16]:
# export to onnx
dummy_input = torch.randn(1, 32, 128)
torch.onnx.export(model, dummy_input, "C:\\Users\\a1882\Desktop\EEG\\normal\model\\cnn.onnx")   # model being run


verbose: False, log level: Level.ERROR

