In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

class Basic_AP(nn.Module):
    #attentive pooling
    def __init__(self, vocab_size, embedding_size, input_size, padding_idx=0):
        super(Basic_AP, self).__init__()

        self.embeds = nn.Embedding(vocab_size, embedding_size, padding_idx=padding_idx)
        self.weight = Parameter(torch.Tensor(embedding_size, embedding_size))
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.mp = nn.MaxPool1d(input_size, stride=1)
        #self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.out_f = nn.Linear(embedding_size*2, 1, bias=False)
        
    def forward(self, x1, x2):
        Q = self.embeds(x1)
        A = self.embeds(x2)
        G = torch.tanh(Q @ self.weight @ A.transpose(1, 2))

        attention_Q = torch.softmax(self.mp(G), dim=1)
        attention_A = torch.softmax(self.mp(G.transpose(1, 2)), dim=1)
        
        r_Q = Q.transpose(1, 2) @ attention_Q
        r_A = A.transpose(1, 2) @ attention_A
        
        output = torch.cat((r_Q, r_A), 1).squeeze(-1)
        output = torch.sigmoid(self.out_f(output))

        return output