In [38]:
import torch
from torch import nn,optim
from torch.nn import functional as F
import pandas as pd
import numpy as np
import random
import pickle
import logging
from tqdm import tqdm
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
USE_CUDA = torch.cuda.is_available()

random.seed(2020)
np.random.seed(2020)
torch.manual_seed(2020)
if USE_CUDA:
    torch.cuda.manual_seed(2020)
# set cuda
gpu = 0
use_cuda = gpu >= 0 and torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(gpu)
    device = torch.device("cuda", gpu)
else:
    device = torch.device("cpu")
logging.info("Use cuda: %s, gpu id: %d.", use_cuda, gpu)

In [39]:
# 将train_data和test_data读取存入list中
def ReadTxtName(rootdir):
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            line = list(eval(line))
            lines.append(line)
    return lines

In [40]:
train_data_path = 'data_example/train_data.txt'
test_data_path = 'data_example/test_data.txt'

In [41]:
train_set = ReadTxtName(train_data_path)
test_set = ReadTxtName(test_data_path)

In [42]:
train_data = pd.DataFrame(columns=['user_id', 'item_seq_temp','cat_list','time_list','time_last_list','time_now_list','position_list','target','item_seq_len'],data = train_set)
test_data = pd.DataFrame(columns=['user_id', 'item_seq_temp','cat_list','time_list','time_last_list','time_now_list','position_list','target','item_seq_len'],data = test_set)

In [43]:
train_data.head()

Unnamed: 0,user_id,item_seq_temp,cat_list,time_list,time_last_list,time_now_list,position_list,target,item_seq_len
0,3522,"[118873, 190989, 73693, 354311, 73693, 354311,...","[571, 894, 482, 482, 482, 482, 482, 202, 330, ...","[398008, 398008, 398104, 398104, 398104, 39810...","[0, 0, 96, 0, 0, 0, 0, 864, 48, 0, 0, 0, 0, 0,...","[1032, 1032, 936, 936, 936, 936, 936, 72, 24, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[125165, 482, 399040]",26
1,3839,"[298668, 297576, 305398, 107015, 32892, 76809,...","[636, 506, 339, 339, 636, 25, 339, 339, 25, 94...","[401704, 401704, 401704, 401704, 401704, 40170...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[168523, 943, 401752]",23
2,3719,"[211347, 153658, 7223, 81509, 19699, 81509, 10...","[719, 719, 613, 352, 1333, 352, 352, 1333, 120...","[400216, 400216, 400480, 400552, 400552, 40057...","[0, 0, 264, 72, 0, 24, 0, 0, 0, 24, 24, 288, 0...","[1008, 1008, 744, 672, 672, 648, 648, 648, 648...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[329307, 344, 401224]",17
3,4362,"[95057, 68099, 161481, 366635, 150257, 303230,...","[1203, 704, 704, 704, 704, 704, 704, 704, 704,...","[399184, 400192, 400192, 400240, 400240, 40024...","[0, 1008, 0, 48, 0, 0, 0, 0, 0, 384, 0, 0, 72,...","[2520, 1512, 1512, 1464, 1464, 1464, 1464, 146...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[146956, 232, 401704]",19
4,1104,"[228127, 202790, 37771, 210693, 316530, 186423...","[537, 381, 194, 962, 194, 301, 194, 194, 1098,...","[397672, 397984, 398008, 398536, 398656, 39865...","[0, 312, 24, 528, 120, 0, 0, 0, 24, 312, 0, 0,...","[3864, 3552, 3528, 3000, 2880, 2880, 2880, 288...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[111155, 67, 401536]",18


In [44]:
train_data['item_seq_len'] = train_data['item_seq_len']-1

TypeError: can only concatenate list (not "int") to list

In [45]:
train_data.head()

Unnamed: 0,user_id,item_seq_temp,cat_list,time_list,time_last_list,time_now_list,position_list,target,item_seq_len
0,3522,"[118873, 190989, 73693, 354311, 73693, 354311,...","[571, 894, 482, 482, 482, 482, 482, 202, 330, ...","[398008, 398008, 398104, 398104, 398104, 39810...","[0, 0, 96, 0, 0, 0, 0, 864, 48, 0, 0, 0, 0, 0,...","[1032, 1032, 936, 936, 936, 936, 936, 72, 24, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[125165, 482, 399040]",25
1,3839,"[298668, 297576, 305398, 107015, 32892, 76809,...","[636, 506, 339, 339, 636, 25, 339, 339, 25, 94...","[401704, 401704, 401704, 401704, 401704, 40170...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[168523, 943, 401752]",22
2,3719,"[211347, 153658, 7223, 81509, 19699, 81509, 10...","[719, 719, 613, 352, 1333, 352, 352, 1333, 120...","[400216, 400216, 400480, 400552, 400552, 40057...","[0, 0, 264, 72, 0, 24, 0, 0, 0, 24, 24, 288, 0...","[1008, 1008, 744, 672, 672, 648, 648, 648, 648...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[329307, 344, 401224]",16
3,4362,"[95057, 68099, 161481, 366635, 150257, 303230,...","[1203, 704, 704, 704, 704, 704, 704, 704, 704,...","[399184, 400192, 400192, 400240, 400240, 40024...","[0, 1008, 0, 48, 0, 0, 0, 0, 0, 384, 0, 0, 72,...","[2520, 1512, 1512, 1464, 1464, 1464, 1464, 146...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[146956, 232, 401704]",18
4,1104,"[228127, 202790, 37771, 210693, 316530, 186423...","[537, 381, 194, 962, 194, 301, 194, 194, 1098,...","[397672, 397984, 398008, 398536, 398656, 39865...","[0, 312, 24, 528, 120, 0, 0, 0, 24, 312, 0, 0,...","[3864, 3552, 3528, 3000, 2880, 2880, 2880, 288...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[111155, 67, 401536]",17


In [46]:
train_data['item_seq_temp'] = train_data.item_seq_temp.apply(lambda x:[i+1 for i in x])

In [47]:
train_data = train_data[:1000]

In [48]:
for i in tqdm(range(len(train_data))):
    train_data['item_seq_temp'][i][train_data.item_seq_len[i]] = 0

100%|███████████████████████████████████| 1000/1000 [00:00<00:00, 33308.75it/s]


In [49]:
train_data.shape

(1000, 9)

In [50]:
input_length = 50
train_data['item_seq_temp'] = train_data['item_seq_temp'].apply(lambda x:x+[0]*(input_length-len(x))if len(x)<input_length else x[:input_length])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [51]:
# 对train_data和test_data进行padding,即item_count+1作为padding_idx
# item_id,user_id,cat_list_id进行了重新编码
train_data['target_item'] = train_data['target'].apply(lambda x:x[0]) 
test_data['target_item'] = test_data['target'].apply(lambda x:x[0]) 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [52]:
train_data.head()

Unnamed: 0,user_id,item_seq_temp,cat_list,time_list,time_last_list,time_now_list,position_list,target,item_seq_len,target_item
0,3522,"[118874, 190990, 73694, 354312, 73694, 354312,...","[571, 894, 482, 482, 482, 482, 482, 202, 330, ...","[398008, 398008, 398104, 398104, 398104, 39810...","[0, 0, 96, 0, 0, 0, 0, 864, 48, 0, 0, 0, 0, 0,...","[1032, 1032, 936, 936, 936, 936, 936, 72, 24, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[125165, 482, 399040]",25,125165
1,3839,"[298669, 297577, 305399, 107016, 32893, 76810,...","[636, 506, 339, 339, 636, 25, 339, 339, 25, 94...","[401704, 401704, 401704, 401704, 401704, 40170...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[168523, 943, 401752]",22,168523
2,3719,"[211348, 153659, 7224, 81510, 19700, 81510, 10...","[719, 719, 613, 352, 1333, 352, 352, 1333, 120...","[400216, 400216, 400480, 400552, 400552, 40057...","[0, 0, 264, 72, 0, 24, 0, 0, 0, 24, 24, 288, 0...","[1008, 1008, 744, 672, 672, 648, 648, 648, 648...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[329307, 344, 401224]",16,329307
3,4362,"[95058, 68100, 161482, 366636, 150258, 303231,...","[1203, 704, 704, 704, 704, 704, 704, 704, 704,...","[399184, 400192, 400192, 400240, 400240, 40024...","[0, 1008, 0, 48, 0, 0, 0, 0, 0, 384, 0, 0, 72,...","[2520, 1512, 1512, 1464, 1464, 1464, 1464, 146...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[146956, 232, 401704]",18,146956
4,1104,"[228128, 202791, 37772, 210694, 316531, 186424...","[537, 381, 194, 962, 194, 301, 194, 194, 1098,...","[397672, 397984, 398008, 398536, 398656, 39865...","[0, 312, 24, 528, 120, 0, 0, 0, 24, 312, 0, 0,...","[3864, 3552, 3528, 3000, 2880, 2880, 2880, 288...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[111155, 67, 401536]",17,111155


In [60]:
# 随便padding
def pad(data):
    for i in data.columns:
        if i =='cat_list':
            data[i] = data[i].apply(lambda x:x+[x[-1]-1]*(input_length-len(x))if len(x)<input_length else x[:input_length])
        elif i=='position_list':
            data[i] = data[i].apply(lambda x:[i for i in range(input_length)])

In [65]:
# input_length = 50
pad(train_data)

In [66]:
# 获得train_data和test_data
columns = ['user_id','item_seq_temp','cat_list','position_list','item_seq_len','target_item']
train_data= train_data[columns]
test_data = test_data[columns]

In [67]:
# 方便处理数据，转换为numpy格式
user_id = np.array(train_data['user_id'].tolist(),dtype = np.int32)
item_seq_temp = np.array(train_data['item_seq_temp'].tolist(),dtype = np.int32)
cat_list = np.array(train_data['cat_list'].tolist(),dtype = np.int32)
position_list = np.array(train_data['position_list'].tolist(),dtype = np.int32)
item_seq_len = np.array(train_data['item_seq_len'].tolist(),dtype = np.int64)
target_item = np.array(train_data['target_item'].tolist(),dtype = np.int32)

# 生成batch

In [68]:
def get_batch(x_user_id,x_item_seq_temp,x_cat_list,x_position_list,x_item_seq_len,y,batch_size,shuffle = True):
    assert x_user_id.shape[0] == y.shape[0]
    if shuffle:
        shuffled_index = np.random.permutation(y.shape[0])
        x_user_id,x_item_seq_len,x_cat_list,x_position_list,x_item_seq_len = x_user_id[shuffled_index],x_item_seq_len[shuffled_index],x_cat_list[shuffled_index],x_position_list[shuffled_index],x_item_seq_len[shuffled_index]
        y = y[shuffled_index]
    
    n_batches = int(x_user_id.shape[0]/batch_size)
    for i in range(n_batches-1):
        x_user_id_batch = x_user_id[i*batch_size:(i+1)*batch_size]
        x_item_seq_temp_batch = x_item_seq_temp[i*batch_size:(i+1)*batch_size]
        x_cat_list_batch = x_cat_list[i*batch_size:(i+1)*batch_size]
        x_position_list_batch = x_position_list[i*batch_size:(i+1)*batch_size]
        x_item_seq_len_batch = x_item_seq_len[i*batch_size:(i+1)*batch_size]
        y_batch = y[i*batch_size:(i+1)*batch_size]
        yield x_user_id_batch,x_item_seq_temp_batch,x_cat_list_batch,x_position_list_batch,x_item_seq_len_batch,y_batch

# 构造模型

In [69]:
class GetEmbedding(nn.Module):
    def __init__(self,parameter_path,input_length,embedding_dim):
        super(GetEmbedding,self).__init__()
        self.parameter_path = parameter_path
        self.parameters = self.get_parameter(self.parameter_path)
        self.user_count = self.parameters['user_count']
        self.item_count = self.parameters['item_count']
        self.category_count = self.parameters['category_count']
        self.input_length = input_length
        self.embedding_dim = embedding_dim
        self.user_id_embedding = nn.Embedding(self.user_count+3,self.embedding_dim)
        self.item_list_embedding = nn.Embedding(self.item_count+3,self.embedding_dim,padding_idx = 0)
        self.item_list_embedding_weight = self.item_list_embedding.weight
        self.category_list_embedding = nn.Embedding(self.category_count+3,self.embedding_dim)
        self.position_list_embeddig = nn.Embedding(self.input_length,self.embedding_dim)
        self.apply(self.init_weights)
    def init_weights(self, module):
        if isinstance(module, nn.Embedding):
            nn.init.xavier_uniform_(module.weight)
        
    def get_parameter(self,file_path):
        with open(file_path, 'rb') as f:  
            parameters = pickle.loads(f.read())
        return parameters

In [143]:
class NARM(nn.Module):
    def __init__(self,parameter_path,input_length,embedding_dim,hidden_size,n_layers = 1):
        super(NARM,self).__init__()
        self.parameter_path = parameter_path
        self.input_length = input_length
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.emb = GetEmbedding(self.parameter_path,self.input_length,self.embedding_dim)
        self.emb_dropout = nn.Dropout(0.25)
        self.gru = nn.GRU(self.embedding_dim,self.hidden_size,self.n_layers,batch_first = True)
        self.a_1 = nn.Linear(self.hidden_size,self.hidden_size,bias = False)
        self.a_2 = nn.Linear(self.hidden_size,self.hidden_size,bias = False)
        self.v_t = nn.Linear(self.hidden_size,1,bias = False)
        self.ct_dropout = nn.Dropout(0.5)
        self.b = nn.Linear(self.embedding_dim,2*self.hidden_size,bias = False)
        
        
       
    def forward(self,x_item_seq_temp_batch,x_item_seq_len_batch):
        item_list_emb = self.emb.item_list_embedding(x_item_seq_temp_batch)
        item_list_emb_dropout = self.emb_dropout(item_list_emb)
        item_list_emb_nopad = pack_padded_sequence(item_list_emb_dropout,x_item_seq_len_batch,batch_first=True,enforce_sorted=False)
        gru_out,hidden = self.gru(item_list_emb_nopad)
        gru_out,lengths = pad_packed_sequence(gru_out,batch_first=True,total_length=input_length)
        
        #fetch the last hidden state of last timestamp
        ht = hidden[-1]#[batch_size,hidden_size]
        c_global = ht
        
        q1 = self.a_1(gru_out.contiguous().view(-1,self.hidden_size)).view(gru_out.size())
        q2 = self.a_2(ht)
        
        mask = torch.where(x_item_seq_temp_batch>0,torch.tensor([1.]),torch.tensor([0.]))
        q2_expand = q2.unsqueeze(1).expand_as(q1)
        q2_masked = mask.unsqueeze(2).expand_as(q1)*q2_expand
        
        alpha = self.v_t(torch.sigmoid(q1+q2_masked).view(-1,self.hidden_size)).view(mask.size())
        c_local = torch.sum(alpha.unsqueeze(2).expand_as(gru_out)*gru_out,1)
        
        c_t = torch.cat([c_local,c_global],1)
        c_t = self.ct_dropout(c_t)
        
        return c_t,self.emb.item_list_embedding_weight
        
        
    

In [144]:
class Myloss(nn.Module):
    def __init__(self):
        super(Myloss,self).__init__()
        self.log_softmax = nn.LogSoftmax()
    def forward(self,emb,pred,truth):
        item_lookup_table_T = emb.t()
        logits = torch.matmul(pred,item_lookup_table_T)
        log_probs = self.log_softmax(logits)
        truth = torch.reshape(truth,[-1])
        one_hot_labels = F.one_hot(truth, num_classes=emb.shape[0])
        loss_origin = -torch.sum(log_probs.float() * one_hot_labels.float(), dim=-1)
        loss = torch.mean(loss_origin)
        return loss    

In [145]:
# train
batch_size = 10
parameter_path = 'data_example/parameters.pkl'
input_length = 50
embedding_dim = 32
hidden_size=16
n_layers = 1
narm = NARM(parameter_path,input_length,embedding_dim,hidden_size,n_layers = 1)
narm = narm.to(device)
narm.train()
criterion = Myloss()
train_loss = []
optimizer = optim.Adam(narm.parameters(),lr=0.001)
for epoch in range(10):
    for batch_idx,data in enumerate(get_batch(user_id,item_seq_temp,cat_list,position_list,item_seq_len,target_item,batch_size)):
        x_user_id_batch = torch.LongTensor(data[0])
        x_item_seq_temp_batch = torch.LongTensor(data[1])
        x_cat_list_batch = torch.LongTensor(data[2])
        x_position_list_batch = torch.LongTensor(data[3])
        x_item_seq_len_batch = torch.LongTensor(data[4])
        y_batch = torch.LongTensor(data[5])
        pred,item_list_emb_weight = narm(x_item_seq_temp_batch,x_item_seq_len_batch)
#         print('item_list_emb.shape',item_list_emb_weight.shape)
        loss = criterion(item_list_emb_weight,pred,y_batch)
        train_loss.append(loss.item())
        # Backward and optimizer
        # 1.优化器保存先前的梯度信息
        optimizer.zero_grad()
        # 2.计算梯度
        loss.backward()
#         3.梯度更新 w' = w - lr*grads
        optimizer.step()
        if batch_idx%200 == 0:
            print(epoch,batch_idx,loss.item())

  


0 0 12.804776191711426
1 0 12.44099235534668
2 0 10.49498462677002
3 0 10.475838661193848
4 0 9.599950790405273
5 0 9.551443099975586
6 0 7.664181709289551
7 0 8.336762428283691
8 0 7.855986595153809
9 0 7.7232513427734375


KeyboardInterrupt: 

In [96]:
input = torch.LongTensor([[1,2,1],[1,0,0]])

In [97]:
seq_len = torch.LongTensor([3,1])

In [98]:
input_emb = nn.Embedding(3,5,padding_idx = 0)(input)

In [99]:
input_emb 

tensor([[[-2.7515, -1.4493,  0.4747, -0.9737,  0.4582],
         [-0.8478,  0.0812,  1.1355, -1.6338,  1.3905],
         [-2.7515, -1.4493,  0.4747, -0.9737,  0.4582]],

        [[-2.7515, -1.4493,  0.4747, -0.9737,  0.4582],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<EmbeddingBackward>)

In [100]:
input_emb_nopad = pack_padded_sequence(input_emb,seq_len,batch_first=True)
input_emb_nopad

PackedSequence(data=tensor([[-2.7515, -1.4493,  0.4747, -0.9737,  0.4582],
        [-2.7515, -1.4493,  0.4747, -0.9737,  0.4582],
        [-0.8478,  0.0812,  1.1355, -1.6338,  1.3905],
        [-2.7515, -1.4493,  0.4747, -0.9737,  0.4582]],
       grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([2, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [101]:
embedding_size,hidden_size = 5,3
gru_output,hidden = nn.GRU(embedding_size,hidden_size,batch_first=True)(input_emb_nopad)

In [102]:
gru_output

PackedSequence(data=tensor([[-0.4256,  0.7522,  0.0735],
        [-0.4256,  0.7522,  0.0735],
        [-0.6064,  0.0024,  0.0175],
        [-0.7467,  0.7776,  0.0862]], grad_fn=<CatBackward>), batch_sizes=tensor([2, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [103]:
hidden

tensor([[[-0.7467,  0.7776,  0.0862],
         [-0.4256,  0.7522,  0.0735]]], grad_fn=<StackBackward>)

In [104]:
hidden = hidden[-1]

In [105]:
hidden

tensor([[-0.7467,  0.7776,  0.0862],
        [-0.4256,  0.7522,  0.0735]], grad_fn=<SelectBackward>)

In [106]:
gru_output,lengths = pad_packed_sequence(gru_output,batch_first=True)

In [107]:
gru_output

tensor([[[-0.4256,  0.7522,  0.0735],
         [-0.6064,  0.0024,  0.0175],
         [-0.7467,  0.7776,  0.0862]],

        [[-0.4256,  0.7522,  0.0735],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]], grad_fn=<TransposeBackward0>)

In [95]:
lengths

tensor([2, 1])

In [108]:
mask = torch.where(input>0,torch.tensor([1.]),torch.tensor([0.]))
mask

tensor([[1., 1., 1.],
        [1., 0., 0.]])

In [116]:
mask

tensor([[1., 1., 1.],
        [1., 0., 0.]])

In [109]:
# hidden:[batch_size,embedding_size]
#gru_output:[batch_size,seq_len.embedding_size]

hidden_expand = hidden.unsqueeze(1).expand_as(gru_output)
hidden_expand

tensor([[[-0.7467,  0.7776,  0.0862],
         [-0.7467,  0.7776,  0.0862],
         [-0.7467,  0.7776,  0.0862]],

        [[-0.4256,  0.7522,  0.0735],
         [-0.4256,  0.7522,  0.0735],
         [-0.4256,  0.7522,  0.0735]]], grad_fn=<ExpandBackward>)

In [110]:
hidden_masked = mask.unsqueeze(2).expand_as(gru_output)
hidden_masked

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]])

In [111]:
hidden_masked = hidden_masked*hidden_expand
hidden_masked

tensor([[[-0.7467,  0.7776,  0.0862],
         [-0.7467,  0.7776,  0.0862],
         [-0.7467,  0.7776,  0.0862]],

        [[-0.4256,  0.7522,  0.0735],
         [-0.0000,  0.0000,  0.0000],
         [-0.0000,  0.0000,  0.0000]]], grad_fn=<MulBackward0>)

In [113]:
vt = nn.Linear(hidden_size,1,bias = False)
alpha = vt(torch.sigmoid(gru_output+hidden_masked).view(-1,hidden_size))
alpha

tensor([[-0.3909],
        [-0.3208],
        [-0.4048],
        [-0.3747],
        [-0.1497],
        [-0.1497]], grad_fn=<MmBackward>)

In [117]:
alpha = alpha.view(mask.size())
alpha

tensor([[-0.3909, -0.3208, -0.4048],
        [-0.3747, -0.1497, -0.1497]], grad_fn=<ViewBackward>)

In [120]:
alpha = alpha.unsqueeze(2).expand_as(gru_output)
alpha

tensor([[[-0.3909, -0.3909, -0.3909],
         [-0.3208, -0.3208, -0.3208],
         [-0.4048, -0.4048, -0.4048]],

        [[-0.3747, -0.3747, -0.3747],
         [-0.1497, -0.1497, -0.1497],
         [-0.1497, -0.1497, -0.1497]]], grad_fn=<ExpandBackward>)

In [119]:
gru_output

tensor([[[-0.4256,  0.7522,  0.0735],
         [-0.6064,  0.0024,  0.0175],
         [-0.7467,  0.7776,  0.0862]],

        [[-0.4256,  0.7522,  0.0735],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]], grad_fn=<TransposeBackward0>)

In [122]:
c_local = torch.sum(alpha*gru_output,1)
c_local

tensor([[ 0.6631, -0.6095, -0.0692],
        [ 0.1595, -0.2819, -0.0275]], grad_fn=<SumBackward2>)