In [None]:
class StochasticCNN_1D(nn.Module):
    def __init__(self, emb_dims=None, hidden_size=1024,
                 weight_regularizer=.1, dropout_regularizer=.1,
                 conv_dropout_traditional=False, dropout_dense_only=False,
                 input_size=1, c=1, hs=True, kernel_size=3, nr_kernels = 20,
                 nr_conv_layers=2, seq_len=6,
                 dropout=True, concrete=True, p_fix=0.01, Bayes=True):

        '''
        ARGUMENTS:
        emb_dims: list of tuples (a, b) for each categorical variable,
                  with a: number of levels, and b: embedding dimension
        hidden_size: number of nodes in dense layers
        weight_regularizer: parameter for weight regularization in reformulated ELBO
        dropout_regularizer: parameter for dropout regularization in reformulated ELBO
        conv_dropout_traditional: if "True" then traditional dropout between convolutional layers
        dropout_dense_only: if "True" then only dropout in dense layers, not in convolutional layers
        input_size: number of range features
        c: number of outputs (one for remaining time prediction)
        hs: "True" if heteroscedastic, "False" if homoscedastic
        kernel_size: size of kernels in convolutional layers
        nr_kernels: number of kernels in convolutional layers
        nr_conv_layers: number of convolutional layers
        seq_len: sequence length
        dropout: in case of deterministic model, apply dropout if "True", otherwise no dropout
        concrete: dropout parameter is fixed when "False". If "True", then concrete dropout
        p_fix: dropout parameter in case "concrete"="False"
        Bayes: BNN if "True", deterministic model if "False" (only sampled once for inference)
        '''

        self.heteroscedastic = hs
        self.nr_conv_layers = nr_conv_layers

        super().__init__()

        self.no_of_embs = 0
        if emb_dims:
            self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                             for x, y in emb_dims])
            self.no_of_embs = sum([y for x, y in emb_dims])

        self.conv = "1D"
        if conv_dropout_traditional:
            self.conv = "lin"

        dropout_conv = dropout
        if dropout_dense_only:
            dropout_conv = False

        len_conv1 = self.no_of_embs + input_size
        self.conv1 = nn.Conv1d(len_conv1, nr_kernels, kernel_size=kernel_size)
        self.maxpool1 = nn.MaxPool1d(3,1)
        if self.nr_conv_layers == 3:
            self.conv2 = nn.Conv1d(nr_kernels, nr_kernels * 2, kernel_size=kernel_size)
            self.maxpool2 = nn.MaxPool1d(3,1)
            self.conv3a = nn.Conv1d(nr_kernels*2, nr_kernels, kernel_size=kernel_size)
        elif self.nr_conv_layers == 2:
            self.conv3b = nn.Conv1d(nr_kernels, nr_kernels, kernel_size=kernel_size)
        else:
            pass
        self.maxpool3 = nn.MaxPool1d(3, 1)
        self.len_lin3 = nr_kernels * (seq_len - self.nr_conv_layers * (kernel_size - 1) - self.nr_conv_layers * (3 - 1))
        self.linear4 = nn.Linear(self.len_lin3, hidden_size)
        self.linear5 = nn.Linear(hidden_size, int(hidden_size / 10))

        self.linear6_mu = nn.Linear(int(hidden_size / 10), c)
        if self.heteroscedastic:
            self.linear6_logvar = nn.Linear(int(hidden_size / 10), 1)

        # concrete dropout for convolutional layers
        self.conc_drop1 = ConcreteDropout(dropout=dropout_conv, concrete=concrete, p_fix=p_fix,
                                          weight_regularizer=weight_regularizer,
                                          dropout_regularizer=dropout_regularizer, conv=self.conv, Bayes=Bayes)
        self.conc_drop2 = ConcreteDropout(dropout=dropout_conv, concrete=concrete, p_fix=p_fix,
                                          weight_regularizer=weight_regularizer,
                                          dropout_regularizer=dropout_regularizer, conv=self.conv, Bayes=Bayes)
        self.conc_drop3a = ConcreteDropout(dropout=dropout_conv, concrete=concrete, p_fix=p_fix,
                                          weight_regularizer=weight_regularizer,
                                          dropout_regularizer=dropout_regularizer, conv=self.conv, Bayes=Bayes)
        self.conc_drop3b = ConcreteDropout(dropout=dropout_conv, concrete=concrete, p_fix=p_fix,
                                          weight_regularizer=weight_regularizer,
                                          dropout_regularizer=dropout_regularizer, conv=self.conv, Bayes=Bayes)
        # concrete dropout for dense layers
        self.conc_drop4 = ConcreteDropout(dropout=dropout, concrete=concrete, p_fix=p_fix,
                                          weight_regularizer=weight_regularizer,
                                          dropout_regularizer=dropout_regularizer, conv="lin", Bayes=Bayes)
        self.conc_drop5 = ConcreteDropout(dropout=dropout, concrete=concrete, p_fix=p_fix,
                                          weight_regularizer=weight_regularizer,
                                          dropout_regularizer=dropout_regularizer, conv="lin", Bayes=Bayes)
        self.conc_drop6_mu = ConcreteDropout(dropout=dropout, concrete=concrete, p_fix=p_fix,
                                            weight_regularizer=weight_regularizer,
                                            dropout_regularizer=dropout_regularizer, conv="lin", Bayes=Bayes)
        if self.heteroscedastic:
            self.conc_drop6_logvar = ConcreteDropout(dropout=dropout, concrete=concrete, p_fix=p_fix,
                                                weight_regularizer=weight_regularizer,
                                                dropout_regularizer=dropout_regularizer, conv="lin", Bayes=Bayes)

        self.relu = nn.ReLU()

        
    def forward(self, x_cat, x_range, stop_dropout=False):
        '''
        ARGUMENTS:
        x_cat: categorical variables. Torch tensor (batch size x sequence length x number of variables)
        x_range: range variables. Torch tensor (batch size x sequence length x number of variables)
        stop_dropout: if "True" prevents dropout during inference for deterministic models

        OUTPUTS:
        mean: outputs (point estimates). Torch tensor (batch size x number of outputs)
        log_var: log of uncertainty estimates. Torch tensor (batch size x number of outputs)
        regularization.sum(): sum of KL regularizers over all model layers
        '''

        regularization = torch.empty(7, device=x_range.device)    #x.device = cuda:0 here

        if self.no_of_embs != 0:
            x = [emb_layer(x_cat[:, :, i])
                 for i, emb_layer in enumerate(self.emb_layers)]
            x = torch.cat(x, -1)
            x = torch.cat([x, x_range], -1)
        else:
            x = x_range

        x = torch.transpose(x, 1, 2)  # reshape from (N, seq_len, nr_features) to (N, nr_features, seq_len)

        x, regularization[0] = self.conc_drop1(x, nn.Sequential(self.conv1, self.relu, self.maxpool1), stop_dropout)
        if self.nr_conv_layers == 2:
            x, regularization[2] = self.conc_drop3b(x, nn.Sequential(self.conv3b, self.relu, self.maxpool3), stop_dropout)
        elif self.nr_conv_layers == 3:
            x, regularization[1] = self.conc_drop2(x, nn.Sequential(self.conv2, self.relu, self.maxpool2), stop_dropout)
            x, regularization[2] = self.conc_drop3a(x, nn.Sequential(self.conv3a, self.relu, self.maxpool3), stop_dropout)
        else:
            pass
        x3 = x.view(-1, self.len_lin3)
        x4, regularization[3] = self.conc_drop4(x3, nn.Sequential(self.linear4, self.relu), stop_dropout)
        x5, regularization[4] = self.conc_drop5(x4, nn.Sequential(self.linear5, self.relu), stop_dropout)
        mean, regularization[5] = self.conc_drop6_mu(x5, self.linear6_mu, stop_dropout)
        if self.heteroscedastic:
            log_var, regularization[6] = self.conc_drop6_logvar(x5, self.linear6_logvar, stop_dropout)
        else:
            regularization[6] = 0
            log_var = torch.empty(mean.size())

        return mean, log_var, regularization.sum()