In [None]:
'''
adapted from:
https://github.com/nlhkh/dropout-in-rnn
'''

from typing import Optional, Tuple
import torch
from torch import nn, Tensor
from Models_12_gh import ConcreteDropout
import numpy as np


class StochasticLSTMCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, p_fix=0.01, concrete=True,
                 weight_regularizer=.1, dropout_regularizer=.1, Bayes=True):
        '''
        ARGUMENTS:
        input_size: number of features (after embedding layer)
        hidden_size: number of nodes in LSTM layers
        p_fix: dropout parameter used in case of not self.concrete
        concrete: dropout parameter is fixed when "False". If "True", then concrete dropout
        weight_regularizer: parameter for weight regularization in reformulated ELBO
        dropout_regularizer: parameter for dropout regularization in reformulated ELBO
        Bayes: BNN if "True", deterministic model if "False" (only sampled once for inference)
        '''

        super(StochasticLSTMCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.concrete = concrete
        self.wr = weight_regularizer
        self.dr = dropout_regularizer
        self.Bayes = Bayes

        if concrete:
            self.p_logit = nn.Parameter(torch.empty(1).normal_())
        else:
            if np.isnan(p_fix):
                p_fix = .5
            self.p_logit = torch.full([1], p_fix)

        self.Wi = nn.Linear(self.input_size, self.hidden_size)
        self.Wf = nn.Linear(self.input_size, self.hidden_size)
        self.Wo = nn.Linear(self.input_size, self.hidden_size)
        self.Wg = nn.Linear(self.input_size, self.hidden_size)

        self.Ui = nn.Linear(self.hidden_size, self.hidden_size)
        self.Uf = nn.Linear(self.hidden_size, self.hidden_size)
        self.Uo = nn.Linear(self.hidden_size, self.hidden_size)
        self.Ug = nn.Linear(self.hidden_size, self.hidden_size)

        self.init_weights()

    def init_weights(self):
        k = torch.tensor(self.hidden_size, dtype=torch.float32).reciprocal().sqrt()

        self.Wi.weight.data.uniform_(-k, k).cuda()
        self.Wi.bias.data.uniform_(-k, k).cuda()

        self.Wf.weight.data.uniform_(-k, k).cuda()
        self.Wf.bias.data.uniform_(-k, k).cuda()

        self.Wo.weight.data.uniform_(-k, k).cuda()
        self.Wo.bias.data.uniform_(-k, k).cuda()

        self.Wg.weight.data.uniform_(-k, k).cuda()
        self.Wg.bias.data.uniform_(-k, k).cuda()

        self.Ui.weight.data.uniform_(-k, k).cuda()
        self.Ui.bias.data.uniform_(-k, k).cuda()

        self.Uf.weight.data.uniform_(-k, k).cuda()
        self.Uf.bias.data.uniform_(-k, k).cuda()

        self.Uo.weight.data.uniform_(-k, k).cuda()
        self.Uo.bias.data.uniform_(-k, k).cuda()

        self.Ug.weight.data.uniform_(-k, k).cuda()
        self.Ug.bias.data.uniform_(-k, k).cuda()

    def _sample_mask(self, batch_size, stop_dropout):
        '''
        ARGUMENTS:
        batch_size: batch size
        stop_dropout: if "True" prevents dropout during inference for deterministic models

        OUTPUTS:
        zx: dropout masks for inputs. Tensor (GATES x batch_size x input size (after embedding))
        zh: dropout masks for hiddens states. Tensor (GATES x batch_size x number hidden states)
        '''

        if not self.concrete:
            p = self.p_logit.cuda()
        else:
            p = torch.sigmoid(self.p_logit).cuda()
        GATES = 4
        eps = torch.tensor(1e-7)
        t = 1e-1

        if not stop_dropout:
            ux = torch.rand(GATES, batch_size, self.input_size).cuda()
            uh = torch.rand(GATES, batch_size, self.hidden_size).cuda()

            if self.input_size == 1:
                zx = (1 - torch.sigmoid((torch.log(eps) - torch.log(1 + eps)
                                         + torch.log(ux + eps) - torch.log(1 - ux + eps))
                                        / t))
            else:
                zx = (1 - torch.sigmoid((torch.log(p + eps) - torch.log(1 - p + eps)
                                         + torch.log(ux + eps) - torch.log(1 - ux + eps))
                                        / t)) / (1 - p)
            zh = (1 - torch.sigmoid((torch.log(p + eps) - torch.log(1 - p + eps)
                                     + torch.log(uh + eps) - torch.log(1 - uh + eps))
                                    / t)) / (1 - p)
        else:
            zx = torch.ones(GATES, batch_size, self.input_size).cuda()
            zh = torch.ones(GATES, batch_size, self.input_size).cuda()

        return zx, zh

    
    def regularizer(self):
        '''
        OUTPUTS:
        self.wr * weight_sum: weight regularization in reformulated ELBO
        self.wr * bias_sum: bias regularization in reformulated ELBO
        self.dr * dropout_reg: dropout regularization in reformulated ELBO
        '''

        if not self.concrete:
            p = self.p_logit.cuda()
        else:
            p = torch.sigmoid(self.p_logit)

        if self.Bayes:
            weight_sum = torch.tensor([
                torch.sum(params ** 2) for name, params in self.named_parameters() if name.endswith("weight")
            ]).sum() / (1. - p)

            bias_sum = torch.tensor([
                torch.sum(params ** 2) for name, params in self.named_parameters() if name.endswith("bias")
            ]).sum()

            if not self.concrete:
                dropout_reg = torch.zeros(1)
            else:
                dropout_reg = self.input_size * (p * torch.log(p) + (1 - p) * torch.log(1 - p))
            return self.wr * weight_sum, self.wr * bias_sum, self.dr * dropout_reg
        else:
            return torch.zeros(1)


    def forward(self, input: Tensor, stop_dropout=False) -> Tuple[
        Tensor, Tuple[Tensor, Tensor]]:
        '''
        ARGUMENTS:
        input: Tensor (sequence length x batch size x input size(after embedding) )
        stop_dropout: if "True" prevents dropout during inference for deterministic models

        OUTPUTS:
        hn: tensor of hidden states h_t. Dimension (sequence_length x batch_size x hidden size)
        h_t: hidden states at time t. Dimension (batch size x hidden size (number of nodes in LSTM layer)
        c_t: cell states. Dimension (batch size x hidden size (number of nodes in LSTM layer)
        '''

        seq_len, batch_size = input.shape[0:2]

        h_t = torch.zeros(batch_size, self.hidden_size, dtype=input.dtype).cuda()
        c_t = torch.zeros(batch_size, self.hidden_size, dtype=input.dtype).cuda()

        hn = torch.empty(seq_len, batch_size, self.hidden_size, dtype=input.dtype)

        zx, zh = self._sample_mask(batch_size, stop_dropout)

        for t in range(seq_len):
            x_i, x_f, x_o, x_g = (input[t] * zx_ for zx_ in zx)
            h_i, h_f, h_o, h_g = (h_t * zh_ for zh_ in zh)

            i = torch.sigmoid(self.Ui(h_i) + self.Wi(x_i))
            f = torch.sigmoid(self.Uf(h_f) + self.Wf(x_f))
            o = torch.sigmoid(self.Uo(h_o) + self.Wo(x_o))
            g = torch.tanh(self.Ug(h_g) + self.Wg(x_g))

            c_t = f * c_t + i * g
            h_t = o * torch.tanh(c_t)
            hn[t] = h_t
            hn = hn.cuda()

        return hn, (h_t, c_t)


class StochasticLSTM(nn.Module):
    """LSTM stacked layers with dropout and MCMC"""


    def __init__(self, **kwargs):
        '''
        ARGUMENTS:
        emb_dims: list of tuples (a, b) for each categorical variable,
                  with a: number of levels, and b: embedding dimension
        hidden_size: number of nodes in LSTM layers
        weight_regularizer: parameter for weight regularization in reformulated ELBO
        dropout_regularizer: parameter for dropout regularization in reformulated ELBO
        "input_size": number of range variables
        hs: "True" if heteroscedastic, "False" if homoscedastic
        dropout: in case of deterministic model, apply dropout if "True", otherwise no dropout
        concrete: dropout parameter is fixed when "False". If "True", then concrete dropout
        p_fix: dropout parameter in case "concrete"="False"
        Bayes: BNN if "True", deterministic model if "False" (only sampled once for inference)
        nr_lstm_layers: number of LSTM layers
        '''

        super(StochasticLSTM, self).__init__()

        defaultKwargs = {"emb_dims": None, "hidden_size": 10, "weight_regularizer": .1, "dropout_regularizer": .1,
                         "input_size": 1, "hs": True, "dropout": True, "concrete":True, "p_fix": .01,
                         "Bayes": True, "nr_lstm_layers": 3}
        kwargs = {**defaultKwargs, **kwargs}
        self.emb_dims = kwargs["emb_dims"]
        self.hidden_size = kwargs["hidden_size"]
        self.weight_regularizer = kwargs["weight_regularizer"]
        self.dropout_regularizer = kwargs["dropout_regularizer"]
        self.input_size = kwargs["input_size"]
        self.heteroscedastic = kwargs["hs"]
        self.dropout = kwargs["dropout"]
        self.concrete = kwargs["concrete"]
        self.p_fix = kwargs["p_fix"]
        self.Bayes = kwargs["Bayes"]
        self.nr_layers = kwargs["nr_lstm_layers"]

        self.no_of_embs = 0
        if self.emb_dims:
            self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                             for x, y in self.emb_dims])
            self.no_of_embs = sum([y for x, y in self.emb_dims])

        self.input_size += self.no_of_embs

        self.first_layer = StochasticLSTMCell(self.input_size, self.hidden_size, p_fix=self.p_fix, concrete=self.concrete,
                                              weight_regularizer=self.weight_regularizer,
                                              dropout_regularizer=self.dropout_regularizer, Bayes=self.Bayes)
        self.hidden_layers = nn.ModuleList(
            [StochasticLSTMCell(self.hidden_size, self.hidden_size, self.p_fix, concrete=self.concrete,
                                weight_regularizer=self.weight_regularizer,
                                dropout_regularizer=self.dropout_regularizer,
                                Bayes=self.Bayes) for i in range(self.nr_layers - 1)])
        self.linear1 = nn.Linear(self.hidden_size, 5)
        self.linear2_mu = nn.Linear(5, 1)
        if self.heteroscedastic:
            self.linear2_logvar = nn.Linear(5, 1)


        self.conc_drop1 = ConcreteDropout(dropout=self.dropout, concrete=self.concrete, p_fix=self.p_fix,
                                          weight_regularizer=self.weight_regularizer,
                                          dropout_regularizer=self.dropout_regularizer, conv="lin", Bayes=self.Bayes)
        self.conc_drop2_mu = ConcreteDropout(dropout=self.dropout, concrete=self.concrete, p_fix=self.p_fix,
                                            weight_regularizer=self.weight_regularizer,
                                            dropout_regularizer=self.dropout_regularizer, conv="lin", Bayes=self.Bayes)
        if self.heteroscedastic:
            self.conc_drop2_logvar = ConcreteDropout(dropout=self.dropout, concrete=self.concrete, p_fix=self.p_fix,
                                                weight_regularizer=self.weight_regularizer,
                                                dropout_regularizer=self.dropout_regularizer, conv="lin", Bayes=self.Bayes)

        self.relu = nn.ReLU()


    def regularizer(self):
        total_weight_reg = self.first_layer.regularizer()
        for l in self.hidden_layers:
            total_weight_reg += l.regularizer()
        return total_weight_reg


    def forward(self, x_cat, x_range, stop_dropout=False):
        '''
        ARGUMENTS:
        x_cat: categorical variables. Torch tensor (batch size x sequence length x number of variables)
        x_range: range variables. Torch tensor (batch size x sequence length x number of variables)
        stop_dropout: if "True" prevents dropout during inference for deterministic models

        OUTPUTS:
        mean: outputs (point estimates). Torch tensor (batch size x number of outputs)
        log_var: log of uncertainty estimates. Torch tensor (batch size x number of outputs)
        regularization.sum(): sum of KL regularizers over all model layers

        '''

        regularization = torch.empty(4, device=x_range.device)

        if self.no_of_embs != 0:
            x = [emb_layer(x_cat[:, :, i])
                 for i, emb_layer in enumerate(self.emb_layers)]
            x = torch.cat(x, -1)
            x = torch.cat([x, x_range], -1)
        else:
            x = x_range

        x = x.transpose(0, 1)

        batch_size = x.shape[1]
        h_n = torch.zeros(self.nr_layers, batch_size, self.first_layer.hidden_size)
        c_n = torch.zeros(self.nr_layers, batch_size, self.first_layer.hidden_size)

        outputs, (h, c) = self.first_layer(x)
        h_n[0] = h
        c_n[0] = c