forked from HypoX64/candock
-
Notifications
You must be signed in to change notification settings - Fork 7
/
simple_test.py
54 lines (45 loc) · 1.29 KB
/
simple_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import util
import transformer
import dataloader
from options import Options
from creatnet import CreatNet
'''
@hypox64
19/04/03
'''
opt = Options().getparse()
net=CreatNet(opt.model_name)
if not opt.no_cuda:
net.cuda()
if not opt.no_cudnn:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
if opt.pretrained:
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth'))
# N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4
stage_map={0:'stage3',1:'stage2',2:'stage3',3:'REM',4:'Wake'}
def runmodel(eegdata):
eegdata = eegdata.reshape(1,-1)
eegdata = transformer.ToInputShape(eegdata,opt.model_name,test_flag =True)
eegdata = transformer.ToTensor(eegdata,no_cuda =opt.no_cuda)
with torch.no_grad():
out = net(eegdata)
pred = torch.max(out, 1)[1]
pred_stage=pred.data.cpu().numpy()
return pred_stage[0]
'''
you can change your input data here.
but the data needs meet the following conditions:
1.record for 1 epoch(30s)
2.fs = 100Hz
'''
eegdata = np.load('./datasets/simple_test_data.npy')
print('the shape of eegdata:',eegdata.shape)
stage = runmodel(eegdata)
print('the sleep stage of this signal is:',stage_map[stage])
plt.plot(eegdata)
plt.show()