In [None]:
# a variant of attention-based convolutional network
# Yin, Wenpeng, et al. "Abcnn: Attention-based convolutional neural network for modeling sentence pairs." 2016 

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.nn.utils import weight_norm

class ABCNN(nn.Module):
    #attentive pooling
    def __init__(self, vocab_size, embedding_size, input_size, num_filters=50, levels = 1,
                 kernel_size=3, padding_idx=0, dropout = 0.2):
        
        super(ABCNN, self).__init__()

        self.embeds = nn.Embedding(vocab_size, embedding_size, padding_idx=padding_idx)

        padding = kernel_size - 1

        layers = []
        for i in range(levels):
            dilation_size = 2 ** i
            in_channels = embedding_size if i == 0 else num_filters
            out_channels = num_filters
            layers += [ABCNN_Block(in_channels, out_channels, kernel_size, stride=1, 
                                   dilation=dilation_size,
                                   padding=(kernel_size-1)//2 * dilation_size, 
                                   dropout=dropout,
                                   dim=1)]
        self.net = nn.Sequential(*layers)
        
        self.w0 = Parameter(torch.Tensor(embedding_size, input_size))
        self.w1 = Parameter(torch.Tensor(embedding_size, input_size))
        
        self.out_f = nn.Linear(num_filters*2, 1, bias=False)
        
        self.init_weights()
        
    def init_weights(self):
        init.kaiming_uniform_(self.w0, a=math.sqrt(5))
        init.kaiming_uniform_(self.w1, a=math.sqrt(5))

    def pairdist(self, Q, A):
        Qd = Q.unsqueeze(-1).expand(Q.shape[0], Q.shape[1], Q.shape[2], Q.shape[2])
        Ad = A.unsqueeze(-2).expand(A.shape[0], A.shape[1], A.shape[2], A.shape[2])
        dist = torch.abs(Qd-Ad).sum(1)
        return dist
        
    def forward(self, x1, x2):
        Q = self.embeds(x1).transpose(1, 2)
        A = self.embeds(x2).transpose(1, 2)

        dist = self.pairdist(Q, A)
        attention_before = 1 / (1 + dist)

        F_Q = self.w0 @ attention_before.transpose(1, 2) #
        F_A = self.w1 @ attention_before

        Q = torch.tanh(self.net(Q)+self.net(F_Q))
        A = torch.tanh(self.net(A)+self.net(F_A))

        
        dist = self.pairdist(Q, A)
        attention_after = 1 / (1 + dist)
        
        attention_Q = attention_after.sum(2).unsqueeze(-1) #
        attention_A = attention_after.sum(1).unsqueeze(-1)

        r_Q = Q @ attention_Q
        r_A = A @ attention_A

        output = torch.cat((r_Q, r_A), 1).squeeze(-1)
        output = torch.sigmoid(self.out_f(output))

        #return r_Q, r_A
        #output = torch.sigmoid(self.cos(r_Q, r_A))
        return output
    
class ABCNN_Block(nn.Module):
    def __init__(self, c_input, c_output, kernel_size, stride, dilation, padding, dim=1, dropout=0.1):
        super(ABCNN_Block, self).__init__()
        
        if dim == 1:
            self.conv = nn.Conv1d(c_input, c_output, kernel_size, 
                                  stride=stride, dilation=dilation, padding=padding)
        elif dim == 2:
            self.conv = nn.Conv2d(c_input, c_output, (kernel_size, 1), stride=stride, 
                                  dilation=(dilation, 1), padding=(padding,0))            
        self.dropout = nn.Dropout(dropout)   
        self.net = nn.Sequential(self.conv, self.dropout)
        self.init_weights()

    def init_weights(self):
        self.conv.weight.data.normal_(0, 0.01)
        #self.conv.weight.data.fill_(1)

    def forward(self, x):
        out = self.net(x)
        return out
