In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import LSTM


In [5]:
class TPD(nn.Module):

    __constants__ = ["n_units", "input_dim"]
    def __init__(self, input_dim, output_dim, n_units):
        super(TPD, self).__init__()

        self.n_units = n_units
        self.input_dim = input_dim

        self.left_fc_1 = nn.Linear(n_units, int(1/2*n_units))
        self.left_fc_2 = nn.Linear(int(1/2*n_units),int(1/4*n_units))
        self.left_fc_3 = nn.Linear(int(1/4*n_units),1)

        self.right_fc_1 = nn.Linear(n_units, int(1/2*n_units))
        self.right_fc_2 = nn.Linear(int(1/2*n_units),int(1/4*n_units))
        self.right_fc_3 = nn.Linear(int(1/4*n_units),1)

        self.layers = nn.ModuleList([nn.LSTM(input_size=1,hidden_size=n_units,num_layers=1) for i in range(input_dim)])

        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        temp = x[:,:,0]
        temp = temp[:,:,None]

        first = self.layers[0]
        output,_ = first(temp)

        i = 1
        for layer in self.layers[1:]:
            temp = x[:,:,i]
            temp = temp[:,:,None]
            temp,_ = layer(temp)
            output = torch.cat((output,temp))
            i += 1

        output = output.permute(1,0,2)
        left = torch.relu(self.left_fc_1(output))
        left = torch.tanh(self.left_fc_2(left))
        left = self.softmax(self.left_fc_3(left))

        right = F.relu(self.right_fc_1(output))
        right = F.relu(self.right_fc_2(right))
        right = self.sigmoid(self.right_fc_3(right))

        prob = torch.sum(left*right,dim=1)

        return torch.squeeze(prob), torch.squeeze(left)

In [6]:
N = 11
M = 512
model = TPD(N,1,M)
t = 100
x = torch.randn((1,t,N))
p,alpha = model(x)