In [1]:
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
import torch.optim as optim
import copy
from scipy.interpolate import interp1d
dtype = torch.FloatTensor

In [2]:
def sort_data(path):
    data = pd.read_csv(path)
    data.sort_values("OS.time",ascending = False, inplace = True)
    x = data.drop(["Patient_ID", "race_black", "race_white", "age", "stageh","gradeh", "OS", "OS.time"], axis = 1).values
    ytime = data.loc[:, ["OS.time"]].values
    yevent = data.loc[:, ["OS"]].values
    age = data.loc[:, ["age"]].values
    cstage = data.loc[:, ["stageh"]].values
    hgrade = data.loc[:, ["gradeh"]].values
    race_black = data.loc[:, ["race_black"]].values
    race_white = data.loc[:, ["race_white"]].values
    return(x, ytime, yevent, age, cstage, hgrade, race_black, race_white)

def load_data(path, dtype):
    x, ytime, yevent, age, cstage, hgrade, race_black, race_white = sort_data(path)
    X = torch.from_numpy(x).type(dtype)
    YTIME = torch.from_numpy(ytime).type(dtype)
    YEVENT = torch.from_numpy(yevent).type(dtype)
    AGE = torch.from_numpy(age).type(dtype)
    CSTAGE = torch.from_numpy(cstage).type(dtype)
    HGRADE = torch.from_numpy(hgrade).type(dtype)
    RACE_BLACK = torch.from_numpy(race_black).type(dtype)
    RACE_WHITE = torch.from_numpy(race_white).type(dtype)
    if torch.cuda.is_available():
        X = X.cuda()
        YTIME = YTIME.cuda()
        YEVENT = YEVENT.cuda()
        AGE = AGE.cuda()
        CSTAGE = CSTAGE.cuda()
        HGRADE = HGRADE.cuda()
        RACE_BLACK = RACE_BLACK.cuda()
        RACE_WHITE = RACE_WHITE.cuda()
    return(X, YTIME, YEVENT, AGE, CSTAGE, HGRADE, RACE_BLACK, RACE_WHITE)

In [3]:
class EarlyStopping:
    def __init__(self, patience, verbose=False, delta=0):
        
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter % 200 == 0:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [4]:
def reconstruction_loss(x, x_recon):
    batch_size = x.size(0)
    assert batch_size != 0
    
    recon_loss = F.mse_loss(x_recon, x, reduction='sum').div(batch_size)

    return recon_loss

def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    
    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld

In [5]:
def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std*eps

In [6]:
def kaiming_init(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)

In [7]:
class BetaVAE_B(nn.Module):
    """Model proposed in understanding beta-VAE paper(Burgess et al, arxiv:1804.03599, 2018). Modifications made to best accommodate our data"""

    def __init__(self, z_dim, input_n):
        super(BetaVAE_B, self).__init__()
        self.z_dim = z_dim
        self.nc = input_n
        self.encoder = nn.Sequential(
            nn.Linear(input_n, 200),          
            nn.ReLU(True),
            nn.Linear(200, 50),         
            nn.ReLU(True),
            nn.Linear(50, z_dim*2)            
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 50),                             
            nn.ReLU(True),
            nn.Linear(50, 200),      
            nn.ReLU(True),
            nn.Linear(200, input_n)
        )
        
        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                kaiming_init(m)

    def forward(self, x):
        distributions = self._encode(x)
        mu = distributions[:, :self.z_dim]
        logvar = distributions[:, self.z_dim:]
        z = reparametrize(mu, logvar)
        x_recon = self._decode(z)

        return x_recon, mu, logvar

    def _encode(self, x):
        return self.encoder(x)

    def _decode(self, z):
        return self.decoder(z)

In [8]:
def trainBetaVAE_B(train_x, eval_x, z_dim, input_n, Learning_Rate, L2, Num_Epochs, patience, gamma, C_max, C_stop_iter):
    net = BetaVAE_B(z_dim, input_n)
    
    #early_stopping = EarlyStopping(patience = patience, verbose = False)
    
    if torch.cuda.is_available():
        net.cuda()
    opt = optim.Adam(net.parameters(), lr=Learning_Rate, weight_decay = L2)
    for epoch in range(Num_Epochs+1):
        net.train()
        opt.zero_grad()
        
        x_recon, mu, logvar = net(train_x)
        recon_loss = reconstruction_loss(train_x, x_recon)
        total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
        C = torch.clamp(C_max/C_stop_iter*epoch, 0, C_max.data[0])
        beta_vae_loss = recon_loss + gamma*(total_kld-C).abs()

        
        beta_vae_loss.backward()
        opt.step()
        
        net.eval()
        val_x_recon, val_mu, val_logvar = net(eval_x)
        val_recon_loss = reconstruction_loss(eval_x, val_x_recon)
        val_total_kld, val_dim_wise_kld, val_mean_kld = kl_divergence(val_mu, val_logvar)
        val_loss = val_recon_loss + gamma*(val_total_kld-C).abs()
        
       # early_stopping(val_loss, net)
       # if early_stopping.early_stop:
       #     net.train()
       #     tr_x_recon, tr_mu, tr_logvar = net(train_x)
       #     tr_recon_loss = reconstruction_loss(train_x, tr_x_recon)
       #     tr_total_kld, tr_dim_wise_kld, tr_mean_kld = kl_divergence(tr_mu, tr_logvar)
       #     tr_loss = tr_recon_loss + gamma*(tr_total_kld-C).abs()
       #     print("Early stopping, Number of epochs: ", epoch, ", Loss in Validation: ", val_loss, ", Loss in Training: ", tr_loss)
       #     break
        if epoch % 1000 == 0:
            net.train()
            tr_x_recon, tr_mu, tr_logvar = net(train_x)
            tr_recon_loss = reconstruction_loss(train_x, tr_x_recon)
            tr_total_kld, tr_dim_wise_kld, tr_mean_kld = kl_divergence(tr_mu, tr_logvar)
            tr_loss = tr_recon_loss + gamma*(tr_total_kld-C).abs()
            print("Number of epochs: ", epoch, ", Loss in Train: ", tr_loss)
    return (tr_loss, val_loss, tr_mu, tr_logvar, val_mu, val_logvar)

In [30]:
z_dim = 10
input_n = 929
Initial_Learning_Rate = [0.01]
L2_Lambda = [0.1, 0.01, 0.005]
patience = 1000
gamma = 500
C_max = torch.tensor([25.])
C_stop_iter = 5000
num_epochs = 6000
Num_EPOCHS = 7000
x_train, ytime_train, yevent_train, age_train, cstage_train, hgrade_train, race_black_train, race_white_train = load_data("D:/DL/Variational autoencoder/Tryout_12_30_2020/divided_data/exp_20/data_tr_20.csv", dtype)
x_valid, ytime_valid, yevent_valid, age_valid, cstage_valid, hgrade_valid, race_black_valid, race_white_valid = load_data("D:/DL/Variational autoencoder/Tryout_12_30_2020/divided_data/exp_20/data_val_20.csv", dtype)
x_test, ytime_test, yevent_test, age_test, cstage_test, hgrade_test, race_black_test, race_white_test = load_data("D:/DL/Variational autoencoder/Tryout_12_30_2020/divided_data/exp_20/data_tes_20.csv", dtype)
opt_l2_loss = 0
opt_lr_loss = 0
opt_loss = torch.Tensor([float("Inf")])
if torch.cuda.is_available():
    opt_loss = opt_loss.cuda()
for l2 in L2_Lambda:
    for lr in Initial_Learning_Rate:
        loss_train, loss_valid, tr_mu, tr_logvar, val_mu, val_logvar = trainBetaVAE_B(x_train, x_valid, z_dim, input_n, lr, l2, num_epochs, patience, gamma, C_max, C_stop_iter)
        if loss_valid < opt_loss:
            opt_l2_loss = l2
            opt_lr_loss = lr
            opt_loss = loss_valid
        print ("L2: ", l2, ", LR: ", lr, ", Loss in Validation: ", loss_valid)
loss_train, loss_test, tr_mu, tr_logvar, tes_mu, tes_logvar = trainBetaVAE_B(x_train, x_test, z_dim, input_n, opt_lr_loss, opt_l2_loss, Num_EPOCHS, patience, gamma, C_max, C_stop_iter)
print ("Optimal L2: ", opt_l2_loss, ", Optimal LR: ", opt_lr_loss)

Number of epochs:  0 , Loss in Train:  tensor([6292.1270], grad_fn=<AddBackward0>)
Number of epochs:  1000 , Loss in Train:  tensor([951.2902], grad_fn=<AddBackward0>)
Number of epochs:  2000 , Loss in Train:  tensor([935.1714], grad_fn=<AddBackward0>)
Number of epochs:  3000 , Loss in Train:  tensor([943.6773], grad_fn=<AddBackward0>)
Number of epochs:  4000 , Loss in Train:  tensor([921.4156], grad_fn=<AddBackward0>)
Number of epochs:  5000 , Loss in Train:  tensor([932.7629], grad_fn=<AddBackward0>)
Number of epochs:  6000 , Loss in Train:  tensor([930.8020], grad_fn=<AddBackward0>)
L2:  0.1 , LR:  0.01 , Loss in Validation:  tensor([1006.1946], grad_fn=<AddBackward0>)
Number of epochs:  0 , Loss in Train:  tensor([10005.4824], grad_fn=<AddBackward0>)
Number of epochs:  1000 , Loss in Train:  tensor([936.8536], grad_fn=<AddBackward0>)
Number of epochs:  2000 , Loss in Train:  tensor([924.2093], grad_fn=<AddBackward0>)
Number of epochs:  3000 , Loss in Train:  tensor([935.7556], grad

In [31]:
tr_z = reparametrize(tr_mu, tr_logvar)
tes_z = reparametrize(tes_mu, tes_logvar)

print(tr_z.size())

processed_tr_pre = torch.cat((tr_z, ytime_train, yevent_train, age_train, cstage_train, hgrade_train, race_black_train, race_white_train), 1)
processed_tes_pre = torch.cat((tes_z, ytime_test, yevent_test, age_test, cstage_test, hgrade_test, race_black_test, race_white_test), 1)

processed_tr = pd.DataFrame(processed_tr_pre, columns = ['Z_1', 'Z_2', 'Z_3', 'Z_4', 'Z_5', 'Z_6', 'Z_7', 
                                                         'Z_8', 'Z_9', 'Z_10', 'OS.time', 'OS.event', 'age', 
                                                         'stageh', 'gradeh', 'race_black', 'race_white'])
processed_tr = processed_tr.astype(float)
processed_tes = pd.DataFrame(processed_tes_pre, columns = ['Z_1', 'Z_2', 'Z_3', 'Z_4', 'Z_5', 'Z_6', 'Z_7', 
                                                           'Z_8', 'Z_9', 'Z_10', 'OS.time', 'OS.event', 'age', 
                                                           'stageh', 'gradeh', 'race_black', 'race_white'])
processed_tes = processed_tes.astype(float)

torch.Size([270, 10])


In [32]:
print(processed_tes)

         Z_1       Z_2       Z_3       Z_4       Z_5       Z_6       Z_7  \
0   1.316117  1.813122 -1.568688 -1.869874 -1.236799 -0.998041 -2.408681   
1   1.362314  1.526981 -1.432127 -1.903233  3.350794 -2.169665 -1.809779   
2   2.998492  1.540541 -1.896181 -2.402537  4.602023 -1.782509 -1.474900   
3   2.163852  1.411169 -2.118542 -2.560777 -1.219030 -2.511004 -2.890308   
4   1.218537  1.952838 -1.455242 -1.277790  1.768396 -2.052961 -2.911281   
..       ...       ...       ...       ...       ...       ...       ...   
73  1.867604  1.369310 -1.977803 -2.306443  2.900714 -1.236399 -2.946648   
74  1.780923  1.613906 -1.808143 -1.868042 -0.599716 -1.951150 -2.428445   
75  1.726435  1.605304 -1.928070 -2.128355  4.077435 -1.754271 -2.928323   
76  2.099422  2.042235 -1.984882 -1.611757  2.922384 -2.261796 -2.214767   
77  1.227224  1.895802 -1.668252 -2.126025  0.620667 -1.213066 -2.171818   

         Z_8       Z_9      Z_10  OS.time  OS.event   age  stageh  gradeh  \
0   0.9338

In [12]:
import lifelines
from lifelines import CoxPHFitter

In [34]:
cph = CoxPHFitter(l1_ratio = 1., penalizer = 0.0001)
cph.fit(processed_tes, duration_col='OS.time', event_col='OS.event')
cph.print_summary()

0,1
model,lifelines.CoxPHFitter
duration col,'OS.time'
event col,'OS.event'
penalizer,0.0001
l1 ratio,1
baseline estimation,breslow
number of observations,78
number of events observed,47
partial log-likelihood,-153.52
time fit was run,2021-01-07 21:21:40 UTC

Unnamed: 0,coef,exp(coef),se(coef),coef lower 95%,coef upper 95%,exp(coef) lower 95%,exp(coef) upper 95%,z,p,-log2(p)
Z_1,0.25,1.29,0.37,-0.48,0.98,0.62,2.66,0.68,0.5,1.01
Z_2,0.63,1.87,0.6,-0.55,1.81,0.58,6.1,1.04,0.3,1.75
Z_3,-0.55,0.58,0.49,-1.52,0.42,0.22,1.52,-1.11,0.27,1.91
Z_4,0.66,1.93,0.56,-0.43,1.75,0.65,5.76,1.19,0.24,2.08
Z_5,0.22,1.25,0.09,0.04,0.41,1.04,1.5,2.35,0.02,5.72
Z_6,0.71,2.04,0.49,-0.24,1.67,0.78,5.29,1.46,0.14,2.79
Z_7,-0.09,0.91,0.39,-0.86,0.67,0.42,1.96,-0.24,0.81,0.3
Z_8,0.37,1.44,0.48,-0.58,1.32,0.56,3.73,0.76,0.45,1.16
Z_9,0.42,1.52,0.51,-0.57,1.41,0.56,4.08,0.82,0.41,1.28
Z_10,-0.52,0.59,0.41,-1.33,0.28,0.26,1.33,-1.27,0.2,2.29

0,1
Concordance,0.67
Partial AIC,337.05
log-likelihood ratio test,23.70 on 15 df
-log2(p) of ll-ratio test,3.83
