In [1]:
# attentive_pooling
# Santos, Cicero dos, et al. "Attentive pooling networks." arXiv preprint arXiv:1602.03609 (2016).

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.nn.utils import weight_norm
import numpy as np

class Attentive_Pooling(nn.Module):
    #attentive pooling
    def __init__(self, vocab_size, embedding_size, input_size, num_filters=50,
                 kernel_size=4, padding_idx=0, dropout = 0.1):
        
        super(Attentive_Pooling, self).__init__()

        self.embeds = nn.Embedding(vocab_size, embedding_size, padding_idx=padding_idx)

        padding = kernel_size - 1
        self.conv = nn.Conv1d(embedding_size, num_filters, kernel_size, padding=padding)
        self.dropout = nn.Dropout(dropout)   
        
        self.net = nn.Sequential(self.conv, self.dropout)

        width = input_size + 2*padding - (kernel_size-1)
        self.mp = nn.MaxPool1d(width, stride=1)
        self.weight = Parameter(torch.Tensor(num_filters, num_filters))        
        
        self.out_f = nn.Linear(num_filters*2, 1, bias=False)
        
        self.init_weights()
        
    def init_weights(self):
        self.conv.weight.data.normal_(0, 0.01)
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            
    def forward(self, x1, x2):
        Q = self.embeds(x1).transpose(1, 2)
        A = self.embeds(x2).transpose(1, 2)
        Q = self.net(Q)
        A = self.net(A)

        G = torch.tanh(Q.transpose(1, 2) @ self.weight @ A)

        attention_Q = torch.softmax(self.mp(G), dim=1)
        attention_A = torch.softmax(self.mp(G.transpose(1, 2)), dim=1)

        r_Q = Q @ attention_Q
        r_A = A @ attention_A
        
        output = torch.cat((r_Q, r_A), 1).squeeze(-1)
        output = torch.sigmoid(self.out_f(output))

        #return r_Q, r_A
        #output = torch.sigmoid(self.cos(r_Q, r_A))
        return output

class Attentive_Pooling_1d(nn.Module):
    #attentive pooling
    def __init__(self, vocab_size, embedding_size, input_size, levels = 3, num_filters=50,
                 kernel_size=3, padding_idx=0, dropout = 0.1):
        
        super(Attentive_Pooling_1d, self).__init__()

        layers = []
        for i in range(levels):
            dilation_size = 2 ** i
            in_channels = embedding_size if i == 0 else num_filters
            out_channels = embedding_size if i==levels-1 else num_filters
            layers += [ConvBlock_1d(in_channels, out_channels, kernel_size, stride=1, 
                                     dilation=dilation_size,
                                     padding=(kernel_size-1)//2 * dilation_size, 
                                     dropout=dropout)]

        self.net = nn.Sequential(*layers)
        
        self.embeds = nn.Embedding(vocab_size, embedding_size, padding_idx=padding_idx)
        self.mp = nn.MaxPool1d(input_size, stride=1)
        self.weight = Parameter(torch.Tensor(embedding_size, embedding_size))        
        
        self.out_f = nn.Linear(embedding_size*2, 1, bias=False)
        
        self.init_weights()

        
    def init_weights(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            
    def forward(self, x1, x2):
        Q_i = self.embeds(x1).transpose(1, 2)
        A_i = self.embeds(x2).transpose(1, 2)
        Q = self.net(Q_i)
        A = self.net(A_i)

        G = torch.tanh(Q.transpose(1, 2) @ self.weight @ A)

        attention_Q = torch.softmax(self.mp(G), dim=1)
        attention_A = torch.softmax(self.mp(G.transpose(1, 2)), dim=1)

        r_Q = (Q+Q_i) @ attention_Q
        r_A = (A+A_i) @ attention_A
                
        output = torch.cat((r_Q, r_A), 1).squeeze(-1)
        output = torch.sigmoid(self.out_f(output))

        #return r_Q, r_A
        #output = torch.sigmoid(self.cos(r_Q, r_A))
        return output

class ConvBlock_1d(nn.Module):
    def __init__(self, c_input, c_output, kernel_size, stride, dilation, padding, dropout=0.1):
        super(ConvBlock_1d, self).__init__()
        
        self.conv = nn.Conv1d(c_input, c_output, kernel_size, stride=stride, 
                                          dilation=dilation, padding=padding)
        self.dropout = nn.Dropout(dropout)   
        self.net = nn.Sequential(self.conv, self.dropout)
        self.init_weights()

    def init_weights(self):
        self.conv.weight.data.normal_(0, 0.01)
        #self.conv.weight.data.fill_(1)

    def forward(self, x):
        out = self.net(x)
        return out
    
    
class Attentive_Pooling_2d(nn.Module):
    #attentive pooling
    def __init__(self, vocab_size, embedding_size, input_size, levels = 3,
                 kernel_size=3, padding_idx=0, dropout = 0.1):
        
        super(Attentive_Pooling_2d, self).__init__()
        self.BN = nn.BatchNorm1d(input_size, affine=False)
        layers = []
        for i in range(levels):
            dilation_size = kernel_size ** i
            in_channels = 1
            out_channels = 1
            layers += [ConvBlock_2d(in_channels, out_channels, kernel_size, stride=1, 
                                     dilation=dilation_size,
                                     padding=(kernel_size-1)//2 * dilation_size, 
                                     dropout=dropout)]

        self.net = nn.Sequential(*layers)
        
        self.embeds = nn.Embedding(vocab_size, embedding_size, padding_idx=padding_idx)
        self.mp = nn.MaxPool1d(input_size, stride=1)
        self.weight = Parameter(torch.Tensor(embedding_size, embedding_size))        
        
        self.out_f = nn.Linear(embedding_size*2, 1, bias=False)
        
        self.init_weights()

        
    def init_weights(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            
    def forward(self, x1, x2):
        Q_i = self.embeds(x1).unsqueeze(1)
        A_i = self.embeds(x2).unsqueeze(1)
        Q = (self.net(Q_i)).squeeze(1)
        A = (self.net(A_i)).squeeze(1)
        G = torch.tanh(Q @ self.weight @ A.transpose(1, 2))
        #G = Q @ A.transpose(1, 2) / np.sqrt(Q.shape[2])
        attention_Q = torch.softmax(self.mp(G), dim=1)
        attention_A = torch.softmax(self.mp(G.transpose(1, 2)), dim=1)
        
        Q_i = Q_i.squeeze(1)
        A_i = A_i.squeeze(1)
        r_Q = (Q_i+Q).transpose(1, 2) @ attention_Q
        r_A = (A_i+A).transpose(1, 2) @ attention_A
        
        output = torch.cat((r_Q, r_A), 1).squeeze(-1)
        output = torch.sigmoid(self.out_f(output))

        #return r_Q, r_A
        #output = torch.sigmoid(self.cos(r_Q, r_A))
        return output

class ConvBlock_2d(nn.Module):
    def __init__(self, c_input, c_output, kernel_size, stride, dilation, padding, dropout=0.1):
        super(ConvBlock_2d, self).__init__()
        
        self.conv = nn.Conv2d(c_input, c_output, (kernel_size, 1), stride=stride, 
                                          dilation=(dilation, 1), padding=(padding,0))
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)   
        
        self.net = nn.Sequential(self.conv, self.relu, self.dropout)
        self.init_weights()

    def init_weights(self):
        self.conv.weight.data.normal_(0, 0.01)
        #self.conv.weight.data[0,0]=1

    def forward(self, x):
        out = self.net(x)
        return out