-
Notifications
You must be signed in to change notification settings - Fork 0
/
Models.py
executable file
·162 lines (137 loc) · 5.62 KB
/
Models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
from torch import nn
import torch.nn.functional as F
class attention1d(nn.Module):
def __init__(self, in_planes, ratios, K, temperature, init_weight=True):
super(attention1d, self).__init__()
self.avgpool = nn.AdaptiveAvgPool1d(1)
if in_planes!=3:
hidden_planes = int(in_planes*ratios)+1
else:
hidden_planes = K
self.fc1 = nn.Conv1d(in_planes, hidden_planes, 1, bias=True)
self.bn = nn.BatchNorm2d(hidden_planes)
self.fc2 = nn.Conv1d(hidden_planes, K, 1, bias=True)
self.temperature = temperature
if init_weight:
self._initialize_weights()
self.sigmoid = nn.Sigmoid()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m ,nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def updata_temperature(self):
if self.temperature!=1:
self.temperature -=3
print('Change temperature to:', str(self.temperature))
def forward(self, x):
x = self.avgpool(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x).view(x.size(0), -1)
return torch.sigmoid(x/self.temperature)
class Wave_Lead_Conv(nn.Module):
"""
Wave_Lead_Conv is a convolution model designed to convolve over a single scaleogram,
of shape (1,freq,time) generated by one EEG lead. this model is the portion of
onv2d_by_Leads model without the last linear layer that combines the output for each
lead.
"""
def __init__(self):
super(Wave_Lead_Conv,self).__init__()
self.conv1 = nn.Conv2d(1,4, kernel_size=(3,4), stride=(1,2), padding=(1,2))
self.conv1_bn = nn.BatchNorm2d(4)
self.maxPool1 = nn.MaxPool2d((2,2))
self.conv2 = nn.Conv2d(4,8, kernel_size=(4,3), stride=(2,2), padding=1)
self.conv2_bn = nn.BatchNorm2d(8)
self.maxPool2 = nn.MaxPool2d((2,2))
self.conv3 = nn.Conv2d(8,16, kernel_size=(3,4), stride=(1,2), padding=1)
self.conv3_bn = nn.BatchNorm2d(16)
self.maxPool3 = nn.MaxPool2d((2,2))
self.conv4 = nn.Conv2d(16,32, kernel_size=(2,4), stride=(1,1))
self.conv4_bn = nn.BatchNorm2d(32)
def forward(self,x):
#import pdb; pdb.set_trace()
#convolve over channels only
x = self.conv1(x)
x = torch.relu(x)
x = self.conv1_bn(x)
x = self.maxPool1(x)
x = self.conv2(x)
x = torch.relu(x)
x = self.conv2_bn(x)
x = self.maxPool2(x)
x = self.conv3(x)
x = torch.relu(x)
x = self.conv3_bn(x)
x = self.maxPool3(x)
x = self.conv4(x)
x = torch.relu(x)
x = self.conv4_bn(x)
x = x.view(-1, 32)
return x
class Wave_Fusion_Model(nn.Module):
"""
Wave Fusion Model. Contains 61 Wave Lead Convs that convolves over each
eeg Lead.
"""
def __init__(self, device = torch.device):
self.leads = 61
self.device = device
self.temperature = 35
super(Wave_Fusion_Model,self).__init__()
for i in range(self.leads):
self.add_module('Wave_Lead_Conv' + str(i), Wave_Lead_Conv())
self.Wave_Lead_Conv = AttrProxy(self, 'Wave_Lead_Conv')
self.attention = attention1d(in_planes=self.leads, ratios=0.25, K=self.leads,temperature=self.temperature)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(self.leads*32,488)
self.fc2 = nn.Linear(488,128)
self.fc3 = nn.Linear(128,3)
def forward(self, x):
"""
feeds each eeg channel in x to a Wave_Lead_Conv
x: a tensor of shape BatchSize x self.leads x 32 x 256
returns:
On training: a list of size (batchsize, num_lead) w/ each entry a [1,2] tensor of softmax probabilities for each class
On eval: a tensor containing the class losses for each data in the batch
"""
tmp = []
preds = []
bs = len(x[:,0,0,0])
preds = torch.zeros((bs, self.leads, 32)).to(self.device)
for j in range(self.leads):
#each lead to a wave_lead_conv. reshape to 1,1,32,250
preds[:,j,:] = self.Wave_Lead_Conv.__getitem__(j)(x[:,j,:,:].view(bs,1,32,250)).clone()
attention = self.attention(preds)
attention = torch.reshape(attention, [x.size(0),self.leads,1])
preds = attention * preds
preds = preds.view(-1, self.leads*32)
preds = self.dropout(preds)
preds = self.fc1(preds)
preds = torch.relu(preds)
preds = self.fc2(preds)
preds = torch.relu(preds)
preds = self.fc3(preds)
#return scores
return preds
################################### WAVEFUSION WLCNN CONTAINER OBJECT ###################################
class AttrProxy(object):
"""indexes Wave_Lead_Conv models as Wave_Lead_Conv0, Wave_Lead_Conv1,...
Wave_Lead_Conv63 in the Wave_Fusion_Model."""
def __init__(self, module, prefix):
"""
args:
module: the Wave_Lead_Conv component to be named
prefix: int
"""
self.module = module
self.prefix = prefix
def __getitem__(self, i):
"""retrieves the ith Wave_Lead_Conv from Wave_Fusion_Model."""
return getattr(self.module, self.prefix + str(i))