In [10]:
from CoxPASNet.coxpasnet.DataLoader import load_data, load_pathway
from CoxPASNet.coxpasnet.Train import trainCoxPASNet
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from CoxPASNet.coxpasnet.Survival_CostFunc_CIndex import R_set, neg_par_log_likelihood, c_index
from sksurv.metrics import concordance_index_censored



import torch.optim as optim
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

from run_pipeline.bayesian_custom_nn import run_custom_bnn,arguments
from run_pipeline.cpath_bnn import run_cpath_bnn,arguments


import time

from src.data_prep.torch_datasets import cpath_dataset


In [4]:
dtype = torch.FloatTensor
''' Net Settings'''
In_Nodes = 5567 ###number of genes
Pathway_Nodes = 860 ###number of pathways
Hidden_Nodes = 100 ###number of hidden nodes
Out_Nodes = 30 ###number of hidden nodes in the last hidden layer
''' Initialize '''
Initial_Learning_Rate = [0.03] #[0.03, 0.01, 0.001, 0.00075]
L2_Lambda = [0.01]  #[0.1, 0.01, 0.005, 0.001]
num_epochs = 10 #3000 ###for grid search
Num_EPOCHS = 15 #20000 ###for training
###sub-network setup
Dropout_Rate = [0.7,0.5]

In [22]:
''' load data and pathway '''
pathway_mask = load_pathway("../data/pathway_mask.csv", dtype)

x_train, ytime_train, yevent_train, age_train = load_data("../data/train.csv", dtype)
x_valid, ytime_valid, yevent_valid, age_valid = load_data("../data/validation.csv", dtype)
x_test, ytime_test, yevent_test, age_test = load_data("../data/test.csv", dtype)


In [23]:
pat_mask

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [24]:
pathway_mask

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [21]:
pat_mask = pd.read_csv("../data/pathway_mask.csv", index_col=0).values
pat_mask = torch.from_numpy(pat_mask).type(torch.FloatTensor)

In [6]:
X_train_np,tb = pd.read_csv("../data/train.csv")

ValueError: too many values to unpack (expected 2)

In [76]:
opt_l2_loss = 0
opt_lr_loss = 0
opt_loss = torch.Tensor([float("Inf")])
###if gpu is being used
if torch.cuda.is_available():
	opt_loss = opt_loss.cuda()
###
opt_c_index_va = 0
opt_c_index_tr = 0
###grid search the optimal hyperparameters using train and validation data
for l2 in L2_Lambda:
	for lr in Initial_Learning_Rate:
		loss_train, loss_valid, c_index_tr, c_index_va = trainCoxPASNet(x_train, age_train, ytime_train, yevent_train, \
																x_valid, age_valid, ytime_valid, yevent_valid, pathway_mask, \
																In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
																lr, l2, num_epochs, Dropout_Rate)
		if loss_valid < opt_loss:
			opt_l2_loss = l2
			opt_lr_loss = lr
			opt_loss = loss_valid
			opt_c_index_tr = c_index_tr
			opt_c_index_va = c_index_va
		print ("L2: ", l2, "LR: ", lr, "Loss in Validation: ", loss_valid)



###train Cox-PASNet with optimal hyperparameters using train data, and then evaluate the trained model with test data
###Note that test data are only used to evaluate the trained Cox-PASNet
loss_train, loss_test, c_index_tr, c_index_te = trainCoxPASNet(x_train, age_train, ytime_train, yevent_train, \
							x_test, age_test, ytime_test, yevent_test, pathway_mask, \
							In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
							opt_lr_loss, opt_l2_loss, Num_EPOCHS, Dropout_Rate)
print ("Optimal L2: ", opt_l2_loss, "Optimal LR: ", opt_lr_loss)
print("C-index in Test: ", c_index_te)




Loss in Train:  tensor([4.7949], grad_fn=<ViewBackward0>)
L2:  0.01 LR:  0.03 Loss in Validation:  tensor([3.4735], grad_fn=<ViewBackward0>)
Loss in Train:  tensor([4.7948], grad_fn=<ViewBackward0>)
Optimal L2:  0.01 Optimal LR:  0.03
C-index in Test:  tensor(0.6699)


In [None]:
class Cox_PASNet(nn.Module):
	def __init__(self, In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, Pathway_Mask):
		super(Cox_PASNet, self).__init__()
		self.tanh = nn.Tanh()
		self.pathway_mask = Pathway_Mask
		###gene layer --> pathway layer
		self.sc1 = nn.Linear(In_Nodes, Pathway_Nodes)
		###pathway layer --> hidden layer
		self.sc2 = nn.Linear(Pathway_Nodes, Hidden_Nodes)
		###hidden layer --> hidden layer 2
		self.sc3 = nn.Linear(Hidden_Nodes, Out_Nodes, bias=False)
		###hidden layer 2 + age --> Cox layer
		self.sc4 = nn.Linear(Out_Nodes+1, 1, bias = False)
		self.sc4.weight.data.uniform_(-0.001, 0.001)
		###randomly select a small sub-network
		self.do_m1 = torch.ones(Pathway_Nodes)
		self.do_m2 = torch.ones(Hidden_Nodes)
		###if gpu is being used
		if torch.cuda.is_available():
			self.do_m1 = self.do_m1.cuda()
			self.do_m2 = self.do_m2.cuda()
		###

	def forward(self, x_1, x_2):
		###force the connections between gene layer and pathway layer w.r.t. 'pathway_mask'
		self.sc1.weight.data = self.sc1.weight.data.mul(self.pathway_mask)
		x_1 = self.tanh(self.sc1(x_1))
		if self.training == True: ###construct a small sub-network for training only
			x_1 = x_1.mul(self.do_m1)
		x_1 = self.tanh(self.sc2(x_1))
		if self.training == True: ###construct a small sub-network for training only
			x_1 = x_1.mul(self.do_m2)
		x_1 = self.tanh(self.sc3(x_1))
		###combine age with hidden layer 2
		x_cat = torch.cat((x_1, x_2), 1)
		lin_pred = self.sc4(x_cat)
		
		return lin_pred


In [None]:
#make scripts so that they work with the age variable

In [3]:
class Bay_CPathNet(nn.Module):

    def __init__(self, In_Nodes, Hidden_Nodes,Last_Nodes, mean=0., variance=0.1):
        super(Bay_CPathNet, self).__init__()
        # self.tanh = nn.Tanh()
        self.l1 = LinearReparam(in_features=In_Nodes,
                                    out_features=Hidden_Nodes,
                                    prior_means=np.full((Hidden_Nodes, In_Nodes), mean),
                                    prior_variances=np.full((Hidden_Nodes, In_Nodes), variance),
                                    posterior_mu_init=np.full((Hidden_Nodes, In_Nodes), 0.5),
                                    posterior_rho_init=np.full((Hidden_Nodes, In_Nodes), -3.),
                                    bias=False,
                                    )
        
        self.l2 = LinearReparam(in_features=Hidden_Nodes,
                                    out_features=Hidden_Nodes,
                                    prior_means=np.full((Hidden_Nodes,Hidden_Nodes),mean),
                                    prior_variances=np.full((Hidden_Nodes,Hidden_Nodes),variance),
                                    posterior_mu_init=np.full((Hidden_Nodes, Hidden_Nodes), 0.5),
                                    posterior_rho_init=np.full((Hidden_Nodes, Hidden_Nodes), -3.),
                                    bias=False,
                                    )
        
        
        self.l3 = LinearReparam(in_features=Hidden_Nodes,
                                    out_features=Last_Nodes,
                                    prior_means=np.full((Last_Nodes,Hidden_Nodes),mean),
                                    prior_variances=np.full((Last_Nodes,Hidden_Nodes),variance),
                                    posterior_mu_init=np.full((Last_Nodes, Hidden_Nodes), 0.5),
                                    posterior_rho_init=np.full((Last_Nodes, Hidden_Nodes), -3.),
                                    bias=False,
                                    )


        self.l4 = LinearReparam(in_features=Last_Nodes+1,
                                    out_features=1,
                                    prior_means=np.full((1, Last_Nodes+1), mean),
                                    prior_variances=np.full((1, Last_Nodes+1), variance),
                                    posterior_mu_init=np.full((1, Last_Nodes+1), 0.5),
                                    posterior_rho_init=np.full((1, Last_Nodes+1), -3.),
                                    bias=False,
                                    )


    def forward(self, x, clinical_vars):
            pred = nn.Tanh(self.l1.forward(x))
            pred = nn.Tanh(self.l2(pred,return_kl = False))
            pred = nn.Tanh(self.l3(pred,return_kl = False))
            x_cat = torch.cat((x_1, clinical_vars), 1)
            lin_pred = self.l4(x_cat,return_kl = False)

            return lin_pred


In [7]:
train_data = pd.read_csv("../data/train.csv")
X_train_np =  train_data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis = 1).values
tb_train = train_data.loc[:, ["OS_MONTHS"]].values
e_train = train_data.loc[:, ["OS_EVENT"]].values
clinical_vars_train = train_data.loc[:, ["AGE"]].values

val_data = pd.read_csv("../data/validation.csv")
X_val_np =  val_data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis = 1).values
tb_val = val_data.loc[:, ["OS_MONTHS"]].values
e_val = val_data.loc[:, ["OS_EVENT"]].values
clinical_vars_val = val_data.loc[:, ["AGE"]].values

In [12]:
cpath_train_dataset = cpath_dataset(X_train_np,
                                    clinical_vars_train,
                                    tb_train,
                                    e_train)

cpath_val_dataset = cpath_dataset(X_val_np,
                                  clinical_vars_val,
                                  tb_val,
                                  e_val)

    #run model

#args = arguments(200, 84, 50, 10, "train", 0.01, 0, "test_model_cpath", "model_checkpoints")

#model = Bay_CPathNet(5567, 100, 30,variance = 200)

In [8]:
tes = torch.from_numpy(X_train_np)
tes = tes.float()

In [19]:
class CosPas_BNN_NJ(nn.Module):
        def __init__(self, In_Nodes, Hidden_Nodes,Last_Nodes):
            super(CosPas_BNN_NJ, self).__init__()
            # activation
            self.tanh = nn.Tanh()
            # layers
            self.fc1 = LinearGroupNJ(In_Nodes, Hidden_Nodes, clip_var=0.04, cuda=BNN_NJ_Flags.cuda)
            self.fc2 = LinearGroupNJ(Hidden_Nodes, Hidden_Nodes, cuda=BNN_NJ_Flags.cuda)
            self.fc3 = LinearGroupNJ(Hidden_Nodes, Hidden_Nodes, cuda=BNN_NJ_Flags.cuda)
            self.fc4 = LinearGroupNJ(Hidden_Nodes+1, 1, cuda=BNN_NJ_Flags.cuda)
            # layers including kl_divergence
            self.kl_list = [self.fc1, self.fc2, self.fc3,self.fc4]

        def forward(self, x,clinical_vars):
            x = self.tanh(self.fc1(x))
            x = self.tanh(self.fc2(x))
            x = self.tanh(self.fc3(x))
            x_cat = torch.cat((x, clinical_vars), 1)
            lin_pred = self.fc4(x_cat)
            
            return lin_pred

        def get_masks(self,thresholds):
            weight_masks = []
            mask = None
            for i, (layer, threshold) in enumerate(zip(self.kl_list, thresholds)):
                # compute dropout mask
                if mask is None:
                    log_alpha = layer.get_log_dropout_rates().cpu().data.numpy()
                    mask = log_alpha < threshold
                else:
                    mask = np.copy(next_mask)
                try:
                    log_alpha = layers[i + 1].get_log_dropout_rates().cpu().data.numpy()
                    next_mask = log_alpha < thresholds[i + 1]
                except:
                    # must be the last mask
                    next_mask = np.ones(10)

                weight_mask = np.expand_dims(mask, axis=0) * np.expand_dims(next_mask, axis=1)
                weight_masks.append(weight_mask.astype(np.float))
            return weight_masks

        def kl_divergence(self):
            KLD = 0
            for layer in self.kl_list:
                KLD += layer.kl_divergence()
            return KLD



In [22]:
def partial_ll_loss(lrisks, tb, eb, eps=1e-3):

    tb = tb + eps*np.random.random(len(tb))
    sindex = np.argsort(-tb)

    tb = tb[sindex]
    eb = eb[sindex]

    lrisks = lrisks[sindex]
    lrisksdenom = torch.logcumsumexp(lrisks, dim = 0)

    plls = lrisks - lrisksdenom
    pll = plls[eb == 1]

    pll = torch.sum(pll)

    return -pll

In [38]:
class BNN_NJ_Flags_class():
        def __init__(self, num_mc,epochs,batch_size,print_freq,thresholds,cuda = False):
            self.num_mc = num_mc
            self.epochs = epochs
            self.batch_size = batch_size
            self.print_freq = print_freq
            self.thresholds = thresholds
            self.cuda = cuda
        

In [39]:
BNN_NJ_Flags= BNN_NJ_Flags_class(200,10,84,2,[-2.8, -3., -5.])

In [15]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [14]:
train_data = pd.read_csv("../data/train.csv")
X_train_np = train_data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis=1).values
tb_train = train_data.loc[:, ["OS_MONTHS"]].values
e_train = train_data.loc[:, ["OS_EVENT"]].values
clinical_vars_train = train_data.loc[:, ["AGE"]].values

val_data = pd.read_csv("../data/validation.csv")
X_val_np = val_data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis=1).values
tb_val = val_data.loc[:, ["OS_MONTHS"]].values
e_val = val_data.loc[:, ["OS_EVENT"]].values
clinical_vars_val = val_data.loc[:, ["AGE"]].values

cpath_train_dataset = cpath_dataset(X_train_np,
                                        clinical_vars_train,
                                        tb_train,
                                        e_train)

cpath_val_dataset = cpath_dataset(X_val_np,
                                      clinical_vars_val,
                                      tb_val,
                                      e_val)

    # import data
cpath_train_loader = torch.utils.data.DataLoader(cpath_train_dataset,
                                            batch_size=50,
                                            shuffle=True,
                                            num_workers=0)

cpath_val_loader = torch.utils.data.DataLoader(cpath_val_dataset,
                                            batch_size=40,
                                            shuffle=False,
                                            num_workers=0)



In [19]:
len(cpath_val_loader.dataset)

84

In [47]:
train_data = pd.read_csv("../data/train.csv")
X_train_np = train_data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis=1).values
tb_train = train_data.loc[:, ["OS_MONTHS"]].values
e_train = train_data.loc[:, ["OS_EVENT"]].values
clinical_vars_train = train_data.loc[:, ["AGE"]].values

val_data = pd.read_csv("../data/validation.csv")
X_val_np = val_data.drop(["SAMPLE_ID", "OS_MONTHS", "OS_EVENT", "AGE"], axis=1).values
tb_val = val_data.loc[:, ["OS_MONTHS"]].values
e_val = val_data.loc[:, ["OS_EVENT"]].values
clinical_vars_val = val_data.loc[:, ["AGE"]].values

cpath_train_dataset = cpath_dataset(X_train_np,
                                        clinical_vars_train,
                                        tb_train,
                                        e_train)

cpath_val_dataset = cpath_dataset(X_val_np,
                                      clinical_vars_val,
                                      tb_val,
                                      e_val)

    # import data
cpath_train_loader = torch.utils.data.DataLoader(cpath_train_dataset,
                                            batch_size=BNN_NJ_Flags.batch_size,
                                            shuffle=True,
                                            num_workers=0)

cpath_val_loader = torch.utils.data.DataLoader(cpath_val_dataset,
                                            batch_size=BNN_NJ_Flags.batch_size,
                                            shuffle=False,
                                            num_workers=0)


    # for later analysis we take some sample digits
    #mask = 255. * (np.ones((1, 28, 28)))
    #examples = train_loader.sampler.data_source.train_data[0:5].numpy()
    #images = np.vstack([mask, examples])

    # build a simple MLP
    
    # init model
model = CosPas_BNN_NJ(5567, 100, 30)
if BNN_NJ_Flags.cuda:
    model.cuda()

    # init optimizer
optimizer = optim.Adam(model.parameters())

    # we optimize the variational lower bound scaled by the number of data
    # points (so we can keep our intuitions about hyper-params such as the learning rate)

 

def train(BNN_NJ_Flags,model,cpath_train_loader,epoch,optimizer):        
        
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        plls = AverageMeter()
        c_indexs = AverageMeter()

        # switch to train mode
        model.train()

        end = time.time()
        for i, (input, target) in enumerate(cpath_train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            tb = target["tb"].cpu()
            e = target["e"].cpu()
            input_var = input["X"].cpu()
            clinical_var = input["clinical_vars"].cpu()
            
            output_ = []
            for mc_run in range(BNN_NJ_Flags.num_mc):
                output = model(input_var,clinical_var)
                output_.append(output)
            output = torch.mean(torch.stack(output_), dim=0)
            loss_crit_metric = partial_ll_loss(output.reshape(-1),tb.reshape(-1),e.reshape(-1))
            scaled_loss_crit_metric = loss_crit_metric / BNN_NJ_Flags.batch_size
            scaled_kl = model.kl_divergence() / BNN_NJ_Flags.batch_size
            loss =  loss_crit_metric + scaled_kl

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            for layer in model.kl_list:
                    layer.clip_variances()
                
            conc_metric = concordance_index_censored(e.detach().numpy().astype(bool).reshape(-1),
                                                     tb.detach().numpy().reshape(-1),
                                                     output.reshape(-1).detach().numpy())[0]
            output = output.float()
            loss = loss.float()
            # measure accuracy and record loss
            losses.update(loss.item(), input["X"].size(0))
            plls.update(scaled_loss_crit_metric.item(), input["X"].size(0))
            c_indexs.update(conc_metric, input["X"].size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % BNN_NJ_Flags.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'PLL {plls.val:.3f} ({plls.avg:.3f})\t'
                      'C-Index {c_ind.val:.3f} ({c_ind.avg:.3f})'.format(
                          epoch,
                          i,
                          len(cpath_train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          plls=plls,
                          c_ind=c_indexs))

def test():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            if FLAGS.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            output = model(data)
            test_loss += discrimination_loss(output, target, size_average=False).data[0]
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        test_loss /= len(test_loader.dataset)
        print('Test loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

    # train the model and save some visualisations on the way



In [52]:

def validate_cpath_nj_model(args, cpath_val_loader, model, epoch, tb_writer=None):

    batch_time = AverageMeter()
    losses = AverageMeter()
    errors = AverageMeter()
    c_indexs = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(cpath_val_loader):

            tb = target["tb"].cpu()
            e = target["e"].cpu()
            input_var = input["X"].cpu()
            clinical_var = input["clinical_vars"].cpu()

            # compute output
            output_ = []
            for mc_run in range(args.num_mc):
                output = model(input_var,clinical_var)
                output_.append(output)
            output = torch.mean(torch.stack(output_), dim=0)
            error_metric = partial_ll_loss(output.reshape(-1), tb.reshape(-1), e.reshape(-1))
            scaled_error_metric = error_metric /  args.batch_size
            scaled_kl = model.kl_divergence() / args.batch_size

            #ELBO loss
            loss = error_metric + scaled_kl

            conc_metric = concordance_index_censored(e.detach().numpy().astype(bool).reshape(-1),
                                                     tb.detach().numpy().reshape(-1),
                                                     output.reshape(-1).detach().numpy())[0]

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            losses.update(loss.item(), input["X"].size(0))
            errors.update(scaled_error_metric.item(), input["X"].size(0))
            c_indexs.update(conc_metric,input["X"].size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Error {error.val:.3f} ({error.avg:.3f})\t'
                      'C-Index {c_ind.val:.3f} ({c_ind.avg:.3f}) '.format(
                          i,
                          len(cpath_val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          error=errors,
                          c_ind=c_indexs))
       
    print(' * Error {error.avg:.3f}'.format(error=errors))

    return errors.avg

In [53]:
for epoch in range(1, BNN_NJ_Flags.epochs + 1):
    train(BNN_NJ_Flags,model,cpath_train_loader,epoch,optimizer)
            #test()
            # visualizations
    val_score = validate_cpath_nj_model(BNN_NJ_Flags, cpath_val_loader, model, epoch)
                             
    weight_mus = [model.fc1.weight_mu, model.fc2.weight_mu]
    log_alphas = [model.fc1.get_log_dropout_rates(), model.fc2.get_log_dropout_rates(),
                          model.fc3.get_log_dropout_rates()]
            #visualise_weights(weight_mus, log_alphas, epoch=epoch)
            #log_alpha = model.fc1.get_log_dropout_rates().cpu().data.numpy()
            #visualize_pixel_importance(images, log_alpha=log_alpha, epoch=str(epoch))

            # compute compression rate and new model accuracy
    layers = [model.fc1, model.fc2, model.fc3]
#thresholds = FLAGS.thresholds
#compute_compression_rate(layers, model.get_masks(thresholds))

    #print("Test error after with reduced bit precision:")

    #weights = compute_reduced_weights(layers, model.get_masks(thresholds))
    #for layer, weight in zip(layers, weights):
     #   if FLAGS.cuda:
      #      layer.post_weight_mu.data = torch.Tensor(weight).cuda()
       # else:
        #    layer.post_weight_mu.data = torch.Tensor(weight)
    #for layer in layers: layer.deterministic = True
    #test()

Epoch: [1][0/4]	Time 5.696 (5.696)	Data 0.009 (0.009)	Loss 28061.5723 (28061.5723)	PLL 2.946 (2.946)	C-Index 0.666 (0.666)
Epoch: [1][2/4]	Time 5.598 (5.634)	Data 0.001 (0.004)	Loss 28012.2598 (28038.8249)	PLL 2.442 (2.717)	C-Index 0.788 (0.717)
Test: [0/1]	Time 1.315 (1.315)	Loss 28087.1992 (28087.1992)	Error 3.418 (3.418)	C-Index 0.435 (0.435) 
 * Error 3.418
Epoch: [2][0/4]	Time 5.787 (5.787)	Data 0.001 (0.001)	Loss 28039.2676 (28039.2676)	PLL 2.847 (2.847)	C-Index 0.740 (0.740)
Epoch: [2][2/4]	Time 5.597 (5.682)	Data 0.001 (0.001)	Loss 28007.4766 (28022.2253)	PLL 2.552 (2.686)	C-Index 0.781 (0.762)
Test: [0/1]	Time 1.317 (1.317)	Loss 28070.6875 (28070.6875)	Error 3.388 (3.388)	C-Index 0.483 (0.483) 
 * Error 3.388
Epoch: [3][0/4]	Time 5.649 (5.649)	Data 0.001 (0.001)	Loss 27992.9238 (27992.9238)	PLL 2.462 (2.462)	C-Index 0.857 (0.857)
Epoch: [3][2/4]	Time 5.617 (5.679)	Data 0.001 (0.001)	Loss 28015.6621 (28001.8327)	PLL 2.816 (2.610)	C-Index 0.772 (0.814)
Test: [0/1]	Time 1.341 (1.

In [17]:
pred = model.l1.forward(tes,return_kl = False).detach()

In [18]:
pred

tensor([[-301.4059, -298.6887, -304.2516,  ..., -295.9987, -301.2021,
         -300.9999],
        [-840.2748, -849.6366, -844.2791,  ..., -834.6867, -840.2144,
         -844.5341],
        [ 743.9192,  739.9576,  732.7419,  ...,  748.3130,  747.7260,
          735.9771],
        ...,
        [-265.4668, -263.0511, -258.8149,  ..., -266.2320, -262.6032,
         -258.1574],
        [ 951.9144,  952.3220,  944.4562,  ...,  932.3268,  941.6860,
          954.7036],
        [ 406.7895,  391.3521,  388.9161,  ...,  393.7029,  395.9680,
          398.4599]])

In [21]:
torch.nn.Tanh(pred[0])

TypeError: __init__() takes 1 positional argument but 2 were given

In [23]:
test = torch.tensor([[2,3,4,5],[2,3,23,3]])
v = torch.nn.Tanh()
v(test)

tensor([[0.9640, 0.9951, 0.9993, 0.9999],
        [0.9640, 0.9951, 1.0000, 0.9951]])

In [19]:


run_cpath_bnn(cpath_train_dataset,cpath_val_dataset,model,args,"train")



    #inp = torch.tensor([[0.5, 0.5]], dtype=torch.float)

    #print(model(inp))

current lr 1.00000e-02


TypeError: __init__() takes 1 positional argument but 2 were given

In [3]:
import bayesian_torch.layers as bayesian_layers
from bayesian_torch.utils.util import get_rho

In [4]:
test_layer = LinearReparam(
        in_features=2,
        out_features=1,
        prior_means=np.array([[0.,0.]]),
        prior_variances=np.array([[0.1,0.1]]),
        posterior_mu_init=np.array([[0.,0.]]),
        posterior_rho_init=np.array([[0.1,0.1]]),
        bias=False,
    )

#test_layer.kl_loss()

#inp = torch.tensor([[1]],dtype = torch.float)
#r = test_layer.forward(inp)


In [548]:
test_layer.mu_weight.shape

torch.Size([1, 2])

In [517]:
np.array([[0.,0.]])

array([[0., 0.]])

In [512]:
test_layer.prior_weight_mu.shape

torch.Size([2, 1])

In [549]:
test_lay_normal = LinearReparameterization(
        in_features=3,
        out_features=2,
        prior_mean=0,
        prior_variance=0.00000001,
        posterior_mu_init=0.5,
        posterior_rho_init=-3,
        bias=False,)

In [552]:
test_lay_normal.mu_weight.data

tensor([[0.5124, 0.4482, 0.5500],
        [0.6314, 0.5597, 0.4710]])

In [76]:
test_lay_normal.posterior_mu_init[0]
test_lay_normal.mu_weight
test_lay_normal.prior_weight_sigma
test_lay_normal.prior_bias_mu
test_lay_normal.prior_weight_mu

tensor([[0.],
        [0.]])

In [97]:
##Tmr check why these layers dont yield the same result. Maybe assign some variables
# and call get kl function to see if that one works

test_layer.prior_bias_sigma

tensor([0.1000, 0.1000], dtype=torch.float64)

In [100]:
inp = torch.tensor([[1]],dtype = torch.float)

res_normal_1 = []
res_normal_2 = []
res_normal_kl = []
res_new_1 = []
res_new_2 = []
res_new_kl = []

for i in range(10000):
    norm_res = test_lay_normal.forward(inp)
    new_res = test_layer.forward(inp)

    res_normal_1.append(norm_res[0].detach().numpy()[0][0])
    res_normal_2.append(norm_res[0].detach().numpy()[0][1])
    res_normal_kl.append(norm_res[1].detach().numpy())

    res_new_1.append(new_res[0].detach().numpy()[0][0])
    res_new_2.append(new_res[0].detach().numpy()[0][1])
    res_new_kl.append(new_res[1].detach().numpy())

In [104]:
x = np.array([[0.],[0.]])
x

array([[0.],
       [0.]])

array([[0., 0., 0.]])

In [64]:
class Bay_TestNet(nn.Module):
    
    def __init__(self, In_Nodes, Hidden_Nodes, Out_Nodes,mean=0.,variance=.1):
        super(Bay_TestNet, self).__init__()
        #self.tanh = nn.Tanh()
        self.l1 = LinearReparam(in_features=In_Nodes,
                                out_features=Out_Nodes,
                                prior_means=np.array([[0.,3.5]]), #np.full((Out_Nodes,In_Nodes),mean)
                                prior_variances= np.array([[0.1,0.5]]),#np.full((Out_Nodes,In_Nodes),variance)
                                posterior_mu_init=np.array([[2.,0.5]]),
                                posterior_rho_init=np.array([[-3.,-3.]]),
                                bias=False,
                                )
        
        '''
        self.l2 = LinearReparam(in_features=Hidden_Nodes,
                                out_features=Hidden_Nodes,
                                prior_means=np.full((Hidden_Nodes,Hidden_Nodes),mean),
                                prior_variances=np.full((Hidden_Nodes,Hidden_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        
        self.l3 = LinearReparam(in_features=Hidden_Nodes,
                                out_features=Hidden_Nodes,
                                prior_means=np.full((Hidden_Nodes,Hidden_Nodes),mean),
                                prior_variances=np.full((Hidden_Nodes,Hidden_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        self.l4 = LinearReparam(in_features=Hidden_Nodes,
                                out_features=Out_Nodes,
                                prior_means=np.full((Out_Nodes,Hidden_Nodes),mean),
                                prior_variances=np.full((Out_Nodes,Hidden_Nodes),variance),
                                posterior_mu_init=0.5,
                                posterior_rho_init=-3.0,
                                bias=False,
                                )
        '''
    def forward(self, x):
        
        lin_pred = self.l1(x)

        return lin_pred

In [8]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, X, y, scale_data=False):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X).float()
            self.y = torch.from_numpy(y).float()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]



In [6]:
class arguments():
    def __init__(self,num_mc,batch_size,print_freq,epochs,mode,lr,workers,model_name,save_dir):
        self.num_mc = num_mc
        self.batch_size = batch_size
        self.print_freq = print_freq
        self.epochs = epochs
        self.mode = mode
        self.lr = lr
        self.workers = workers
        self.model_name = model_name
        self.save_dir = save_dir
        
args = arguments(200,1,500,10,"test",0.00009,0,"test_model_no_noise","model_checkpoints")



In [65]:
weight_1 = []
weight_2 = []

for i in range(1):

    x1 = np.random.normal(size=40)
    x2 = np.random.normal(size=40)
    X = np.stack((x1, x1), axis=-1)
    y = x1+x1 + x1 #x1 + x2

    X_train = X[0:30]
    y_train = y[0:30]
    X_test = X[30:40]
    y_test = y[30:40]

    train_dataset = Dataset(X_train, y_train)
    val_dataset = Dataset(X_test, y_test)

    #run model

    args = arguments(200, 5, 5, 200, "train", 0.01, 0, "test_model_no_noise", "model_checkpoints")

    model = Bay_TestNet(2, 3, 1)

    criterion = nn.MSELoss().cpu()

    run_custom_bnn(train_dataset,val_dataset,args,model,criterion)

    inp = torch.tensor([[0.5, 0.5]], dtype=torch.float)

    weight_1.append(model.l1.mu_weight.detach().numpy()[0][0])
    weight_2.append(model.l1.mu_weight.detach().numpy()[0][1])

current lr 1.00000e-02
Epoch: [0][0/6]	Time 0.064 (0.064)	Data 0.001 (0.001)	Loss 19.8398 (19.8398)	Mses 0.207 (0.207)
Epoch: [0][5/6]	Time 0.041 (0.045)	Data 0.000 (0.000)	Loss 19.3413 (19.4961)	Mses 0.700 (0.361)
Test: [0/2]	Time 0.014 (0.014)	Loss 18.7888 (18.7888)	Error0.341 (0.341)
 * Error 0.411
current lr 1.00000e-02
Epoch: [1][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 18.7847 (18.7847)	Mses 0.337 (0.337)
Epoch: [1][5/6]	Time 0.042 (0.038)	Data 0.000 (0.000)	Loss 18.2065 (18.3227)	Mses 0.721 (0.358)
Test: [0/2]	Time 0.014 (0.014)	Loss 17.6394 (17.6394)	Error0.342 (0.342)
 * Error 0.408
current lr 1.00000e-02
Epoch: [2][0/6]	Time 0.038 (0.038)	Data 0.000 (0.000)	Loss 17.7836 (17.7836)	Mses 0.486 (0.486)
Epoch: [2][5/6]	Time 0.046 (0.042)	Data 0.000 (0.000)	Loss 17.5481 (17.1908)	Mses 1.180 (0.360)
Test: [0/2]	Time 0.016 (0.016)	Loss 16.5211 (16.5211)	Error0.332 (0.332)
 * Error 0.414
current lr 1.00000e-02
Epoch: [3][0/6]	Time 0.043 (0.043)	Data 0.000 (0.000)	Loss 16.9041 (

Epoch: [27][5/6]	Time 0.044 (0.044)	Data 0.000 (0.000)	Loss 1.6780 (1.2808)	Mses 0.931 (0.460)
Test: [0/2]	Time 0.015 (0.015)	Loss 1.1575 (1.1575)	Error0.432 (0.432)
 * Error 0.503
current lr 1.00000e-02
Epoch: [28][0/6]	Time 0.041 (0.041)	Data 0.000 (0.000)	Loss 0.8125 (0.8125)	Mses 0.087 (0.087)
Epoch: [28][5/6]	Time 0.045 (0.044)	Data 0.000 (0.000)	Loss 1.3473 (1.1164)	Mses 0.730 (0.451)
Test: [0/2]	Time 0.015 (0.015)	Loss 1.0166 (1.0166)	Error0.413 (0.413)
 * Error 0.493
current lr 1.00000e-02
Epoch: [29][0/6]	Time 0.041 (0.041)	Data 0.000 (0.000)	Loss 1.6027 (1.6027)	Mses 0.999 (0.999)
Epoch: [29][5/6]	Time 0.040 (0.041)	Data 0.000 (0.000)	Loss 1.1015 (1.0013)	Mses 0.540 (0.408)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.8946 (0.8946)	Error0.347 (0.347)
 * Error 0.446
current lr 1.00000e-02
Epoch: [30][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 1.8254 (1.8254)	Mses 1.278 (1.278)
Epoch: [30][5/6]	Time 0.040 (0.039)	Data 0.000 (0.000)	Loss 1.0195 (0.9265)	Mses 0.489 (0.378)
Test: [0

Epoch: [55][5/6]	Time 0.042 (0.039)	Data 0.000 (0.000)	Loss 0.0500 (0.0558)	Mses 0.002 (0.008)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0527 (0.0527)	Error0.006 (0.006)
 * Error 0.010
current lr 1.00000e-02
Epoch: [56][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0648 (0.0648)	Mses 0.018 (0.018)
Epoch: [56][5/6]	Time 0.041 (0.041)	Data 0.000 (0.000)	Loss 0.0479 (0.0538)	Mses 0.000 (0.004)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0484 (0.0484)	Error0.003 (0.003)
 * Error 0.004
current lr 1.00000e-02
Epoch: [57][0/6]	Time 0.038 (0.038)	Data 0.000 (0.000)	Loss 0.0481 (0.0481)	Mses 0.002 (0.002)
Epoch: [57][5/6]	Time 0.042 (0.042)	Data 0.000 (0.000)	Loss 0.0594 (0.0507)	Mses 0.022 (0.010)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0469 (0.0469)	Error0.009 (0.009)
 * Error 0.012
current lr 1.00000e-02
Epoch: [58][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0396 (0.0396)	Mses 0.001 (0.001)
Epoch: [58][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0467 (0.0461)	Mses 0.010 (0.010)
Test: [0

Epoch: [83][5/6]	Time 0.041 (0.038)	Data 0.000 (0.000)	Loss 0.0391 (0.0446)	Mses 0.005 (0.011)
Test: [0/2]	Time 0.013 (0.013)	Loss 0.0413 (0.0413)	Error0.007 (0.007)
 * Error 0.015
current lr 1.00000e-03
Epoch: [84][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0471 (0.0471)	Mses 0.013 (0.013)
Epoch: [84][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0352 (0.0397)	Mses 0.001 (0.005)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0344 (0.0344)	Error0.000 (0.000)
 * Error 0.009
current lr 1.00000e-03
Epoch: [85][0/6]	Time 0.038 (0.038)	Data 0.000 (0.000)	Loss 0.0376 (0.0376)	Mses 0.003 (0.003)
Epoch: [85][5/6]	Time 0.044 (0.042)	Data 0.000 (0.000)	Loss 0.0337 (0.0387)	Mses 0.000 (0.005)
Test: [0/2]	Time 0.015 (0.015)	Loss 0.0376 (0.0376)	Error0.004 (0.004)
 * Error 0.010
current lr 1.00000e-03
Epoch: [86][0/6]	Time 0.040 (0.040)	Data 0.000 (0.000)	Loss 0.0346 (0.0346)	Mses 0.001 (0.001)
Epoch: [86][5/6]	Time 0.041 (0.039)	Data 0.000 (0.000)	Loss 0.0387 (0.0421)	Mses 0.006 (0.009)
Test: [0

Epoch: [111][5/6]	Time 0.039 (0.043)	Data 0.000 (0.000)	Loss 0.0335 (0.0384)	Mses 0.001 (0.006)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0337 (0.0337)	Error0.001 (0.001)
 * Error 0.004
current lr 1.00000e-03
Epoch: [112][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0429 (0.0429)	Mses 0.010 (0.010)
Epoch: [112][5/6]	Time 0.041 (0.039)	Data 0.000 (0.000)	Loss 0.0427 (0.0396)	Mses 0.010 (0.007)
Test: [0/2]	Time 0.013 (0.013)	Loss 0.0433 (0.0433)	Error0.010 (0.010)
 * Error 0.005
current lr 1.00000e-03
Epoch: [113][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0361 (0.0361)	Mses 0.003 (0.003)
Epoch: [113][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0428 (0.0375)	Mses 0.011 (0.005)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0438 (0.0438)	Error0.012 (0.012)
 * Error 0.007
current lr 1.00000e-03
Epoch: [114][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0325 (0.0325)	Mses 0.000 (0.000)
Epoch: [114][5/6]	Time 0.041 (0.038)	Data 0.000 (0.000)	Loss 0.0403 (0.0392)	Mses 0.008 (0.007)
T

Epoch: [139][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0394 (0.0477)	Mses 0.006 (0.014)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0404 (0.0404)	Error0.007 (0.007)
 * Error 0.007
current lr 1.00000e-04
Epoch: [140][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0415 (0.0415)	Mses 0.008 (0.008)
Epoch: [140][5/6]	Time 0.041 (0.039)	Data 0.000 (0.000)	Loss 0.0462 (0.0396)	Mses 0.012 (0.006)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0479 (0.0479)	Error0.014 (0.014)
 * Error 0.018
current lr 1.00000e-04
Epoch: [141][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0347 (0.0347)	Mses 0.001 (0.001)
Epoch: [141][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0368 (0.0389)	Mses 0.003 (0.005)
Test: [0/2]	Time 0.013 (0.013)	Loss 0.0468 (0.0468)	Error0.013 (0.013)
 * Error 0.008
current lr 1.00000e-04
Epoch: [142][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0449 (0.0449)	Mses 0.011 (0.011)
Epoch: [142][5/6]	Time 0.039 (0.038)	Data 0.000 (0.000)	Loss 0.0496 (0.0422)	Mses 0.016 (0.009)
T

Epoch: [167][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0334 (0.0388)	Mses 0.000 (0.006)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0407 (0.0407)	Error0.008 (0.008)
 * Error 0.008
current lr 1.00000e-05
Epoch: [168][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0455 (0.0455)	Mses 0.012 (0.012)
Epoch: [168][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0381 (0.0392)	Mses 0.005 (0.006)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0415 (0.0415)	Error0.008 (0.008)
 * Error 0.013
current lr 1.00000e-05
Epoch: [169][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0352 (0.0352)	Mses 0.002 (0.002)
Epoch: [169][5/6]	Time 0.040 (0.038)	Data 0.000 (0.000)	Loss 0.0337 (0.0391)	Mses 0.001 (0.006)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0549 (0.0549)	Error0.022 (0.022)
 * Error 0.014
current lr 1.00000e-05
Epoch: [170][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0357 (0.0357)	Mses 0.003 (0.003)
Epoch: [170][5/6]	Time 0.041 (0.038)	Data 0.000 (0.000)	Loss 0.0342 (0.0471)	Mses 0.001 (0.014)
T

Epoch: [195][5/6]	Time 0.040 (0.039)	Data 0.000 (0.000)	Loss 0.0344 (0.0374)	Mses 0.001 (0.004)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0402 (0.0402)	Error0.007 (0.007)
 * Error 0.011
current lr 5.00000e-06
Epoch: [196][0/6]	Time 0.037 (0.037)	Data 0.000 (0.000)	Loss 0.0592 (0.0592)	Mses 0.026 (0.026)
Epoch: [196][5/6]	Time 0.041 (0.038)	Data 0.000 (0.000)	Loss 0.0403 (0.0439)	Mses 0.007 (0.011)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0418 (0.0418)	Error0.009 (0.009)
 * Error 0.007
current lr 5.00000e-06
Epoch: [197][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0345 (0.0345)	Mses 0.001 (0.001)
Epoch: [197][5/6]	Time 0.040 (0.039)	Data 0.000 (0.000)	Loss 0.0364 (0.0437)	Mses 0.003 (0.011)
Test: [0/2]	Time 0.014 (0.014)	Loss 0.0382 (0.0382)	Error0.005 (0.005)
 * Error 0.007
current lr 5.00000e-06
Epoch: [198][0/6]	Time 0.036 (0.036)	Data 0.000 (0.000)	Loss 0.0440 (0.0440)	Mses 0.011 (0.011)
Epoch: [198][5/6]	Time 0.043 (0.039)	Data 0.000 (0.000)	Loss 0.0356 (0.0387)	Mses 0.003 (0.006)
T

In [66]:
np.mean(weight_1)

-0.014919555

In [67]:
np.mean(weight_2)

3.1012974

In [11]:
np.full((1 ,2) ,0.5)

array([[0.5, 0.5]])

In [479]:
model.l1.mu_weight

Parameter containing:
tensor([[15.0195, 15.0738]], requires_grad=True)

In [542]:
l = np.array([[0.,5,6],[0.,5,6]])
var = np.array([[0.1,1,0.01],[0.01,0.01,0.3]])

In [543]:
torch.from_numpy(l[0])

tensor([0., 5., 6.], dtype=torch.float64)

In [497]:
torch.arange(1, 0, -0.1)

tensor([1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,
        0.1000])

In [503]:
np.array([[0.],[20]]).shape

(2, 1)

In [505]:
torch.Tensor(1,2).shape

torch.Size([1, 2])

In [553]:
test = torch.Tensor(2, 3).detach()
for row in range(l.shape[0]):
    test[row] = torch.normal(mean=torch.from_numpy(l[row]), std=0.1, generator=None, out=None)

In [546]:
test

tensor([[-0.0183,  5.0091,  5.9324],
        [-0.0260,  5.1976,  5.9561]])

In [481]:
torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1), generator=None, out=None)

tensor([ 3.1018,  2.5745,  3.1913,  3.9363,  4.8456,  5.6560,  6.9963,  8.2421,
         8.9690, 10.2445])

In [375]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print (name, param.data)

l1.mu_weight tensor([[0.6448, 0.5888]])
l1.rho_weight tensor([[-2.8683, -3.1510]])


In [302]:
x1 = np.random.normal(size = 1500)
x2 = np.random.normal(size = 1500)
X = np.stack((x1,x2),axis = -1)
y = x1 + x2  
#+ np.random.normal(0,0.01,size = 1500)

In [303]:
X_train = X[0:1000]
y_train = y[0:1000]
X_test = X[1000:1500]
y_test = y[1000:1500]

In [304]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, X, y, scale_data=False):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X).float()
            self.y = torch.from_numpy(y).float()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

train_dataset = Dataset(X_train, y_train)
test_dataset = Dataset(X_test,y_test)
#trainloader = torch.utils.data.DataLoader(train_dataset)


In [None]:
#sth still seems to be wrong with the error function or train method -> increased steadily
#while training while val loss decreased

In [306]:
checkpoint_file = args.save_dir + '/bayesian_{}.pth'.format(
            args.model_name)
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [307]:
ttest_dataset = Dataset(np.array([[0.2,0.1]]),np.array([[0.3]]))
ttest_loader = torch.utils.data.DataLoader(ttest_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

In [308]:
pred_probs_mc = []
test_loss = 0
correct = 0
output_list = []
labels_list = []
model.eval()
with torch.no_grad():
        begin = time.time()
        for data, target in ttest_loader:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            else:
                data, target = data.cpu(), target.cpu()
            output_mc = []
            for mc_run in range(1000):
                output, _ = model.forward(data)
                output_mc.append(output)
            output_ = torch.stack(output_mc)
            output_list.append(output_)
        end = time.time()
        output = torch.stack(output_list)
        

In [26]:
class Bay_CPASNet(nn.Module):
	def __init__(self, In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, Pathway_Mask):
		super(Bay_CPASNet, self).__init__()
		self.tanh = nn.Tanh()
		self.pathway_mask = Pathway_Mask
		###gene layer --> pathway layer
		self.sc1 = nn.Linear(In_Nodes, Pathway_Nodes)
		###pathway layer --> hidden layer
		self.sc2 = nn.Linear(Pathway_Nodes, Hidden_Nodes)
		###hidden layer --> hidden layer 2
		self.sc3 = nn.Linear(Hidden_Nodes, Out_Nodes, bias=False)
		###hidden layer 2 + age --> Cox layer
		self.sc4 = nn.Linear(Out_Nodes+1, 1, bias = False)
		self.sc4.weight.data.uniform_(-0.001, 0.001)

	def forward(self, x_1, x_2):
		###force the connections between gene layer and pathway layer w.r.t. 'pathway_mask'
		self.sc1.weight.data = self.sc1.weight.data.mul(self.pathway_mask)
		x_1 = self.tanh(self.sc1(x_1))
		x_1 = self.tanh(self.sc2(x_1))
		x_1 = self.tanh(self.sc3(x_1))
		###combine age with hidden layer 2
		x_cat = torch.cat((x_1, x_2), 1)
		lin_pred = self.sc4(x_cat)

		return lin_pred

def trainBayCoxPASNet(train_x, train_age, train_ytime, train_yevent, \
			eval_x, eval_age, eval_ytime, eval_yevent, pathway_mask, \
			In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, \
			Learning_Rate, Num_Epochs,bnn_prior_params):

    net = Bay_CPASNet(In_Nodes, Pathway_Nodes, Hidden_Nodes, Out_Nodes, pathway_mask)
    dnn_to_bnn_bcoxpas(net, bnn_prior_params)
    
    ###
    ###optimizer
    opt = optim.Adam(net.parameters(), lr=Learning_Rate)

    for epoch in range(Num_Epochs+1):
        net.train()
        opt.zero_grad() ###reset gradients to zeros

        pred = net(train_x, train_age) ###Forward
        ce_loss = neg_par_log_likelihood(pred, train_ytime, train_yevent)
        kl = get_kl_loss(net)
        loss = ce_loss + kl
        loss.backward() ###calculate gradients
        opt.step()
        net.sc1.weight.data = net.sc1.weight.data.mul(net.pathway_mask) ###force the connections between gene layer and pathway layer

        if epoch % 20 == 0:
            with torch.no_grad():
                net.train()
                train_output_mc = []
                for mc_run in range(10):
                    output = net(train_x, train_age)
                    train_output_mc.append(output)
                    outputs = torch.stack(train_output_mc)
                train_pred = outputs.mean(dim=0)
                train_loss = neg_par_log_likelihood(train_pred, train_ytime, train_yevent).view(1,)

                eval_output_mc = []
                for mc_run in range(10):
                    output = net(eval_x, eval_age)
                    eval_output_mc.append(output)
                    eval_outputs = torch.stack(eval_output_mc)
                eval_pred = eval_outputs.mean(dim=0)
                eval_loss = neg_par_log_likelihood(eval_pred, eval_ytime, eval_yevent).view(1,)

                train_cindex = c_index(train_pred, train_ytime, train_yevent)
                eval_cindex = c_index(eval_pred, eval_ytime, eval_yevent)
                print(f"Epoch: {epoch}, Train Loss: {train_loss},Eval Loss: {eval_loss}, "
                      f" Train Cindex: {train_cindex}, Eval Cindex: {eval_cindex}")

    return (train_loss, eval_loss, train_cindex, eval_cindex)