In [24]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.metrics import brier_score

from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from sklearn.metrics import mean_squared_error

import pycox
import torch
from torch import nn
import torchtuples as tt

from pycox.datasets import metabric
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv


from src.data_prep.torch_datasets import Dataset
from src.models.train_torch_net import train_torch_net
from configs.model_configs import BDNN_Zhang


from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score,
)

from laplace import Laplace


In [5]:
survival_pseudo_data = pd.read_csv("/Users/alexandermollers/Documents/GitHub/survival_analysis/data/surv_pseudo_data.csv")
survival_pseudo_data

Unnamed: 0,X1,X2,X3,y,failed,pseudo,tpseudo,id
0,0.446645,0.017643,0.298416,21,True,1.000724,2.0,1
1,0.446645,0.017643,0.298416,21,True,1.001876,3.0,1
2,0.446645,0.017643,0.298416,21,True,1.004832,5.0,1
3,0.446645,0.017643,0.298416,21,True,1.010563,11.0,1
4,0.446645,0.017643,0.298416,21,True,-0.055363,27.0,1
...,...,...,...,...,...,...,...,...
8995,-0.241216,0.495190,0.116989,174,True,1.019840,27.0,1000
8996,-0.241216,0.495190,0.116989,174,True,1.033707,88.0,1000
8997,-0.241216,0.495190,0.116989,174,True,-0.082001,185.0,1000
8998,-0.241216,0.495190,0.116989,174,True,-0.056356,256.0,1000


In [7]:
X = survival_pseudo_data[["X1","X2","X3","tpseudo"]].values
y = survival_pseudo_data[["pseudo"]].values

dataset = Dataset(X, y)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)


In [35]:
nn_zhang = BDNN_Zhang()

# Define the loss function and optimizer
loss_function = nn.MSELoss()
trained_nn, loss = train_torch_net(dataset,neural_net=nn_zhang,loss_function=loss_function,n_epochs=50)

In [20]:
la_nn = Laplace(trained_nn, 'regression', subset_of_weights='all', hessian_structure='full')
la_nn.fit(trainloader)
log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
for i in range(50):
    hyper_optimizer.zero_grad()
    neg_marglik = - la_nn.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
    neg_marglik.backward()
    hyper_optimizer.step()



In [22]:
f_mu, f_var = la_nn(dataset.X)
f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma ** 2 + la_nn.sigma_noise.item() ** 2)

In [29]:
nn_pred = trained_nn(dataset.X)
nn_pred = nn_pred.squeeze().detach().cpu().numpy()

In [25]:
mean_squared_error(y,f_mu)

0.20985746216890544

In [30]:
mean_squared_error(y,nn_pred)

0.20985746216890544

9000