In [None]:
import os
import numpy as np
from time import time

from sklearn.metrics import roc_auc_score

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable

mask_num=-2**32+1.0
class ESIM(nn.Module):
    """
    Implementation of the ESIM model
    """
    
    def __init__(self, vocab_size, embedding_matrix):
        super(ESIM, self).__init__()
        self.verbose = True
        self.use_cuda = True
        self.vocab_size = vocab_size
        self.embedding_dim = Parameters.EMBEDDING_DIM
        self.hidden_size = Parameters.HIDDEN_SIZE
        self.linear_size = Parameters.DENSE_SIZE
        self.drop_out = Parameters.DROPOUT_RATE
        self.learning_rate = 0.0001
        self.optimizer = "adam"
        self.eval_metric = roc_auc_score
        self.n_epochs = 1
        self.batch_size = Parameters.BATCH_SIZE

        self.embed = nn.Embedding(self.vocab_size + 1, self.embedding_dim, padding_idx=0)
        self.embed.weight.data.copy_(torch.from_numpy(embedding_matrix))
        
        self.bat_nor_embed = nn.BatchNorm1d(self.embedding_dim)
        self.lstm1 = nn.LSTM(self.embedding_dim, self.hidden_size, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(self.embedding_dim*2*4, self.hidden_size, batch_first=True, bidirectional=True)
        
        self.fc = nn.Sequential(
            nn.BatchNorm1d(self.hidden_size*4*2),
            nn.Linear(self.hidden_size*2*4, self.linear_size),
            nn.ReLU(),
            nn.BatchNorm1d(self.linear_size),
            nn.Dropout(self.drop_out),
            nn.Linear(self.linear_size, 1)
            )

    def sub_mul_work(self, d1, d2):
        mul = d1 * d2
        sub = d1 - d2
        
        return torch.cat([sub, mul], dim=-1)
    
    def pool_work(self, d):
        p1 = F.avg_pool1d(d.transpose(1, 2), d.size(1)).squeeze(-1)
        p2 = F.max_pool1d(d.transpose(1, 2), d.size(1)).squeeze(-1)
        
        return torch.cat([p1, p2], 1)
        
    def soft_attention(self, d1, d2, mask1, mask2):  # mask, ignore padding data in the sequences during the computation of the attention
        attention = torch.matmul(d1, d2.transpose(1,2))
        mask1 = mask1.float().masked_fill(mask1, mask_num)
        mask2 = mask2.float().masked_fill_(mask2, mask_num)
        
        atten_1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
        d1_align = torch.matmul(atten_1, d2)
        atten_2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)
        d2_align = torch.matmul(atten_2, d1)
        
        return d1_align, d2_align
    
    def forward(self, *input):
        q, t = input[0], input[1]
        mask1, mask2 = q.eq(0), t.eq(0)
        
        q = self.bat_nor_embed(self.embed(q).transpose(1, 2).contiguous()).transpose(1, 2)
        t = self.bat_nor_embed(self.embed(t).transpose(1, 2).contiguous()).transpose(1, 2)
        
        q_lstm, _ = self.lstm1(q)
        t_lstm, _ = self.lstm1(t)
        
        q_attn, t_attn = self.soft_attention(q_lstm, t_lstm, mask1, mask2)
        
        q_enhanced = torch.cat([q_lstm, q_attn, self.sub_mul_work(q_lstm, q_attn)], -1)
        t_enhanced = torch.cat([t_lstm, t_attn, self.sub_mul_work(t_lstm, t_attn)], -1)
        
        v_q, _ = self.lstm2(q_enhanced)
        v_t, _ = self.lstm2(t_enhanced)
        
        merged_q = self.pool_work(v_q)
        merged_t = self.pool_work(v_t)
        
        merged = torch.cat([merged_q, merged_t], -1)
        output = self.fc(merged)
        
        return output