In [1]:
import torch
import torch.nn as nn
import scipy.io
import torch.nn.functional as F
from torch.autograd import Variable

In [8]:
dd1=torch.randn(2,32,128,128) 
dd1=dd1.cuda()

In [3]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x) 

In [4]:
class CNN_score(nn.Module):
    def __init__(self, num_channel, hidden_size):
        super(CNN_score, self).__init__()
        self.num_channel = num_channel
        self.hidden_size = hidden_size
        # +: layer1->16
        # ++: layer1, layer0->16
        self.layer0 = nn.Conv1d(num_channel, 1, kernel_size=9, padding=4) #9 4
        nn.init.xavier_uniform(self.layer0.weight, gain=nn.init.calculate_gain('tanh'))
      
        self.layer1 = nn.Conv2d(2, 1, kernel_size=(num_channel,9), padding=(0,4), stride=(num_channel,1)) #
        nn.init.xavier_uniform(self.layer1.weight, gain=nn.init.calculate_gain('tanh'))

        self.fc00 = nn.Linear(hidden_size, num_channel) # 
        nn.init.xavier_uniform(self.fc00.weight, gain=nn.init.calculate_gain('tanh'))
        self.fc01 = nn.Linear(hidden_size, num_channel) #
        nn.init.xavier_uniform(self.fc01.weight, gain=nn.init.calculate_gain('tanh'))
        
        self.fc1 = nn.Linear(hidden_size, num_channel) # 
        nn.init.xavier_uniform(self.fc1.weight, gain=nn.init.calculate_gain('tanh'))
        self.fc2 = nn.Linear(num_channel, num_channel) # 
        nn.init.xavier_uniform(self.fc2.weight, gain=nn.init.calculate_gain('tanh'))

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def forward(self, h_i, pre_h_i, pre_s):
        # shape: batch x channel x hidden_size
        # shape: batch x channel x hidden_size  added by myself
        # shape: batch x channel
        
        out_h_i = self.layer0(h_i) # batch x 1 x hidden_size  把channel变成了1
        out_h_i = self.fc00(out_h_i.view(out_h_i.size(0),-1)) # batch x channel
        out_pre_h_i = self.layer0(pre_h_i) # batch x 1 x something124
        out_pre_h_i = self.fc01(out_pre_h_i.view(out_pre_h_i.size(0),-1)) # batch x channel
        
        
        h_i = h_i.view(h_i.size(0),-1,h_i.size(1),h_i.size(2)) # batch x 1 x channel x hidden_size
        pre_h_i = pre_h_i.view(pre_h_i.size(0),-1,pre_h_i.size(1),pre_h_i.size(2)) # batch x 1 x channel x hidden_size
        hh = torch.cat((h_i,pre_h_i), 1) # batch x 2 x channel x hidden_size
        
        out_hh = self.layer1(hh) # batch x 1 x hidden_size 又把channel变成了1
        out_hh = self.fc1(out_hh.view(out_hh.size(0),-1)) # batch x channel
        pre_s = self.fc2(pre_s)
        
        out = self.tanh(out_h_i + out_pre_h_i + out_hh + pre_s)
        return out  #batch x channel


In [5]:
class Attention(nn.Module):
    def __init__(self, num_channel, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        

        self.attention_layer = CNN_score(num_channel, hidden_size)
            
    def forward(self, h):
        # shape: batch x seq x channel x hidden_size
        batch_size = h.size(0)
        seq_size = h.size(1)
        num_channel = h.size(2)

        context_matrix = to_var(torch.zeros((batch_size, seq_size, num_channel, self.hidden_size))) # batch x seq x channel x hidden_size
        
        for i in range(h.size(1)): #seq_size
            hh_i = h[:,i,:,:] # current hidden state: batch x channel x hidden_size
            if i == 0:
                scores = to_var(torch.zeros((batch_size, 1, num_channel)))
                pre_hh_i = h[:, i, :, :] * 0.0 # batch x channel x hidden_size
            else:
                scores = to_var(torch.zeros((batch_size, i, num_channel))) # batch x sub_seq_size x channel
                pre_hh_i = h[:, :i, :, :] # previous hidden states: batch x sub_seq_size x channel x hidden_size
                for j in range(pre_hh_i.size(1)): #sub_seq_size
                    if j == 0:
                        scores[:,j,:] = self.energy(hh_i, pre_hh_i[:,j,:,:], to_var(torch.zeros((batch_size, num_channel))))
                    else:
                        pre_score = scores[:,j-1,:].clone() #克隆但不改变本体
                        scores[:,j,:] = self.energy(hh_i, pre_hh_i[:,j,:,:], pre_score) # batch x channel
            
            
            scores = self.normalization(scores, 2) # batch x sub_seq_size x channel

            scores = scores.view(scores.size(0),scores.size(1),scores.size(2),-1) # batch x sub_seq_size x channel x 1
            scores = scores.expand(scores.size(0),scores.size(1),scores.size(2),self.hidden_size) # batch x sub_seq_size x channel x hidden_size
            
            context = pre_hh_i * scores # batch x sub_seq_size x channel x hidden_size
            context = context.sum(1) # batch x channel x hidden_size
            context_matrix[:,i,:,:] = context # batch x 1 x channel x hidden_size
            
        # batch x seq x 1 x channel x hidden_size
        context_matrix = context_matrix.view(context_matrix.size(0),context_matrix.size(1),-1,context_matrix.size(2),context_matrix.size(3))
        h = h.view(h.size(0),h.size(1),-1,h.size(2),h.size(3))
        out = torch.cat([context_matrix,h],2) # batch x seq x 2 x channel x hidden_size
        
        return out

    def energy(self, hidden_i, pre_hidden_i, pre_scores): 
        # shape: batch x channel x hidden_size
        # shape: batch x channel

        energies = to_var(torch.zeros((hidden_i.size(0), hidden_i.size(1)))) # batch x channel
        h_i = hidden_i.clone()
        pre_h_i = pre_hidden_i.contiguous()
        energies = self.attention_layer(h_i, pre_h_i, pre_scores)  #放CNN_score中
        
        return energies  #batch x channel
        
    def normalization(self, scores, gamma):
        # shape: batch x sub_seq_size x channel
        sub_seq_size = scores.size(1) 
        num_channel = scores.size(2)
        gamma_d = self.relu(scores).sum(2) # batch x sub_seq_size
        gamma_d_sum = gamma_d.sum(1,keepdim=True) + 1e-8 # batch x 1
        gamma_d_sum = gamma_d_sum.expand(gamma_d_sum.size(0),sub_seq_size) # batch x sub_seq_size
        gamma_d = gamma_d / gamma_d_sum # batch x sub_seq_size
        gamma_d = gamma_d.view(gamma_d.size(0),gamma_d.size(1),-1) # batch x sub_seq_size x 1
        gamma_d =  gamma_d.expand(gamma_d.size(0),gamma_d.size(1),num_channel) # batch x sub_seq_size x channel
        
        
        scores = self.sigmoid(scores)
        out = to_var(torch.zeros(scores.size(0),sub_seq_size,num_channel)) # batch x sub_seq_size x channel
        for i in range(sub_seq_size):
            scores_i_sum = scores[:,i,:].sum(1,keepdim=True) + 1e-8 # batch x 1
            scores_i_sum = scores_i_sum.expand(scores_i_sum.size(0),num_channel) # batch x channel
            out[:,i,:] = scores[:,i,:] / scores_i_sum        
           
        out  = gamma_d * out # batch x sub_seq_size x channel
        out = out.view(out.size(0),-1) # batch x (sub_seq_size * channel)
        out = F.softmax(gamma * out) # batch x (sub_seq_size * channel) 
        out = out.view(out.size(0), sub_seq_size, num_channel) # batch x sub_seq_size x channel
        
        return out

In [9]:
djgnet=Attention(num_channel=128,hidden_size=128)

  if __name__ == '__main__':
  if sys.path[0] == '':
  from ipykernel import kernelapp as app


In [10]:
djgnet=djgnet.to('cuda')
yy=djgnet(dd1)
print(type(yy))
print(len(yy))
print(yy.shape)



<class 'torch.Tensor'>
2
torch.Size([2, 32, 2, 128, 128])
