In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# word_dim , n_head , n_hid , dropout , nlayers
class customizedModule(nn.Module):
    def __init__(self):
        super(customizedModule, self).__init__()

    # linear transformation (w/ initialization) + activation + dropout
    def customizedLinear(self, in_dim, out_dim, activation=None, dropout=False):
        cl = nn.Sequential(nn.Linear(in_dim, out_dim))
        nn.init.xavier_uniform_(cl[0].weight)
        nn.init.constant_(cl[0].bias, 0)

        if activation is not None:
            cl.add_module(str(len(cl)), activation)
        if dropout:
            cl.add_module(str(len(cl)), nn.Dropout(p=self.args.dropout))

        return cl

class Q_S2T(customizedModule):
    def __init__(self, hidden_size):
        super(Q_S2T, self).__init__()

        self.s2t_W1 = self.customizedLinear(hidden_size, hidden_size, activation=nn.ReLU())
        self.s2t_W = self.customizedLinear(hidden_size, hidden_size)

    def forward(self, x):
        """
        source2token self-attention module
        :param x: (batch, seq_len, hidden_size)
        :return: s: (batch, hidden_size)
        """

        # (batch, (block_num), seq_len, word_dim)
        f = self.s2t_W1(x)
        f = F.softmax(self.s2t_W(f), dim=-2)
        # (batch, (block_num), word_dim)
        s = torch.sum(f * x, dim=-2)
        return s    


class qEncoder(nn.Module):
    def __init__(self,word_dim,n_head,n_hid,dropout,nlayers):
        super(qEncoder, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        encoder_layers = TransformerEncoderLayer(word_dim, n_head, n_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.s2t = Q_S2T(word_dim)
        #self.l = nn.Linear(300,1)
    def forward(self,x):
        #(batch,sequence,worddim)
        x = self.transformer_encoder(x,None)
        x = self.s2t(x)
        return  x