In [34]:
import torch
import torch.nn as nn
import scipy.io
import torch.nn.functional as F
from torch.autograd import Variable

In [35]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x) 

In [36]:
dd=torch.randn(10,23,129) 
dd=dd.cuda()

In [37]:
class AE_0(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AE_0, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
              
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(5,3), stride=2, padding=(0,2))  
        self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True)  
        self.conv2 =  nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=(1,2)) 
        self.pool2 = nn.MaxPool2d(2, stride=1, return_indices=True) 
        self.fc1 = nn.Linear(8*2*17, self.hidden_size)
        
        self.fc2 = nn.Linear(self.hidden_size, 8*2*17)
        self.unpool2 = nn.MaxUnpool2d(2, stride=1) 
        self.deconv2 = nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=(1,2)) 
        self.unpool1 = nn.MaxUnpool2d(2, stride=2)
        self.deconv1 = nn.ConvTranspose2d(16, 1, kernel_size=(5,3), stride=2, padding=(0,2))
        
        self.relu = nn.ReLU()  
    
    def forward(self, x):
        # shape: batch x 1 x channel x raw_feature129
        encoded = self.relu(self.conv1(x))
        encoded_size1 = encoded.size()
        encoded, indices1 = self.pool1(encoded)
        encoded = self.relu(self.conv2(encoded))
        encoded_size2 = encoded.size()
        encoded, indices2 = self.pool2(encoded) 
        encoded_size3 = encoded.size()
        
        encoded = self.relu(self.fc1(encoded.view(encoded.size(0),-1))) # batch x hidden_size
        
        decoded = self.relu(self.fc2(encoded))
        decoded = decoded.view(encoded_size3)        

        decoded = self.unpool2(decoded, indices2, output_size=encoded_size2)
        decoded = self.relu(self.deconv2(decoded))
        decoded = self.unpool1(decoded, indices1, output_size=encoded_size1)
        decoded = self.sigmoid(self.deconv1(decoded))
        
        return encoded, decoded

In [38]:
class AE_1(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AE_1, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=2)  
        self.pool1 = nn.MaxPool1d(2, stride=2, return_indices=True)  
        self.conv2 =  nn.Conv1d(16, 8, kernel_size=3, stride=2, padding=2) 
        self.pool2 = nn.MaxPool1d(2, stride=1, return_indices=True) 
        self.fc1 = nn.Linear(8*17, self.hidden_size)
        
        self.fc2 = nn.Linear(self.hidden_size, 8*17)
        self.unpool2 = nn.MaxUnpool1d(2, stride=1) 
        self.deconv2 = nn.ConvTranspose1d(8, 16, kernel_size=3, stride=2, padding=2) 
        self.unpool1 = nn.MaxUnpool1d(2, stride=2) 
        self.deconv1 = nn.ConvTranspose1d(16, 1, kernel_size=3, stride=2, padding=2)     
    
    def forward(self, x):
        # shape: batch x 1 x raw_featuture768
        encoded = self.relu(self.conv1(x))
        encoded_size1 = encoded.size()
        encoded, indices1 = self.pool1(encoded)
        encoded = self.relu(self.conv2(encoded))
        encoded_size2 = encoded.size()
        encoded, indices2 = self.pool2(encoded) 
        encoded_size3 = encoded.size()
        
        encoded = self.relu(self.fc1(encoded.view(encoded.size(0),-1))) # batch x hidden_size
        
        decoded = self.relu(self.fc2(encoded))
        decoded = decoded.view(encoded_size3)

        decoded = self.unpool2(decoded, indices2) # new added due to some reasons
#        decoded = self.unpool2(decoded, indices2, output_size=encoded_size2)
        decoded = self.relu(self.deconv2(decoded)) 
        decoded = self.unpool1(decoded, indices1) # new added due to some reasons
#        decoded = self.unpool1(decoded, indices1, output_size=encoded_size1)
        decoded = self.sigmoid(self.deconv1(decoded))
        decoded = decoded.view(decoded.size(0),-1) # new added due to some reasons

        return encoded, decoded

In [70]:
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, num_channel, dropout_rate=0.5):
        super(Net, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        self.num_channel = num_channel 
        self.relu = nn.ReLU()
        
        self.ae0 = AE_0(input_size*num_channel, hidden_size)
        self.ae1 = AE_1(input_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
                 
        self.fc1 = nn.Linear(3072, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, num_classes) 
    
    def forward(self, x):        
        # shape: batch x channel x raw_feature
        
        encoded_c = to_var(torch.zeros((x.size(0), self.num_channel, self.hidden_size))) # batch x channel x hidden_size        
        decoded_c = to_var(torch.zeros((x.size(0), self.num_channel, self.input_size))) # batch x channel x input_size
        encoded_g, decoded_g = self.ae0(x.view(x.size(0),-1,x.size(1),x.size(2))) # encoder: batch x hidden_size (batch x 8 x 2 x 43)
                                                                                  # decoder: batch x 1 x channel x input_size
        decoded_g = decoded_g[:,0,:,:] #相当于去掉第二个维度 sequence
        for i in range(self.num_channel):
            out1 = x[:,i,:].contiguous()
            out1 = out1.view(out1.size(0),-1,out1.size(1)) #batch*1*raw_feature
            encoded_c[:,i,:], decoded_c[:,i,:] = self.ae1(out1)  #把x去掉channel挨个channel放进去
            
        if self.training == True:
            encoded_c = self.channel_selection(encoded_c,11) # 2*batch x channel x hidden_size

        out = torch.cat([encoded_c.view(encoded_c.size(0),-1),encoded_g.view(encoded_g.size(0),-1)],1)  #128*24=3072
        out = self.dropout(out)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)

        return decoded_c, decoded_g,out
        
    def channel_selection(self, x, k):
        # shape: batch x channel x hidden_size
        num_batch = x.size(0)
        num_channel = x.size(1)
        
        x_sum = x.sum(2) # batch x channel
        _, indices = torch.topk(x_sum,k,1) # batch x k
        ind = indices.view(1,indices.size(0)*indices.size(1)) # 1 x (batch*k)
        
#        ii = torch.arange(0,num_batch) 
        ii = to_var(torch.arange(0,num_batch)) # new added due to some reasons
        ii = ii.view(ii.size(0),1)
        iii = ii.expand(ii.size(0),k).contiguous()
        iii = iii.view(1,iii.size(0)*iii.size(1)).long() # 1 x (batch * k)
        
        full_indices = torch.cat([iii,ind.data],0) # 2 x (batch * k)
#        values = torch.ones(full_indices.size(1)) # (batch * k) 
        values = to_var(torch.ones(full_indices.size(1))) # (batch * k) # new added due to some reasons
        
        mask = torch.sparse.FloatTensor(full_indices, values, torch.Size([num_batch,num_channel])).to_dense()
        
        mask = mask.view(mask.size(0),mask.size(1),-1) # batch x channel x 1
        mask = mask.expand(mask.size(0),mask.size(1),x.size(2)) # batch x channel x hidden_size
        mask = to_var(mask)
        out = mask * x
        return out

In [71]:
djgnet=Net(input_size=129, hidden_size=128, num_layers=2, num_classes=2, num_channel=23)
djgnet=djgnet.to('cuda')

In [72]:
yy=djgnet(dd)
print(type(yy))
print(len(yy))
print(yy[0].shape)
print(yy[1].shape)
print(yy[2].shape)

<class 'tuple'>
3
torch.Size([10, 23, 129])
torch.Size([10, 23, 129])
torch.Size([10, 2])
