In [2]:
import numpy as np
import torch
import torch.nn as nn

class basket_MGU(nn.Module):
    """
    Class for the basket MGU without covariates
    """

    def __init__(self, seq_length, input_dim, num_hidden):
        """
        Initializes a basket_MGU instance

        Arguments:
          seq_length    number of time steps in each (padded) sequence, scalar
          input_dim     dimensionality of the MGU input (equals the assortment size here)
          num_hidden    dimensionality of the hidden MGU state, scalar hyperparameter
        """
        super(basket_MGU, self).__init__()
        self.seq_length = seq_length
        self.input_dim = input_dim
        self.num_hidden = num_hidden

        self.h_init = torch.zeros(num_hidden)

        # Minimal Gated Unit (MGU) parameters
        self.W_ux = nn.Parameter(torch.Tensor(input_dim, num_hidden))
        nn.init.xavier_uniform_(self.W_ux.data)
        self.W_uh = nn.Parameter(torch.Tensor(num_hidden, num_hidden))
        nn.init.xavier_uniform_(self.W_uh.data)
        self.b_u = nn.Parameter(torch.zeros(num_hidden))

        self.W_rx = nn.Parameter(torch.Tensor(input_dim, num_hidden))
        nn.init.xavier_uniform_(self.W_rx.data)
        self.W_rh = nn.Parameter(torch.Tensor(num_hidden, num_hidden))
        nn.init.xavier_uniform_(self.W_rh.data)
        self.b_r = nn.Parameter(torch.zeros(num_hidden))

        self.W_cx = nn.Parameter(torch.Tensor(input_dim, num_hidden))
        nn.init.xavier_uniform_(self.W_cx.data)
        self.W_ch = nn.Parameter(torch.Tensor(num_hidden, num_hidden))
        nn.init.xavier_uniform_(self.W_ch.data)
        self.b_c = nn.Parameter(torch.zeros(num_hidden))

        # Output layer parameters
        self.W_ph = nn.Parameter(torch.Tensor(num_hidden, input_dim))
        nn.init.xavier_uniform_(self.W_ph.data)
        self.b_p = nn.Parameter(torch.zeros(input_dim))

    def forward(self, x, fw_dropout=False, rc_dropout=False, stepwise=False, track_hiddens=False):
        """
        Executes a forward step of the basket MGU model, for a given input x

        Arguments:
          x               batch of input sequences ([batch size] x [sequence length] x [assortment size])
          fw_dropout      dropout rate to apply in the forward layer (either False (0, default) or a scalar value between 0 and 1)
          rc_dropout      dropout rate to apply in the input gate (either False (0, default) or a scalar value between 0 and 1)
          stepwise        whether to drop out different nodes in each time step (True) or the same nodes in every time step (False, default)
          track_hiddens   whether to return the hidden states computed in each time step (default is False)
        """
        fw_dropout_rate = 0
        rc_dropout_rate = 0

        if track_hiddens:
            hiddens = torch.zeros(x.size()[0], x.size()[1], self.num_hidden)

        if fw_dropout != False:
            fw_dropout_rate = fw_dropout

        if rc_dropout != False:
            rc_dropout_rate = rc_dropout

        hidden = self.h_init.to(device)

        if stepwise == False:
            fw_dropout_mask = torch.bernoulli(torch.ones(x.size()[0], self.input_dim) - fw_dropout_rate) / (
                    1 - fw_dropout_rate)
            rc_dropout_mask = torch.bernoulli(torch.ones(x.size()[0], self.num_hidden) - rc_dropout_rate) / (
                    1 - rc_dropout_rate)

        pred_sequence = torch.zeros(x.size()[0], x.size()[1], self.input_dim)

        for t in np.arange(x.size()[1]):
            if stepwise == True:
                fw_dropout_mask = torch.bernoulli(torch.ones(x.size()[0], self.input_dim) - fw_dropout_rate) / (
                        1 - fw_dropout_rate)
                rc_dropout_mask = torch.bernoulli(torch.ones(x.size()[0], self.num_hidden) - rc_dropout_rate) / (
                        1 - rc_dropout_rate)

            xt = x[:, t, :]  # xt should have dimensions [batch_size, input_dim]

            if fw_dropout != False:
                xt = xt * fw_dropout_mask

            update_gate = torch.sigmoid(
                torch.matmul(xt, self.W_ux) + torch.matmul(hidden, self.W_uh) + self.b_u)
            reset_gate = torch.sigmoid(
                torch.matmul(xt, self.W_rx) + torch.matmul(hidden, self.W_rh) + self.b_r)
            candidate_gate = torch.tanh(
                torch.matmul(xt, self.W_cx) + torch.matmul(reset_gate * hidden, self.W_ch) + self.b_c)

            if rc_dropout != False:
                candidate_gate = candidate_gate * rc_dropout_mask

            hidden = (1 - update_gate) * hidden + update_gate * candidate_gate

            if track_hiddens:
                hiddens[:, t, :] = hidden

            next_basket = torch.sigmoid(torch.matmul(hidden, self.W_ph) + self.b_p)
            pred_sequence[:, t, :] = next_basket

        if track_hiddens:
            return pred_sequence, hiddens
        else:
            return pred_sequence
