In [1]:
import torch
from torch import nn
import pytorch_lightning as pl



In [2]:
doc_len = 30 # 30 sentences are in each document
sent_len = 50 # 50 words are in each sentence

word_num = 1000 # 1000 words are in Embedding table
E_D = 128 # Embedding Dimension
H = 256 # Hidden State
S = 10 # Relative position Embedding table
P_V = 50 # Absolute postion Embedding table
P_D = 64 # Position Embedding Dimension

fc = nn.Linear(2*H, 2*H) # to create Doc_Rep

word_embed = nn.Embedding(word_num, E_D)
rel_pos_embed = nn.Embedding(S, P_D)
abs_pos_embed = nn.Embedding(P_V, P_D)

content = nn.Linear(2*H, 1, bias=False)
salience = nn.Bilinear(2*H, 2*H, 1, bias=False)
novelty = nn.Bilinear(2*H, 2*H, 1, bias=False)
abs_pos_imp_layer = nn.Linear(P_D, 1, bias=False)
rel_pos_imp_layer = nn.Linear(P_D, 1, bias=False)
bias = nn.Parameter(torch.FloatTensor(1).uniform_(-0.1,0.1))


# Word - Level RNN
word_rnn = nn.GRU(
    input_size = E_D,
    hidden_size = H,
    bidirectional = True,
    batch_first = True
)

# Sentence - Level RNN
sent_rnn = nn.GRU(
    input_size = 2*H,
    hidden_size = H,
    bidirectional = True,
    batch_first = True
)

In [3]:
def Avg_pool(inputs, kernel_size):
    
    m = nn.MaxPool1d(kernel_size=kernel_size)
    
    return torch.stack([m(inputs[i].T.unsqueeze(0)).squeeze(0).T.squeeze(0) for i in range(inputs.size(0))],dim=0)

In [4]:
# Example Data

# x = (the number of Documents, the number of sentences in each document, the number of words in each sentence)
x = torch.randint(1, 1000, (10, doc_len, sent_len))

In [5]:
# Input Layer
x = word_embed(x)

In [6]:
# Word Layer

word_out = []

# inputs = (10, 30, 50, 128) = (batch_size, sentences, words, word_embedding_dimension)
for doc in x:
    outputs, _ = word_rnn(x[0])

    word_out.append(Avg_pool(outputs, sent_len))
    
# outputs = (10, 30, 512) = (batch_size, sentenes, word_embedding_dimension)
word_out = torch.stack(word_out, dim=0)

In [25]:
# Sentence Layer
sent_out, _ = sent_rnn(word_out)

sent_out = Avg_pool(sent_out, doc_len)

probs = []

for index, doc in enumerate(sent_out):
    Doc_Rep = torch.tanh(fc(doc)).unsqueeze(0)
    
    s = torch.zeros(1, 2*H, requires_grad=True)
    
    for position, h in enumerate(word_out[index]):
        
        h = h.view(1, -1)
        
        abs_pos = abs_pos_embed(torch.tensor(position))
        
        rel_index = torch.tensor(torch.tensor(round(position * 10 / sent_len)))
        rel_pos = rel_pos_embed(rel_index)
        
        prob = torch.sigmoid(
            content(h) +
            salience(h,Doc_Rep) -
            novelty(h,torch.tanh(s)) +
            abs_pos_imp_layer(abs_pos) +
            rel_pos_imp_layer(rel_pos) +
            bias
        )
        
        s = s + torch.matmul(prob,h)
        
        probs.append(prob)
        
#return torch.cat(probs).squeeze()

