In [178]:
from CoxPASNet.coxpasnet.DataLoader import load_data, load_pathway
from CoxPASNet.coxpasnet.Train import trainCoxPASNet

import torch
import torch.nn as nn
import numpy as np
from CoxPASNet.coxpasnet.Survival_CostFunc_CIndex import R_set, neg_par_log_likelihood, c_index

import torch.optim as optim
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

import time


In [2]:
dtype = torch.FloatTensor
''' Net Settings'''
In_Nodes = 5567 ###number of genes
Pathway_Nodes = 860 ###number of pathways
Hidden_Nodes = 100 ###number of hidden nodes
Out_Nodes = 30 ###number of hidden nodes in the last hidden layer
''' Initialize '''
Initial_Learning_Rate = [0.03, 0.00075] #[0.03, 0.01, 0.001, 0.00075]
L2_Lambda = [0.1, 0.001]  #[0.1, 0.01, 0.005, 0.001]
num_epochs = 10 #3000 ###for grid search
Num_EPOCHS = 5 #20000 ###for training
###sub-network setup
Dropout_Rate = [0.7, 0.5]

In [3]:
''' load data and pathway '''
pathway_mask = load_pathway("../data/pathway_mask.csv", dtype)

x_train, ytime_train, yevent_train, age_train = load_data("../data/train.csv", dtype)
x_valid, ytime_valid, yevent_valid, age_valid = load_data("../data/validation.csv", dtype)
x_test, ytime_test, yevent_test, age_test = load_data("../data/test.csv", dtype)


In [4]:
opt_l2_loss = 0
opt_lr_loss = 0
opt_loss = torch.Tensor([float("Inf")])
###if gpu is being used
if torch.cuda.is_available():
	opt_loss = opt_loss.cuda()
###
opt_c_index_va = 0
opt_c_index_tr = 0
###grid search the optimal hyperparameters using train and validation data
for l2 in L2_Lambda:
	for lr in Initial_Learning_Rate:
		loss_train, loss_valid, c_index_tr, c_index_va = trainCoxPASNet(x_train, age_train, ytime_train, yevent_train, \
																x_valid, age_valid, ytime_valid, yevent_valid, pathway_mask, \
																In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
																lr, l2, num_epochs, Dropout_Rate)
		if loss_valid < opt_loss:
			opt_l2_loss = l2
			opt_lr_loss = lr
			opt_loss = loss_valid
			opt_c_index_tr = c_index_tr
			opt_c_index_va = c_index_va
		print ("L2: ", l2, "LR: ", lr, "Loss in Validation: ", loss_valid)



###train Cox-PASNet with optimal hyperparameters using train data, and then evaluate the trained model with test data
###Note that test data are only used to evaluate the trained Cox-PASNet
loss_train, loss_test, c_index_tr, c_index_te = trainCoxPASNet(x_train, age_train, ytime_train, yevent_train, \
							x_test, age_test, ytime_test, yevent_test, pathway_mask, \
							In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
							opt_lr_loss, opt_l2_loss, Num_EPOCHS, Dropout_Rate)
print ("Optimal L2: ", opt_l2_loss, "Optimal LR: ", opt_lr_loss)
print("C-index in Test: ", c_index_te)


Loss in Train:  tensor([4.7945], grad_fn=<ViewBackward0>)


KeyboardInterrupt: 

In [3]:
import bayesian_torch.layers as bayesian_layers
from bayesian_torch.utils.util import get_rho

In [4]:
import torch
import torch.nn as nn
import torch.distributions as distributions


class BaseVariationalLayer_(nn.Module):
    def __init__(self):
        super().__init__()
        self._dnn_to_bnn_flag = False

    @property
    def dnn_to_bnn_flag(self):
        return self._dnn_to_bnn_flag

    @dnn_to_bnn_flag.setter
    def dnn_to_bnn_flag(self, value):
        self._dnn_to_bnn_flag = value

    def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
        """
        Calculates kl divergence between two gaussians (Q || P)
        Parameters:
             * mu_q: torch.Tensor -> mu parameter of distribution Q
             * sigma_q: torch.Tensor -> sigma parameter of distribution Q
             * mu_p: float -> mu parameter of distribution P
             * sigma_p: float -> sigma parameter of distribution P
        returns torch.Tensor of shape 0
        """
        kl = torch.log(sigma_p) - torch.log(
            sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 *
                                                          (sigma_p**2)) - 0.5
        return kl.mean()
    

def bnn_linear_layer_cust(params, d):
    # Get BNN layer
    bnn_layer = LinearReparam(
        in_features=d.in_features,
        out_features=d.out_features,
        prior_mean=params["prior_mu"],
        prior_variance=params["prior_sigma"],
        posterior_mu_init=params["posterior_mu_init"],
        posterior_rho_init=params["posterior_rho_init"],
        bias=d.bias is not None,
    )
    bnn_layer.dnn_to_bnn_flag = True
    return bnn_layer

def dnn_to_bnn_bcoxpas(m, bnn_prior_parameters):
    for name, value in list(m._modules.items()):
        if m._modules[name]._modules:
            dnn_to_bnn_bcoxpas(m._modules[name], bnn_prior_parameters)
        if name == "sc1":
            pass
        elif "Linear" in m._modules[name].__class__.__name__:
            setattr(
                m,
                name,
                bnn_linear_layer_cust(
                    bnn_prior_parameters,
                    m._modules[name]))
        else:
            pass
    return

In [5]:
class LinearReparameterization(BaseVariationalLayer_):
    def __init__(self,
                 in_features,
                 out_features,
                 prior_mean=0,
                 prior_variance=1,
                 posterior_mu_init=0,
                 posterior_rho_init=-3.0,
                 bias=True):
        """
        Implements Linear layer with reparameterization trick.
        Inherits from bayesian_torch.layers.BaseVariationalLayer_
        Parameters:
            in_features: int -> size of each input sample,
            out_features: int -> size of each output sample,
            prior_mean: float -> mean of the prior arbitrary distribution to be used on the complexity cost,
            prior_variance: float -> variance of the prior arbitrary distribution to be used on the complexity cost,
            posterior_mu_init: float -> init trainable mu parameter representing mean of the approximate posterior,
            posterior_rho_init: float -> init trainable rho parameter representing the sigma of the approximate posterior through softplus function,
            bias: bool -> if set to False, the layer will not learn an additive bias. Default: True,
        """
        super(LinearReparameterization, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.prior_mean = prior_mean
        self.prior_variance = prior_variance
        self.posterior_mu_init = posterior_mu_init,  # mean of weight
        # variance of weight --> sigma = log (1 + exp(rho))
        self.posterior_rho_init = posterior_rho_init,
        self.bias = bias

        self.mu_weight = Parameter(torch.Tensor(out_features, in_features))
        self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('eps_weight',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_mu',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_sigma',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        if bias:
            self.mu_bias = Parameter(torch.Tensor(out_features))
            self.rho_bias = Parameter(torch.Tensor(out_features))
            self.register_buffer(
                'eps_bias',
                torch.Tensor(out_features),
                persistent=False)
            self.register_buffer(
                'prior_bias_mu',
                torch.Tensor(out_features),
                persistent=False)
            self.register_buffer('prior_bias_sigma',
                                 torch.Tensor(out_features),
                                 persistent=False)
        else:
            self.register_buffer('prior_bias_mu', None, persistent=False)
            self.register_buffer('prior_bias_sigma', None, persistent=False)
            self.register_parameter('mu_bias', None)
            self.register_parameter('rho_bias', None)
            self.register_buffer('eps_bias', None, persistent=False)

        self.init_parameters()

    def init_parameters(self):
        self.prior_weight_mu.fill_(self.prior_mean)
        self.prior_weight_sigma.fill_(self.prior_variance)

        self.mu_weight.data.normal_(mean=self.posterior_mu_init[0], std=0.1)
        self.rho_weight.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
        if self.mu_bias is not None:
            self.prior_bias_mu.fill_(self.prior_mean)
            self.prior_bias_sigma.fill_(self.prior_variance)
            self.mu_bias.data.normal_(mean=self.posterior_mu_init[0], std=0.1)
            self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
                                       std=0.1)

    def kl_loss(self):
        sigma_weight = torch.log1p(torch.exp(self.rho_weight))
        kl = self.kl_div(
            self.mu_weight,
            sigma_weight,
            self.prior_weight_mu,
            self.prior_weight_sigma)
        if self.mu_bias is not None:
            sigma_bias = torch.log1p(torch.exp(self.rho_bias))
            kl += self.kl_div(self.mu_bias, sigma_bias,
                              self.prior_bias_mu, self.prior_bias_sigma)
        return kl

    def forward(self, input, return_kl=True):
        if self.dnn_to_bnn_flag:
            return_kl = False
        sigma_weight = torch.log1p(torch.exp(self.rho_weight))
        weight = self.mu_weight + \
            (sigma_weight * self.eps_weight.data.normal_())
        if return_kl:
            kl_weight = self.kl_div(self.mu_weight, sigma_weight,
                                    self.prior_weight_mu, self.prior_weight_sigma)
        bias = None

        if self.mu_bias is not None:
            sigma_bias = torch.log1p(torch.exp(self.rho_bias))
            bias = self.mu_bias + (sigma_bias * self.eps_bias.data.normal_())
            if return_kl:
                kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
                                      self.prior_bias_sigma)

        out = F.linear(input, weight, bias)
        if return_kl:
            if self.mu_bias is not None:
                kl = kl_weight + kl_bias
            else:
                kl = kl_weight

            return out, kl

        return out

In [6]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter


class LinearReparam(BaseVariationalLayer_):
    def __init__(self,
                 in_features,
                 out_features,
                 prior_means,
                 prior_variances,
                 posterior_mu_init=0,
                 posterior_rho_init=-3.0,
                 bias=True):
        """
        Implements Linear layer with reparameterization trick.
        Inherits from bayesian_torch.layers.BaseVariationalLayer_
        Parameters:
            in_features: int -> size of each input sample,
            out_features: int -> size of each output sample,
            prior_mean: float -> mean of the prior arbitrary distribution to be used on the complexity cost,
            prior_variance: float -> variance of the prior arbitrary distribution to be used on the complexity cost,
            posterior_mu_init: float -> init trainable mu parameter representing mean of the approximate posterior,
            posterior_rho_init: float -> init trainable rho parameter representing the sigma of the approximate posterior through softplus function,
            bias: bool -> if set to False, the layer will not learn an additive bias. Default: True,
        """
        super(LinearReparam, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.prior_means = prior_means
        self.prior_variances = prior_variances
        self.posterior_mu_init = posterior_mu_init,  # mean of weight
        # variance of weight --> sigma = log (1 + exp(rho))
        self.posterior_rho_init = posterior_rho_init,
        self.bias = bias

        self.mu_weight = Parameter(torch.Tensor(out_features, in_features))
        self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('eps_weight',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_mu',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        self.register_buffer('prior_weight_sigma',
                             torch.Tensor(out_features, in_features),
                             persistent=False)
        if bias:
            self.mu_bias = Parameter(torch.Tensor(out_features))
            self.rho_bias = Parameter(torch.Tensor(out_features))
            self.register_buffer(
                'eps_bias',
                torch.Tensor(out_features),
                persistent=False)
            self.register_buffer(
                'prior_bias_mu',
                torch.Tensor(out_features),
                persistent=False)
            self.register_buffer('prior_bias_sigma',
                                 torch.Tensor(out_features),
                                 persistent=False)
        else:
            self.register_buffer('prior_bias_mu', None, persistent=False)
            self.register_buffer('prior_bias_sigma', None, persistent=False)
            self.register_parameter('mu_bias', None)
            self.register_parameter('rho_bias', None)
            self.register_buffer('eps_bias', None, persistent=False)

        self.init_parameters()

    def init_parameters(self):
        self.prior_weight_mu = torch.from_numpy(self.prior_means)
        self.prior_weight_sigma = torch.from_numpy(self.prior_variances)

        self.mu_weight.data.normal_(mean=self.posterior_mu_init[0], std=0.1)
        self.rho_weight.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
        if self.mu_bias is not None:
            self.prior_bias_mu =torch.from_numpy(np.mean(self.prior_means,axis = 1))
            self.prior_bias_sigma = torch.from_numpy(np.mean(self.prior_variances,axis =1 ))
            self.mu_bias.data.normal_(mean=self.posterior_mu_init[0], std=0.1)
            self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
                                       std=0.1)
    def kl_loss(self):
        sigma_weight = torch.log1p(torch.exp(self.rho_weight))
        kl = self.kl_div(
            self.mu_weight,
            sigma_weight,
            self.prior_weight_mu,
            self.prior_weight_sigma)
        if self.mu_bias is not None:
            sigma_bias = torch.log1p(torch.exp(self.rho_bias))
            kl += self.kl_div(self.mu_bias, sigma_bias,
                              self.prior_bias_mu, self.prior_bias_sigma)
        return kl

    def forward(self, input, return_kl=True):
        if self.dnn_to_bnn_flag:
            return_kl = False
        sigma_weight = torch.log1p(torch.exp(self.rho_weight))
        weight = self.mu_weight + \
            (sigma_weight * self.eps_weight.data.normal_())
        if return_kl:
            kl_weight = self.kl_div(self.mu_weight, sigma_weight,
                                    self.prior_weight_mu, self.prior_weight_sigma)
        bias = None

        if self.mu_bias is not None:
            sigma_bias = torch.log1p(torch.exp(self.rho_bias))
            bias = self.mu_bias + (sigma_bias * self.eps_bias.data.normal_())
            if return_kl:
                kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
                                      self.prior_bias_sigma)

        out = F.linear(input, weight, bias)
        if return_kl:
            if self.mu_bias is not None:
                kl = kl_weight + kl_bias
            else:
                kl = kl_weight

            return out, kl

        return out

In [96]:
test_layer = LinearReparam(
        in_features=1,
        out_features=2,
        prior_means=np.array([[0.],[0.]]),
        prior_variances=np.array([[0.1],[0.1]]),
        posterior_mu_init=0.5,
        posterior_rho_init=-3.0,
        bias=False,
    )

#test_layer.kl_loss()

#inp = torch.tensor([[1]],dtype = torch.float)
#r = test_layer.forward(inp)


In [102]:
test_lay_normal = LinearReparameterization(
        in_features=1,
        out_features=2,
        prior_mean=0,
        prior_variance=0.00000001,
        posterior_mu_init=0.5,
        posterior_rho_init=-3,
        bias=False,)

In [76]:
test_lay_normal.posterior_mu_init[0]
test_lay_normal.mu_weight
test_lay_normal.prior_weight_sigma
test_lay_normal.prior_bias_mu
test_lay_normal.prior_weight_mu

tensor([[0.],
        [0.]])

In [97]:
##Tmr check why these layers dont yield the same result. Maybe assign some variables
# and call get kl function to see if that one works

test_layer.prior_bias_sigma

tensor([0.1000, 0.1000], dtype=torch.float64)

In [100]:
inp = torch.tensor([[1]],dtype = torch.float)

res_normal_1 = []
res_normal_2 = []
res_normal_kl = []
res_new_1 = []
res_new_2 = []
res_new_kl = []

for i in range(10000):
    norm_res = test_lay_normal.forward(inp)
    new_res = test_layer.forward(inp)

    res_normal_1.append(norm_res[0].detach().numpy()[0][0])
    res_normal_2.append(norm_res[0].detach().numpy()[0][1])
    res_normal_kl.append(norm_res[1].detach().numpy())

    res_new_1.append(new_res[0].detach().numpy()[0][0])
    res_new_2.append(new_res[0].detach().numpy()[0][1])
    res_new_kl.append(new_res[1].detach().numpy())

In [104]:
x = np.array([[0.],[0.]])
x

array([[0.],
       [0.]])

In [107]:
np.full((3,1),0.)

array([[0.],
       [0.],
       [0.]])

In [174]:
class Bay_TestNet(nn.Module):
    
    def __init__(self, In_Nodes, Hidden_Nodes, Out_Nodes,mean=0.,variance=.1):
        super(Bay_TestNet, self).__init__()
        #self.tanh = nn.Tanh()
        self.l1 = LinearReparam(in_features=In_Nodes,
                                out_features=Out_Nodes,
                                prior_means=np.full((Out_Nodes,In_Nodes),mean),
                                prior_variances=np.full((Out_Nodes,In_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        
        '''
        self.l2 = LinearReparam(in_features=Hidden_Nodes,
                                out_features=Hidden_Nodes,
                                prior_means=np.full((Hidden_Nodes,Hidden_Nodes),mean),
                                prior_variances=np.full((Hidden_Nodes,Hidden_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        
        self.l3 = LinearReparam(in_features=Hidden_Nodes,
                                out_features=Hidden_Nodes,
                                prior_means=np.full((Hidden_Nodes,Hidden_Nodes),mean),
                                prior_variances=np.full((Hidden_Nodes,Hidden_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        self.l4 = LinearReparam(in_features=Hidden_Nodes,
                                out_features=Out_Nodes,
                                prior_means=np.full((Out_Nodes,Hidden_Nodes),mean),
                                prior_variances=np.full((Out_Nodes,Hidden_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        '''
    def forward(self, x):
        
        lin_pred = self.l1(x)

        return lin_pred

In [None]:
def trainbaynet(train_x, train_age, train_ytime, train_yevent, \
			eval_x, eval_age, eval_ytime, eval_yevent, pathway_mask, \
			In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
			Learning_Rate, Num_Epochs,bnn_prior_params):

    net = Bay_CPASNet(In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, pathway_mask)
    dnn_to_bnn_bcoxpas(net, bnn_prior_params)
    
    ###
    ###optimizer
    opt = optim.Adam(net.parameters(), lr=Learning_Rate)

    for epoch in range(Num_Epochs+1):
        net.train()
        opt.zero_grad() ###reset gradients to zeros

        pred = net(train_x, train_age) ###Forward
        ce_loss = neg_par_log_likelihood(pred, train_ytime, train_yevent)
        kl = get_kl_loss(net)
        loss = ce_loss + kl
        loss.backward() ###calculate gradients
        opt.step()
        net.sc1.weight.data = net.sc1.weight.data.mul(net.pathway_mask) ###force the connections between gene layer and pathway layer

        if epoch % 20 == 0:
            with torch.no_grad():
                net.train()
                train_output_mc = []
                for mc_run in range(10):
                    output = net(train_x, train_age)
                    train_output_mc.append(output)
                    outputs = torch.stack(train_output_mc)
                train_pred = outputs.mean(dim=0)
                train_loss = neg_par_log_likelihood(train_pred, train_ytime, train_yevent).view(1,)

                eval_output_mc = []
                for mc_run in range(10):
                    output = net(eval_x, eval_age)
                    eval_output_mc.append(output)
                    eval_outputs = torch.stack(eval_output_mc)
                eval_pred = eval_outputs.mean(dim=0)
                eval_loss = neg_par_log_likelihood(eval_pred, eval_ytime, eval_yevent).view(1,)

                train_cindex = c_index(train_pred, train_ytime, train_yevent)
                eval_cindex = c_index(eval_pred, eval_ytime, eval_yevent)
                print(f"Epoch: {epoch}, Train Loss: {train_loss},Eval Loss: {eval_loss}, "
                      f" Train Cindex: {train_cindex}, Eval Cindex: {eval_cindex}")

    return (train_loss, eval_loss, train_cindex, eval_cindex)

In [173]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [186]:
def train(args,
          train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          tb_writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    mses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cpu()
        input_var = input.cpu()
        target_var = target


        # compute output
        output_ = []
        kl_ = []
        for mc_run in range(args.num_mc):
            output, kl = model(input_var)
            output_.append(output)
            kl_.append(kl)
        output = torch.mean(torch.stack(output_), dim=0)
        kl = torch.mean(torch.stack(kl_), dim=0)
        cross_entropy_loss = criterion(output, target_var)
        scaled_kl = kl / args.batch_size 
        #ELBO loss
        loss = cross_entropy_loss + scaled_kl

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        mse = criterion(output.data, target)
        losses.update(loss.item(), input.size(0))
        mses.update(mse.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Mses {mses.val:.3f} ({mses.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      mses=mses))

        if tb_writer is not None:
            tb_writer.add_scalar('train/cross_entropy_loss',
                                 cross_entropy_loss.item(), epoch)
            tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch)
            tb_writer.add_scalar('train/elbo_loss', loss.item(), epoch)
            tb_writer.add_scalar('train/accuracy', prec1.item(), epoch)
            tb_writer.flush()


In [249]:
def validate(args, val_loader, model, criterion, epoch, tb_writer=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    errors = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cpu()
            input_var = input.cpu()
            target_var = target.cpu()


            # compute output
            output_ = []
            kl_ = []
            for mc_run in range(args.num_mc):
                output, kl = model(input_var)
                output_.append(output)
                kl_.append(kl)
            output = torch.mean(torch.stack(output_), dim=0)
            kl = torch.mean(torch.stack(kl_), dim=0)
            cross_entropy_loss = criterion(output, target_var)
            scaled_kl = kl / args.batch_size 
            #ELBO loss
            loss = cross_entropy_loss + scaled_kl

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))
            errors.update(error.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Error{error.val:.3f} ({error.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          error=errors))
        '''
            if tb_writer is not None:
                tb_writer.add_scalar('val/cross_entropy_loss',
                                     cross_entropy_loss.item(), epoch)
                tb_writer.add_scalar('val/kl_div', scaled_kl.item(), epoch)
                tb_writer.add_scalar('val/elbo_loss', loss.item(), epoch)
                tb_writer.add_scalar('val/accuracy', prec1.item(), epoch)
                tb_writer.flush()
        '''
    print(' * Error {error.avg:.3f}'.format(error=errors))

    return errors.avg

In [324]:
def evaluate(args, model,criterion, val_loader):
    pred_probs_mc = []
    test_loss = 0
    correct = 0
    output_list = []
    labels_list = []
    model.eval()
    with torch.no_grad():
        begin = time.time()
        for data, target in val_loader:
            data, target = data.cpu(), target.cpu()
            output_mc = []
            for mc_run in range(args.num_mc):
                output, _ = model.forward(data)
                output_mc.append(output)
            output_ = torch.mean(torch.stack(output_mc),dim=0)
            output_list.append(output_)
        end = time.time()

        print('Test Error:',
              (criterion(output_list,target)))
        

In [252]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)

In [319]:
class arguments():
    def __init__(self,num_mc,batch_size,print_freq,epochs,mode,lr,workers,model_name,save_dir):
        self.num_mc = num_mc
        self.batch_size = batch_size
        self.print_freq = print_freq
        self.epochs = epochs
        self.mode = mode
        self.lr = lr
        self.workers = workers
        self.model_name = model_name
        self.save_dir = save_dir
        
args = arguments(200,1,500,3,"test",0.00009,0,"test_model_no_noise","model_checkpoints")



In [302]:
x1 = np.random.normal(size = 1500)
x2 = np.random.normal(size = 1500)
X = np.stack((x1,x2),axis = -1)
y = x1 + x2  
#+ np.random.normal(0,0.01,size = 1500)

In [303]:
X_train = X[0:1000]
y_train = y[0:1000]
X_test = X[1000:1500]
y_test = y[1000:1500]

In [304]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, X, y, scale_data=False):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X).float()
            self.y = torch.from_numpy(y).float()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

train_dataset = Dataset(X_train, y_train)
test_dataset = Dataset(X_test,y_test)
#trainloader = torch.utils.data.DataLoader(train_dataset)


In [325]:
model = Bay_TestNet(2,3,1)
model.cpu()

best_pred = 0

tb_writer = None

train_loader = torch.utils.data.DataLoader( dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=args.workers,
                                            pin_memory=True)

val_loader = torch.utils.data.DataLoader(test_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

criterion = nn.MSELoss().cpu()

  
'''

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
'''
if args.mode == 'train':

    for epoch in range(args.epochs):

            lr = args.lr
            if (epoch >= 80 and epoch < 120):
                lr = 0.1 * args.lr
            elif (epoch >= 120 and epoch < 160):
                lr = 0.01 * args.lr
            elif (epoch >= 160 and epoch < 180):
                lr = 0.001 * args.lr
            elif (epoch >= 180):
                lr = 0.0005 * args.lr

            optimizer = torch.optim.Adam(model.parameters(), lr)

            # train for one epoch
            print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
            train(args, train_loader, model, criterion, optimizer, epoch,
                  tb_writer)

            val_score = validate(args, val_loader, model, criterion, epoch,
                             tb_writer)

            is_best = val_score >= best_pred
            best_val = max(val_score, best_pred)

            if is_best:
                save_checkpoint(
                    {
                    'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                    'best_val': best_val,
                   },
                    is_best,
                    filename=os.path.join(
                        args.save_dir,
                       'bayesian_{}.pth'.format(args.model_name)))

elif args.mode == 'test':
    checkpoint_file = args.save_dir + '/bayesian_{}.pth'.format(
            args.model_name)
    if torch.cuda.is_available():
            checkpoint = torch.load(checkpoint_file)
    else:
            checkpoint = torch.load(checkpoint_file,
                                    map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    evaluate(args, model, criterion,val_loader)



AttributeError: 'list' object has no attribute 'size'

In [None]:
#sth still seems to be wrong with the error function or train method -> increased steadily
#while training while val loss decreased

In [306]:
checkpoint_file = args.save_dir + '/bayesian_{}.pth'.format(
            args.model_name)
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [307]:
ttest_dataset = Dataset(np.array([[0.2,0.1]]),np.array([[0.3]]))
ttest_loader = torch.utils.data.DataLoader(ttest_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

In [308]:
pred_probs_mc = []
test_loss = 0
correct = 0
output_list = []
labels_list = []
model.eval()
with torch.no_grad():
        begin = time.time()
        for data, target in ttest_loader:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            else:
                data, target = data.cpu(), target.cpu()
            output_mc = []
            for mc_run in range(1000):
                output, _ = model.forward(data)
                output_mc.append(output)
            output_ = torch.stack(output_mc)
            output_list.append(output_)
        end = time.time()
        output = torch.stack(output_list)
        

In [309]:
torch.mean(output_, dim=0)

tensor([[0.0111]])

In [130]:
bay_net = Bay_TestNet(2,3,1)


In [244]:
inp = torch.tensor([[0.5,0.5]],dtype = torch.float)
res = []

for i in range(50000):
    res.append(model(inp)[0].detach().cpu().numpy()[0][0])
np.mean(res)

0.04336347

In [248]:
model(inp)

(tensor([[0.1276]], grad_fn=<MmBackward0>),
 tensor(0.0965, dtype=torch.float64, grad_fn=<MeanBackward0>))

In [134]:
bay_net(inp)

(tensor([[1.0395]], grad_fn=<MmBackward0>),
 tensor(14.7322, dtype=torch.float64, grad_fn=<MeanBackward0>))

In [26]:
class Bay_CPASNet(nn.Module):
	def __init__(self, In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, Pathway_Mask):
		super(Bay_CPASNet, self).__init__()
		self.tanh = nn.Tanh()
		self.pathway_mask = Pathway_Mask
		###gene layer --> pathway layer
		self.sc1 = nn.Linear(In_Nodes, Pathway_Nodes)
		###pathway layer --> hidden layer
		self.sc2 = nn.Linear(Pathway_Nodes, Hidden_Nodes)
		###hidden layer --> hidden layer 2
		self.sc3 = nn.Linear(Hidden_Nodes, Out_Nodes, bias=False)
		###hidden layer 2 + age --> Cox layer
		self.sc4 = nn.Linear(Out_Nodes+1, 1, bias = False)
		self.sc4.weight.data.uniform_(-0.001, 0.001)

	def forward(self, x_1, x_2):
		###force the connections between gene layer and pathway layer w.r.t. 'pathway_mask'
		self.sc1.weight.data = self.sc1.weight.data.mul(self.pathway_mask)
		x_1 = self.tanh(self.sc1(x_1))
		x_1 = self.tanh(self.sc2(x_1))
		x_1 = self.tanh(self.sc3(x_1))
		###combine age with hidden layer 2
		x_cat = torch.cat((x_1, x_2), 1)
		lin_pred = self.sc4(x_cat)

		return lin_pred

def trainBayCoxPASNet(train_x, train_age, train_ytime, train_yevent, \
			eval_x, eval_age, eval_ytime, eval_yevent, pathway_mask, \
			In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
			Learning_Rate, Num_Epochs,bnn_prior_params):

    net = Bay_CPASNet(In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, pathway_mask)
    dnn_to_bnn_bcoxpas(net, bnn_prior_params)
    
    ###
    ###optimizer
    opt = optim.Adam(net.parameters(), lr=Learning_Rate)

    for epoch in range(Num_Epochs+1):
        net.train()
        opt.zero_grad() ###reset gradients to zeros

        pred = net(train_x, train_age) ###Forward
        ce_loss = neg_par_log_likelihood(pred, train_ytime, train_yevent)
        kl = get_kl_loss(net)
        loss = ce_loss + kl
        loss.backward() ###calculate gradients
        opt.step()
        net.sc1.weight.data = net.sc1.weight.data.mul(net.pathway_mask) ###force the connections between gene layer and pathway layer

        if epoch % 20 == 0:
            with torch.no_grad():
                net.train()
                train_output_mc = []
                for mc_run in range(10):
                    output = net(train_x, train_age)
                    train_output_mc.append(output)
                    outputs = torch.stack(train_output_mc)
                train_pred = outputs.mean(dim=0)
                train_loss = neg_par_log_likelihood(train_pred, train_ytime, train_yevent).view(1,)

                eval_output_mc = []
                for mc_run in range(10):
                    output = net(eval_x, eval_age)
                    eval_output_mc.append(output)
                    eval_outputs = torch.stack(eval_output_mc)
                eval_pred = eval_outputs.mean(dim=0)
                eval_loss = neg_par_log_likelihood(eval_pred, eval_ytime, eval_yevent).view(1,)

                train_cindex = c_index(train_pred, train_ytime, train_yevent)
                eval_cindex = c_index(eval_pred, eval_ytime, eval_yevent)
                print(f"Epoch: {epoch}, Train Loss: {train_loss},Eval Loss: {eval_loss}, "
                      f" Train Cindex: {train_cindex}, Eval Cindex: {eval_cindex}")

    return (train_loss, eval_loss, train_cindex, eval_cindex)

In [30]:
const_bnn_prior_parameters = {
        "prior_mu": 0.0,
        "prior_sigma": 1.0,
        "posterior_mu_init": 0.0,
        "posterior_rho_init": -3.0,
        "type": "Reparameterization",  # Flipout or Reparameterization
        "moped_enable": True,  # True to initialize mu/sigma from the pretrained dnn weights
        "moped_delta": 0.5,
}
    
opt_lr_loss = 0.01
loss_train_bnn, loss_test_bnn, c_index_tr_bnn, c_index_te_bnn = trainBayCoxPASNet(x_train, age_train, ytime_train, yevent_train, \
							x_test, age_test, ytime_test, yevent_test, pathway_mask, \
							In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
							opt_lr_loss, Num_EPOCHS,const_bnn_prior_parameters)
   

Epoch: 0, Train Loss: tensor([4.8027]),Eval Loss: tensor([3.5799]),  Train Cindex: 0.629298210144043, Eval Cindex: 0.6694067716598511


In [40]:
net(x_train[0:2],age_train[0:2])

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [55]:
torch.cat((x_train[0][None:],age_train[0][None:]),1)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [53]:
age_train[0][None:].shape

torch.Size([1])