In [22]:
from CoxPASNet.coxpasnet.DataLoader import load_data, load_pathway
from CoxPASNet.coxpasnet.Train import trainCoxPASNet
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from CoxPASNet.coxpasnet.Survival_CostFunc_CIndex import R_set, neg_par_log_likelihood, c_index
from sksurv.metrics import concordance_index_censored



import torch.optim as optim
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

from run_pipeline.bayesian_custom_nn import run_custom_bnn,arguments
from run_pipeline.cpath_bnn import run_cpath_bnn,arguments


import time

from src.models.variational_layers.linear_reparam import LinearReparam,LinearGroupNJ
from src.data_prep.torch_datasets import cpath_dataset
from torch.nn.parameter import Parameter


In [3]:
dtype = torch.FloatTensor
''' Net Settings'''
In_Nodes = 5567 ###number of genes
Pathway_Nodes = 860 ###number of pathways
Hidden_Nodes = 100 ###number of hidden nodes
Out_Nodes = 30 ###number of hidden nodes in the last hidden layer
''' Initialize '''
Initial_Learning_Rate = [0.03] #[0.03, 0.01, 0.001, 0.00075]
L2_Lambda = [0.01]  #[0.1, 0.01, 0.005, 0.001]
num_epochs = 10 #3000 ###for grid search
Num_EPOCHS = 15 #20000 ###for training
###sub-network setup
Dropout_Rate = [0.7,0.5]

In [4]:
''' load data and pathway '''
pathway_mask = load_pathway("../data/pathway_mask.csv", dtype)

x_train, ytime_train, yevent_train, age_train = load_data("../data/train.csv", dtype)
x_valid, ytime_valid, yevent_valid, age_valid = load_data("../data/validation.csv", dtype)
x_test, ytime_test, yevent_test, age_test = load_data("../data/test.csv", dtype)



In [21]:
class LinearGroupNJ_Masked(Module):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).
    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = self.mask * torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)
            self.weight_mu.data = self.mask * self.weight_mu 

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z  
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [53]:
test_tens = Parameter(torch.Tensor(860, 5567))
test_tens.data.normal_(0, 0.2)

tensor([[-0.3019,  0.3189, -0.2703,  ..., -0.0904, -0.2369, -0.0382],
        [-0.1145,  0.3655,  0.1703,  ..., -0.1354,  0.2717, -0.0121],
        [-0.1308,  0.1574,  0.1233,  ...,  0.0948, -0.1653,  0.0116],
        ...,
        [ 0.2184,  0.3030,  0.0014,  ..., -0.0902,  0.0674,  0.0728],
        [-0.2955,  0.2499, -0.2631,  ...,  0.2403,  0.0836, -0.0193],
        [-0.1089, -0.2964,  0.0898,  ..., -0.1847,  0.1355,  0.1329]])

In [55]:
test_tens.shape

torch.Size([860, 5567])

In [43]:
)

RuntimeError: The size of tensor a (860) must match the size of tensor b (5567) at non-singleton dimension 1

In [45]:
l = pathway_mask.transpose(0,1)

In [49]:
l*test_tens

RuntimeError: The size of tensor a (860) must match the size of tensor b (5567) at non-singleton dimension 1

In [56]:
w = pathway_mask *test_tens

In [50]:
l.shape

torch.Size([5567, 860])

In [51]:
pathway_mask.shape

torch.Size([860, 5567])